/*
 * Decompiled with CFR 0.152.
 */
package jdplus.toolkit.base.core.math.splines;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import jdplus.toolkit.base.api.data.DoubleSeq;
import jdplus.toolkit.base.api.math.matrices.Matrix;
import jdplus.toolkit.base.core.data.DataBlock;
import jdplus.toolkit.base.core.data.DataBlockIterator;
import jdplus.toolkit.base.core.math.matrices.FastMatrix;
import jdplus.toolkit.base.core.math.matrices.LowerTriangularMatrix;
import jdplus.toolkit.base.core.math.matrices.decomposition.ElementaryTransformations;
import jdplus.toolkit.base.core.math.polynomials.UnitRoots;
import jdplus.toolkit.base.core.math.splines.BSplines;
import jdplus.toolkit.base.core.stats.Combinatorics;
import jdplus.toolkit.base.core.stats.linearmodel.LeastSquaresResults;
import jdplus.toolkit.base.core.stats.linearmodel.LinearModel;
import jdplus.toolkit.base.core.stats.linearmodel.Ols;
import lombok.Generated;
import org.jspecify.annotations.NonNull;
import org.jspecify.annotations.Nullable;

public class AdaptivePeriodicSpline {
    private final Specification spec;
    private final Step step0;
    private final List<Step> steps = new ArrayList<Step>();
    private final double sigma2;
    private final double scaling;
    private final FastMatrix B;
    private final FastMatrix B2;
    private final FastMatrix D;
    private final FastMatrix LB;
    private final DataBlock By;

    private Step nextStep(Step start, double lambda) {
        double[] w = (double[])start.getW().clone();
        int q = w.length;
        double[] z = new double[q];
        FastMatrix LBp = this.B2.extract(0, q, 0, q);
        LBp.copy(this.LB);
        for (int l = 0; l < q; ++l) {
            this.B2.column(q + l).setAY(Math.sqrt(lambda * w[l]), this.D.row(l));
        }
        ElementaryTransformations.fastGivensTriangularize(this.B2);
        DataBlock A = this.By.deepClone();
        LowerTriangularMatrix.solveLx(LBp, A);
        LowerTriangularMatrix.solvexL(LBp, A);
        for (int i = 0; i < q; ++i) {
            double wcur;
            if (w[i] == 0.0) continue;
            double da = this.D.row(i).dot(A);
            w[i] = wcur = 1.0 / (da * da + 1.0E-10);
            z[i] = da * da * wcur;
        }
        int k = this.spec.getSplineOrder();
        double[] selectedKnots = Step.selectedKnots(z, this.spec);
        if (selectedKnots.length < this.spec.getSplineOrder()) {
            return null;
        }
        BSplines.BSpline bs = BSplines.periodic(k, selectedKnots, this.spec.getPeriod());
        FastMatrix Bnew = BSplines.splines(bs, this.spec.getX());
        LinearModel lm = LinearModel.builder().y(this.spec.getY()).addX((Matrix)Bnew).build();
        LeastSquaresResults rslt = Ols.compute(lm);
        double[] knots = this.spec.getKnots();
        DataBlock s = DataBlock.make(knots.length);
        FastMatrix Bs = BSplines.splines(bs, DoubleSeq.of((double[])knots));
        s.product(DataBlock.of(rslt.getCoefficients()), Bs.rowsIterator());
        Step.Builder builder = Step.builder().lambda(lambda).a(A.getStorage()).w(w).z(z).s(s.getStorage());
        return this.completeStep(rslt.getErrorSumOfSquares(), this.spec.getY().length(), selectedKnots.length, builder);
    }

