package com.xiaomi.ai.nlp.optimization;

import com.xiaomi.ai.nlp.loss.DiffFunction;
import com.xiaomi.ai.nlp.utils.MLMath;
import java.util.Arrays;

/* loaded from: classes4.dex */
public class LineSearch {
    private final double c;
    private double[] nextX;
    private final double rho;

    public LineSearch(int i) {
        this(i, 0.1d, 0.01d);
    }

    public LineSearch(int i, double d, double d2) {
        this.rho = d;
        this.c = d2;
        this.nextX = new double[i];
    }

    private void constrainNextX(double[] dArr, double[] dArr2, double[] dArr3) {
        for (int i = 0; i < dArr3.length; i++) {
            if (Math.signum(dArr3[i]) != ((dArr[i] > 0.0d || dArr[i] < 0.0d) ? Math.signum(dArr[i]) : Math.signum(-dArr2[i]))) {
                dArr3[i] = 0.0d;
            }
        }
    }

    public double getC() {
        return this.c;
    }

    public double getRho() {
        return this.rho;
    }

    public double[] nextX(double[] dArr, double[] dArr2, DiffFunction diffFunction, boolean z) {
        Arrays.fill(this.nextX, 0.0d);
        double valueAt = diffFunction.valueAt(dArr);
        double[] derivativeAt = diffFunction.derivativeAt(dArr);
        double dotProd = this.c * MLMath.dotProd(derivativeAt, dArr2);
        double d = 1.0d;
        int i = 0;
        while (i < 10) {
            int i2 = i;
            MLMath.plusTo(dArr, 1.0d, dArr2, d, this.nextX);
            if (z) {
                constrainNextX(dArr, derivativeAt, this.nextX);
            }
            if (Double.compare(diffFunction.valueAt(this.nextX), (d * dotProd) + valueAt) <= 0) {
                break;
            }
            d *= this.rho;
            i = i2 + 1;
        }
        return this.nextX;
    }
}
