/*
 * Decompiled with CFR 0.152.
 */
package jdplus.toolkit.base.core.stats.linearmodel;

import jdplus.toolkit.base.api.data.DoubleSeq;
import jdplus.toolkit.base.api.math.matrices.Matrix;
import jdplus.toolkit.base.api.stats.StatisticalTest;
import jdplus.toolkit.base.api.stats.TestType;
import jdplus.toolkit.base.core.data.DataBlock;
import jdplus.toolkit.base.core.dstats.F;
import jdplus.toolkit.base.core.math.matrices.FastMatrix;
import jdplus.toolkit.base.core.math.matrices.LowerTriangularMatrix;
import jdplus.toolkit.base.core.math.matrices.MatrixFactory;
import jdplus.toolkit.base.core.math.matrices.SymmetricMatrix;
import jdplus.toolkit.base.core.stats.likelihood.ConcentratedLikelihoodWithMissing;
import jdplus.toolkit.base.core.stats.tests.TestsUtility;
import lombok.NonNull;

public class JointTest {
    private final DoubleSeq b;
    private final FastMatrix bvar;
    private final double rss;
    private final int n;
    private int hyperParameters;
    private FastMatrix R;
    private DoubleSeq alpha;
    private int[] coef;
    private boolean blue = true;

    public JointTest(DoubleSeq coefficients, FastMatrix unscaledVariance, double rss, int n) {
        this.b = coefficients;
        this.bvar = unscaledVariance;
        this.rss = rss;
        this.n = n;
    }

    public JointTest(ConcentratedLikelihoodWithMissing ll) {
        this.b = ll.coefficients();
        this.bvar = FastMatrix.of((Matrix)ll.unscaledCovariance());
        this.rss = ll.ssq();
        this.n = ll.dim();
    }

    public JointTest variableSelection(int[] variableSelection) {
        this.coef = variableSelection;
        this.R = null;
        this.alpha = null;
        return this;
    }

    public JointTest variableSelection(int start, int n) {
        if (start == 0 && n == this.b.length()) {
            return this;
        }
        this.coef = new int[n];
        for (int i = 0; i < n; ++i) {
            this.coef[i] = start + i;
        }
        this.R = null;
        this.alpha = null;
        return this;
    }

    public JointTest constraints(@NonNull FastMatrix R, @NonNull DoubleSeq alpha) {
        if (R == null) {
            throw new NullPointerException("R is marked non-null but is null");
        }
        if (alpha == null) {
            throw new NullPointerException("alpha is marked non-null but is null");
        }
        if (R.getRowsCount() != alpha.length()) {
            throw new IllegalArgumentException();
        }
        this.R = R;
        this.alpha = alpha;
        this.coef = null;
        return this;
    }

    public JointTest ml() {
        this.blue = false;
        return this;
    }

    public JointTest blue() {
        this.blue = true;
        return this;
    }

    public JointTest hyperParametersCount(int nhp) {
        this.hyperParameters = nhp;
        return this;
    }

    public JointTest deterministicRegressors(boolean det) {
        return this;
    }

    public StatisticalTest build() {
        DataBlock rb = this.rb();
        FastMatrix rwr = this.rwr();
        int nx = rb.length();
        int df = this.df();
        SymmetricMatrix.lcholesky(rwr);
        LowerTriangularMatrix.solveLx(rwr, rb);
        double f = rb.ssq() / (double)nx / (this.rss / (double)df);
        F fdist = new F(nx, df);
        return TestsUtility.testOf(f, fdist, TestType.Upper);
    }

    private FastMatrix rwr() {
        if (this.coef != null) {
            return MatrixFactory.select(this.bvar, this.coef, this.coef);
        }
        if (this.R != null) {
            return SymmetricMatrix.XSXt(this.bvar, this.R);
        }
        return this.bvar.deepClone();
    }

    private DataBlock rb() {
        double[] rb;
        if (this.coef != null) {
            rb = new double[this.coef.length];
            for (int i = 0; i < rb.length; ++i) {
                rb[i] = this.b.get(this.coef[i]);
            }
        } else if (this.R != null) {
            rb = new double[this.R.getRowsCount()];
            for (int i = 0; i < rb.length; ++i) {
                rb[i] = this.R.row(i).dot(this.b) - this.alpha.get(i);
            }
        } else {
            rb = this.b.toArray();
        }
        return DataBlock.of(rb);
    }

    private int df() {
        if (this.blue) {
            return this.n - this.bvar.getRowsCount() - this.hyperParameters;
        }
        return this.n;
    }
}