    AdaptivePeriodicSpline(Specification spec) {
        int i;
        this.spec = spec;
        int k = spec.getSplineOrder();
        double[] knots = spec.getKnots();
        int q = knots.length;
        double P = spec.getPeriod();
        BSplines.BSpline bs = BSplines.periodic(k, knots, P);
        this.B = BSplines.splines(bs, spec.getX());
        FastMatrix Bt = this.B.transpose();
        ElementaryTransformations.fastGivensTriangularize(Bt);
        this.LB = Bt.extract(0, q, 0, q).deepClone();
        DoubleSeq coeff = UnitRoots.D(1, k).coefficients();
        this.D = FastMatrix.square(q);
        for (i = -1; i < k; ++i) {
            this.D.subDiagonal(i).set(coeff.get(k - 1 - i));
        }
        this.D.set(0, q - 1, coeff.get(k));
        for (i = 1; i < k; ++i) {
            this.D.subDiagonal(i - q).set(coeff.get(i + 1));
        }
        double[] w = new double[q];
        double[] z = new double[q];
        for (int i2 = 0; i2 < q; ++i2) {
            w[i2] = 1.0;
            z[i2] = 1.0;
        }
        int[] fixedKnots = spec.getFixedKnots();
        if (fixedKnots != null) {
            for (int i3 = 0; i3 < fixedKnots.length; ++i3) {
                w[fixedKnots[i3]] = 0.0;
            }
        }
        this.By = DataBlock.make(this.B.getColumnsCount());
        DataBlock Y = DataBlock.of(spec.y);
        this.By.addAProduct(1.0, this.B.columnsIterator(), Y);
        this.B2 = FastMatrix.make(q, 2 * q);
        this.B2.extract(0, q, 0, q).copy(this.LB);
        DataBlock A = this.By.deepClone();
        LowerTriangularMatrix.solveLx(this.LB, A);
        LowerTriangularMatrix.solvexL(this.LB, A);
        DataBlock e = Y.deepClone();
        e.addAProduct(-1.0, this.B.rowsIterator(), A);
        this.sigma2 = e.ssq() / (double)this.B.getRowsCount();
        DataBlock s = DataBlock.make(knots.length);
        FastMatrix Bs = BSplines.splines(bs, DoubleSeq.of((double[])knots));
        s.product(A, Bs.rowsIterator());
        Step.Builder builder = Step.builder().lambda(0.0).a(A.getStorage()).s(s.getStorage()).w(w).z(z);
        this.step0 = this.completeStep(e.ssq(), Y.length(), A.length(), builder);
        this.scaling = this.sigma2 / coeff.norm2();
    }

    private Step completeStep(double ssq, int n, int nparams, Step.Builder builder) {
        double ll = -0.5 * ssq / this.sigma2;
        return builder.ll(ll).aic(2.0 * (-ll + (double)nparams)).bic(-2.0 * ll + Math.log(n) * (double)nparams).ebic(-2.0 * ll + Math.log(n) * (double)nparams + 2.0 * Combinatorics.logChoose(this.B.getColumnsCount(), nparams)).build();
    }

    public static AdaptivePeriodicSpline of(Specification spec) {
        return new AdaptivePeriodicSpline(spec);
    }

    public boolean process(double lambda) {
        double da;
        Step next;
        int niter;
        this.steps.clear();
        Step current = this.step0;
        lambda *= this.scaling;
        this.steps.add(current);
        for (niter = 0; niter < this.spec.getMaxIter() && (next = this.nextStep(current, lambda)) != null && Step.selectedKnotsCount(next.getZ(), this.spec) >= this.spec.getSplineOrder() && !((da = DoubleSeq.of((double[])current.getA()).distance(DoubleSeq.of((double[])next.getA()))) < this.spec.getPrecision()); ++niter) {
            if (current.getLl() != next.getLl()) {
                this.steps.add(current);
            }
            current = next;
        }
        return niter < this.spec.maxIter;
    }

    public Specification getSpecification() {
        return this.spec;
    }

    public int getiterationCount() {
        return this.steps.size();
    }

    public Step step(int i) {
        return this.steps.get(i);
    }

    public List<Step> allSteps() {
        return Collections.unmodifiableList(this.steps);
    }

    public Step result() {
        return this.steps.get(this.steps.size() - 1);
    }

    public int selectedKnotsCount(int pos) {
        double[] z = this.steps.get(pos).getZ();
        if (z == null) {
            return this.spec.knots.length;
        }
        return Step.selectedKnotsCount(z, this.spec);
    }

    public int[] selectedKnotsPosition(int pos) {
        double[] z = this.steps.get(pos).getZ();
        if (z == null) {
            return null;
        }
        return Step.selectedKnotsPosition(z, this.spec);
    }

    public double[] selectedKnots(int pos) {
        double[] z = this.steps.get(pos).getZ();
        if (z == null) {
            return null;
        }
        return Step.selectedKnots(z, this.spec);
    }

    public FastMatrix A() {
        FastMatrix A = FastMatrix.make(this.steps.size(), this.spec.getKnots().length);
        DataBlockIterator rows = A.rowsIterator();
        for (Step step : this.steps) {
            rows.next().copy(DataBlock.of(step.getA()));
        }
        return A;
    }

    public FastMatrix Z() {
        FastMatrix Z = FastMatrix.make(this.steps.size(), this.spec.getKnots().length);
        DataBlockIterator rows = Z.rowsIterator();
        for (Step step : this.steps) {
            rows.next().copy(DataBlock.of(step.getZ()));
        }
        return Z;
    }

    public FastMatrix W() {
        FastMatrix W = FastMatrix.make(this.steps.size(), this.spec.getKnots().length);
        DataBlockIterator rows = W.rowsIterator();
        for (Step step : this.steps) {
            rows.next().copy(DataBlock.of(step.getW()));
        }
        return W;
    }

    public FastMatrix S() {
        FastMatrix S = FastMatrix.make(this.steps.size(), this.spec.getKnots().length);
        DataBlockIterator rows = S.rowsIterator();
        for (Step step : this.steps) {
            rows.next().copy(DataBlock.of(step.getS()));
        }
        return S;
    }

    public static final class Step {
        private final double lambda;
        private final double[] a;
        private final double[] w;
        private final double[] z;
        private final double[] s;
        private final double aic;
        private final double bic;
        private final double ebic;
        private final double ll;

        static int selectedKnotsCount(double[] z, Specification spec) {
            int[] fixedKnots = spec.getFixedKnots();
            int n = fixedKnots == null ? 0 : fixedKnots.length;
            int q = z.length;
            for (int i = 0; i < q; ++i) {
                if (!(z[i] >= spec.getSelectionThreshold())) continue;
                ++n;
            }
            return n;
        }

        static int[] selectedKnotsPosition(double[] z, Specification spec) {
            int n = Step.selectedKnotsCount(z, spec);
            int[] selectedKnots = new int[n];
            int j = 0;
            for (int i = 0; i < z.length; ++i) {
                if (!(z[i] >= spec.getSelectionThreshold())) continue;
                selectedKnots[j++] = i;
            }
            int[] fixedKnots = spec.getFixedKnots();
            if (fixedKnots != null) {
                for (int i = 0; i < fixedKnots.length; ++i) {
                    selectedKnots[j++] = fixedKnots[i];
                }
            }
            Arrays.sort(selectedKnots);
            return selectedKnots;
        }

        static double[] selectedKnots(double[] z, Specification spec) {
            int[] sel = Step.selectedKnotsPosition(z, spec);
            if (sel == null) {
                return null;
            }
            double[] dsel = new double[sel.length];
            double[] knots = spec.getKnots();
            for (int i = 0; i < sel.length; ++i) {
                dsel[i] = knots[sel[i]];
            }
            return dsel;
        }

        @Generated
        Step(double lambda, double[] a, double[] w, double[] z, double[] s, double aic, double bic, double ebic, double ll) {
            this.lambda = lambda;
            this.a = a;
            this.w = w;
            this.z = z;
            this.s = s;
            this.aic = aic;
            this.bic = bic;
            this.ebic = ebic;
            this.ll = ll;
        }

        @Generated
        public static @NonNull Builder builder() {
            return new Builder();
        }

        @Generated
        public double getLambda() {
            return this.lambda;
        }

        @Generated
        public double[] getA() {
            return this.a;
        }

        @Generated
        public double[] getW() {
            return this.w;
        }

        @Generated
        public double[] getZ() {
            return this.z;
        }

        @Generated
        public double[] getS() {
            return this.s;
        }

        @Generated
        public double getAic() {
            return this.aic;
        }

        @Generated
        public double getBic() {
            return this.bic;
        }

        @Generated
        public double getEbic() {
            return this.ebic;
        }

        @Generated
        public double getLl() {
            return this.ll;
        }

        @Generated
        public boolean equals(@Nullable Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Step)) {
                return false;
            }
            Step other = (Step)o;
            if (Double.compare(this.getLambda(), other.getLambda()) != 0) {
                return false;
            }
            if (Double.compare(this.getAic(), other.getAic()) != 0) {
                return false;
            }
            if (Double.compare(this.getBic(), other.getBic()) != 0) {
                return false;
            }
            if (Double.compare(this.getEbic(), other.getEbic()) != 0) {
                return false;
            }
            if (Double.compare(this.getLl(), other.getLl()) != 0) {
                return false;
            }
            if (!Arrays.equals(this.getA(), other.getA())) {
                return false;
            }
            if (!Arrays.equals(this.getW(), other.getW())) {
                return false;
            }
            if (!Arrays.equals(this.getZ(), other.getZ())) {
                return false;
            }
            return Arrays.equals(this.getS(), other.getS());
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $lambda = Double.doubleToLongBits(this.getLambda());
            result = result * 59 + (int)($lambda >>> 32 ^ $lambda);
            long $aic = Double.doubleToLongBits(this.getAic());
            result = result * 59 + (int)($aic >>> 32 ^ $aic);
            long $bic = Double.doubleToLongBits(this.getBic());
            result = result * 59 + (int)($bic >>> 32 ^ $bic);
            long $ebic = Double.doubleToLongBits(this.getEbic());
            result = result * 59 + (int)($ebic >>> 32 ^ $ebic);
            long $ll = Double.doubleToLongBits(this.getLl());
            result = result * 59 + (int)($ll >>> 32 ^ $ll);
            result = result * 59 + Arrays.hashCode(this.getA());
            result = result * 59 + Arrays.hashCode(this.getW());
            result = result * 59 + Arrays.hashCode(this.getZ());
            result = result * 59 + Arrays.hashCode(this.getS());
            return result;
        }

        @Generated
        public @NonNull String toString() {
            return "AdaptivePeriodicSpline.Step(lambda=" + this.getLambda() + ", a=" + Arrays.toString(this.getA()) + ", w=" + Arrays.toString(this.getW()) + ", z=" + Arrays.toString(this.getZ()) + ", s=" + Arrays.toString(this.getS()) + ", aic=" + this.getAic() + ", bic=" + this.getBic() + ", ebic=" + this.getEbic() + ", ll=" + this.getLl() + ")";
        }

        @Generated
        public static class Builder {
            @Generated
            private double lambda;
            @Generated
            private double[] a;
            @Generated
            private double[] w;
            @Generated
            private double[] z;
            @Generated
            private double[] s;
            @Generated
            private double aic;
            @Generated
            private double bic;
            @Generated
            private double ebic;
            @Generated
            private double ll;

            @Generated
            Builder() {
            }

            @Generated
            public @NonNull Builder lambda(double lambda) {
                this.lambda = lambda;
                return this;
            }

            @Generated
            public @NonNull Builder a(double[] a) {
                this.a = a;
                return this;
            }

            @Generated
            public @NonNull Builder w(double[] w) {
                this.w = w;
                return this;
            }

            @Generated
            public @NonNull Builder z(double[] z) {
                this.z = z;
                return this;
            }

            @Generated
            public @NonNull Builder s(double[] s) {
                this.s = s;
                return this;
            }

            @Generated
            public @NonNull Builder aic(double aic) {
                this.aic = aic;
                return this;
            }

            @Generated
            public @NonNull Builder bic(double bic) {
                this.bic = bic;
                return this;
            }

            @Generated
            public @NonNull Builder ebic(double ebic) {
                this.ebic = ebic;
                return this;
            }

            @Generated
            public @NonNull Builder ll(double ll) {
                this.ll = ll;
                return this;
            }

            @Generated
            public @NonNull Step build() {
                return new Step(this.lambda, this.a, this.w, this.z, this.s, this.aic, this.bic, this.ebic, this.ll);
            }

            @Generated
            public @NonNull String toString() {
                return "AdaptivePeriodicSpline.Step.Builder(lambda=" + this.lambda + ", a=" + Arrays.toString(this.a) + ", w=" + Arrays.toString(this.w) + ", z=" + Arrays.toString(this.z) + ", s=" + Arrays.toString(this.s) + ", aic=" + this.aic + ", bic=" + this.bic + ", ebic=" + this.ebic + ", ll=" + this.ll + ")";
            }
        }
    }

    public static final class Specification {
        private final DoubleSeq x;
        private final DoubleSeq y;
        private final int splineOrder;
        private final double period;
        private final double[] knots;
        private final int[] fixedKnots;
        private final double precision;
        private final double selectionThreshold;
        private final int maxIter;

        public static Builder builder() {
            return new Builder().splineOrder(4).precision(1.0E-6).selectionThreshold(0.99).maxIter(20);
        }

        public int getFixedKnotsCount() {
            return this.fixedKnots == null ? 0 : this.fixedKnots.length;
        }

        @Generated
        Specification(DoubleSeq x, DoubleSeq y, int splineOrder, double period, double[] knots, int[] fixedKnots, double precision, double selectionThreshold, int maxIter) {
            this.x = x;
            this.y = y;
            this.splineOrder = splineOrder;
            this.period = period;
            this.knots = knots;
            this.fixedKnots = fixedKnots;
            this.precision = precision;
            this.selectionThreshold = selectionThreshold;
            this.maxIter = maxIter;
        }

        @Generated
        public DoubleSeq getX() {
            return this.x;
        }

        @Generated
        public DoubleSeq getY() {
            return this.y;
        }

        @Generated
        public int getSplineOrder() {
            return this.splineOrder;
        }

        @Generated
        public double getPeriod() {
            return this.period;
        }

        @Generated
        public double[] getKnots() {
            return this.knots;
        }

        @Generated
        public int[] getFixedKnots() {
            return this.fixedKnots;
        }

        @Generated
        public double getPrecision() {
            return this.precision;
        }

        @Generated
        public double getSelectionThreshold() {
            return this.selectionThreshold;
        }

        @Generated
        public int getMaxIter() {
            return this.maxIter;
        }

        @Generated
        public boolean equals(@Nullable Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Specification)) {
                return false;
            }
            Specification other = (Specification)o;
            if (this.getSplineOrder() != other.getSplineOrder()) {
                return false;
            }
            if (Double.compare(this.getPeriod(), other.getPeriod()) != 0) {
                return false;
            }
            if (Double.compare(this.getPrecision(), other.getPrecision()) != 0) {
                return false;
            }
            if (Double.compare(this.getSelectionThreshold(), other.getSelectionThreshold()) != 0) {
                return false;
            }
            if (this.getMaxIter() != other.getMaxIter()) {
                return false;
            }
            DoubleSeq this$x = this.getX();
            DoubleSeq other$x = other.getX();
            if (this$x == null ? other$x != null : !this$x.equals(other$x)) {
                return false;
            }
            DoubleSeq this$y = this.getY();
            DoubleSeq other$y = other.getY();
            if (this$y == null ? other$y != null : !this$y.equals(other$y)) {
                return false;
            }
            if (!Arrays.equals(this.getKnots(), other.getKnots())) {
                return false;
            }
            return Arrays.equals(this.getFixedKnots(), other.getFixedKnots());
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getSplineOrder();
            long $period = Double.doubleToLongBits(this.getPeriod());
            result = result * 59 + (int)($period >>> 32 ^ $period);
            long $precision = Double.doubleToLongBits(this.getPrecision());
            result = result * 59 + (int)($precision >>> 32 ^ $precision);
            long $selectionThreshold = Double.doubleToLongBits(this.getSelectionThreshold());
            result = result * 59 + (int)($selectionThreshold >>> 32 ^ $selectionThreshold);
            result = result * 59 + this.getMaxIter();
            DoubleSeq $x = this.getX();
            result = result * 59 + ($x == null ? 43 : $x.hashCode());
            DoubleSeq $y = this.getY();
            result = result * 59 + ($y == null ? 43 : $y.hashCode());
            result = result * 59 + Arrays.hashCode(this.getKnots());
            result = result * 59 + Arrays.hashCode(this.getFixedKnots());
            return result;
        }

        @Generated
        public @NonNull String toString() {
            return "AdaptivePeriodicSpline.Specification(x=" + String.valueOf(this.getX()) + ", y=" + String.valueOf(this.getY()) + ", splineOrder=" + this.getSplineOrder() + ", period=" + this.getPeriod() + ", knots=" + Arrays.toString(this.getKnots()) + ", fixedKnots=" + Arrays.toString(this.getFixedKnots()) + ", precision=" + this.getPrecision() + ", selectionThreshold=" + this.getSelectionThreshold() + ", maxIter=" + this.getMaxIter() + ")";
        }

        @Generated
        public static class Builder {
            @Generated
            private DoubleSeq x;
            @Generated
            private DoubleSeq y;
            @Generated
            private int splineOrder;
            @Generated
            private double period;
            @Generated
            private double[] knots;
            @Generated
            private int[] fixedKnots;
            @Generated
            private double precision;
            @Generated
            private double selectionThreshold;
            @Generated
            private int maxIter;

            @Generated
            Builder() {
            }

            @Generated
            public @NonNull Builder x(DoubleSeq x) {
                this.x = x;
                return this;
            }

            @Generated
            public @NonNull Builder y(DoubleSeq y) {
                this.y = y;
                return this;
            }

            @Generated
            public @NonNull Builder splineOrder(int splineOrder) {
                this.splineOrder = splineOrder;
                return this;
            }

            @Generated
            public @NonNull Builder period(double period) {
                this.period = period;
                return this;
            }

            @Generated
            public @NonNull Builder knots(double[] knots) {
                this.knots = knots;
                return this;
            }

            @Generated
            public @NonNull Builder fixedKnots(int[] fixedKnots) {
                this.fixedKnots = fixedKnots;
                return this;
            }

            @Generated
            public @NonNull Builder precision(double precision) {
                this.precision = precision;
                return this;
            }

            @Generated
            public @NonNull Builder selectionThreshold(double selectionThreshold) {
                this.selectionThreshold = selectionThreshold;
                return this;
            }

            @Generated
            public @NonNull Builder maxIter(int maxIter) {
                this.maxIter = maxIter;
                return this;
            }

            @Generated
            public @NonNull Specification build() {
                return new Specification(this.x, this.y, this.splineOrder, this.period, this.knots, this.fixedKnots, this.precision, this.selectionThreshold, this.maxIter);
            }

            @Generated
            public @NonNull String toString() {
                return "AdaptivePeriodicSpline.Specification.Builder(x=" + String.valueOf(this.x) + ", y=" + String.valueOf(this.y) + ", splineOrder=" + this.splineOrder + ", period=" + this.period + ", knots=" + Arrays.toString(this.knots) + ", fixedKnots=" + Arrays.toString(this.fixedKnots) + ", precision=" + this.precision + ", selectionThreshold=" + this.selectionThreshold + ", maxIter=" + this.maxIter + ")";
            }
        }
    }
}

