/*
 * Decompiled with CFR 0.152.
 */
package dr.util;

import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.util.LKJCholeskyTransformConstrained;

public class LKJTransformConstrained
extends LKJCholeskyTransformConstrained {
    private static boolean DEBUG = false;

    public LKJTransformConstrained(int n) {
        super(n);
    }

    @Override
    protected double[] inverse(double[] dArray) {
        WrappedMatrix.WrappedUpperTriangularMatrix wrappedUpperTriangularMatrix = WrappedMatrix.WrappedUpperTriangularMatrix.fillDiagonal(super.inverse(dArray), this.dimVector);
        SymmetricMatrix symmetricMatrix = wrappedUpperTriangularMatrix.transposedProduct();
        if (DEBUG) {
            System.err.println("Z: " + SymmetricMatrix.compoundCorrelationSymmetricMatrix(dArray, this.dimVector));
            System.err.println("R: " + symmetricMatrix);
            try {
                if (!symmetricMatrix.isPD()) {
                    throw new RuntimeException("The LKJ transform should produce a Positive Definite matrix.");
                }
            }
            catch (IllegalDimension illegalDimension) {
                illegalDimension.printStackTrace();
            }
        }
        return SymmetricMatrix.extractUpperTriangular(symmetricMatrix);
    }

    @Override
    protected double[] transform(double[] dArray) {
        double[] dArray2;
        SymmetricMatrix symmetricMatrix = SymmetricMatrix.compoundCorrelationSymmetricMatrix(dArray, this.dimVector);
        try {
            dArray2 = new CholeskyDecomposition(symmetricMatrix).getStrictlyUpperTriangular();
        }
        catch (IllegalDimension illegalDimension) {
            throw new RuntimeException("Unable to decompose matrix in LKJ inverse transform.");
        }
        double[] dArray3 = super.transform(dArray2);
        if (DEBUG) {
            System.err.println("R: " + SymmetricMatrix.compoundCorrelationSymmetricMatrix(dArray, this.dimVector));
            System.err.println("L: " + new WrappedMatrix.WrappedStrictlyUpperTriangularMatrix(dArray2, this.dimVector));
            System.err.println("Z: " + SymmetricMatrix.compoundCorrelationSymmetricMatrix(dArray3, this.dimVector));
        }
        return dArray3;
    }

    @Override
    public double[] inverse(double[] dArray, int n, int n2, double d) {
        throw new RuntimeException("Not relevant for the LKJ transform.");
    }

    @Override
    public String getTransformName() {
        return "LKJTransform";
    }

    @Override
    public double[] gradientInverse(double[] dArray, int n, int n2) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    protected double getLogJacobian(double[] dArray) {
        double[] dArray2 = this.transform(dArray);
        double d = 0.0;
        int n = 0;
        for (int i = 0; i < this.dimVector - 2; ++i) {
            for (int j = i + 1; j < this.dimVector; ++j) {
                d += (double)(this.dimVector - i - 2) * Math.log(1.0 - Math.pow(dArray2[n], 2.0));
                ++n;
            }
        }
        return -0.5 * d;
    }

    @Override
    public double[] getGradientLogJacobianInverse(double[] dArray) {
        double[] dArray2 = new double[dArray.length];
        int n = 0;
        for (int i = 0; i < this.dimVector - 2; ++i) {
            for (int j = i + 1; j < this.dimVector; ++j) {
                dArray2[n] = (double)(-(this.dimVector - i - 2)) * dArray[n] / (1.0 - Math.pow(dArray[n], 2.0));
                ++n;
            }
        }
        return dArray2;
    }

    @Override
    public double[][] computeJacobianMatrixInverse(double[] dArray) {
        double[][] dArray2 = new double[this.dim][this.dim];
        for (int i = 1; i < this.dimVector; ++i) {
            dArray2[this.pos((int)0, (int)i)][i - 1] = 1.0;
        }
        this.recursionJacobian(dArray2, dArray);
        return dArray2;
    }

    private void recursionJacobian(double[][] dArray, double[] dArray2) {
        for (int i = 1; i < this.dimVector - 1; ++i) {
            for (int j = i + 1; j < this.dimVector; ++j) {
                int n;
                dArray[this.pos((int)i, (int)j)][this.pos((int)i, (int)j)] = dArray2[this.pos(i, j)];
                for (n = 1; n < i + 1; ++n) {
                    this.setUpperTriangular(dArray[this.pos(i, j)], i, j, this.recursionFormulaJacobian(dArray[this.pos(i, j)], dArray2, i, j, n, i, j));
                }
                for (n = 0; n < i; ++n) {
                    dArray[this.pos((int)n, (int)i)][this.pos((int)i, (int)j)] = dArray2[this.pos(i, j)];
                    dArray[this.pos((int)n, (int)j)][this.pos((int)i, (int)j)] = dArray2[this.pos(i, j)];
                    for (int k = 1; k < i + 1; ++k) {
                        this.setUpperTriangular(dArray[this.pos(n, i)], i, j, this.recursionFormulaJacobian(dArray[this.pos(n, i)], dArray2, i, j, k, n, i));
                        this.setUpperTriangular(dArray[this.pos(n, j)], i, j, this.recursionFormulaJacobian(dArray[this.pos(n, j)], dArray2, i, j, k, n, j));
                    }
                }
            }
        }
    }

    private double recursionFormulaJacobian(double[] dArray, double[] dArray2, int n, int n2, int n3, int n4, int n5) {
        double d = this.getUpperTriangular(dArray2, n - n3, n);
        double d2 = this.getUpperTriangular(dArray2, n - n3, n2);
        if (n == n4 && n2 == n5 && n3 == 1) {
            return Math.sqrt((1.0 - d * d) * (1.0 - d2 * d2));
        }
        if (n - n3 == n4 && n == n5) {
            return this.getUpperTriangular(dArray, n, n2) * (-d / Math.sqrt(1.0 - d * d)) * Math.sqrt(1.0 - d2 * d2) + d2;
        }
        if (n - n3 == n4 && n2 == n5) {
            return this.getUpperTriangular(dArray, n, n2) * (-d2 / Math.sqrt(1.0 - d2 * d2)) * Math.sqrt(1.0 - d * d) + d;
        }
        if (n - n3 < n4) {
            return this.getUpperTriangular(dArray, n, n2) * Math.sqrt((1.0 - d * d) * (1.0 - d2 * d2));
        }
        return this.getUpperTriangular(dArray, n, n2) * Math.sqrt((1.0 - d * d) * (1.0 - d2 * d2)) + d * d2;
    }

    private double getUpperTriangular(double[] dArray, int n, int n2) {
        assert (n <= n2);
        if (n == n2) {
            return 1.0;
        }
        return dArray[this.pos(n, n2)];
    }

    private void setUpperTriangular(double[] dArray, int n, int n2, double d) {
        assert (n < n2);
        dArray[this.pos((int)n, (int)n2)] = d;
    }

    private int pos(int n, int n2) {
        return n * (2 * this.dimVector - n - 1) / 2 + (n2 - n - 1);
    }

    public double[] inverseRecursion(double[] dArray, int n, int n2) {
        assert (n == 0 && n2 == dArray.length) : "The transform function can only be applied to the whole array of values.";
        assert (this.dimVector * (this.dimVector - 1) / 2 == dArray.length) : "The transform function can only be applied to the whole array of values.";
        for (int i = 0; i < this.dim; ++i) {
            assert (dArray[i] <= 1.0 && dArray[i] >= -1.0) : "CPCs must be between -1.0 and 1.0";
        }
        double[] dArray2 = new double[dArray.length];
        System.arraycopy(dArray, 0, dArray2, 0, dArray.length);
        this.recursionInverse(dArray2, dArray);
        return dArray2;
    }

    public double[] transformRecursion(double[] dArray, int n, int n2) {
        assert (n == 0 && n2 == dArray.length) : "The transform function can only be applied to the whole array of values.";
        double[] dArray2 = new double[dArray.length];
        System.arraycopy(dArray, 0, dArray2, 0, dArray.length);
        this.recursion(dArray2);
        if (DEBUG) {
            SymmetricMatrix symmetricMatrix = SymmetricMatrix.compoundCorrelationSymmetricMatrix(dArray, this.dimVector);
            try {
                if (!symmetricMatrix.isPD()) {
                    throw new RuntimeException("The LKJ transform should produce a Positive Definite matrix.");
                }
            }
            catch (IllegalDimension illegalDimension) {
                illegalDimension.printStackTrace();
            }
        }
        return dArray2;
    }

    private void recursionInverse(double[] dArray, double[] dArray2) {
        for (int i = 1; i < this.dimVector; ++i) {
            for (int j = i + 1; j < this.dimVector; ++j) {
                for (int k = 1; k < i + 1; ++k) {
                    this.setUpperTriangular(dArray, i, j, this.recursionInverseFormula(dArray, dArray2, i, j, k));
                }
            }
        }
    }

    private double recursionInverseFormula(double[] dArray, double[] dArray2, int n, int n2, int n3) {
        double d = this.getUpperTriangular(dArray2, n - n3, n);
        double d2 = this.getUpperTriangular(dArray2, n - n3, n2);
        return this.getUpperTriangular(dArray, n, n2) * Math.sqrt((1.0 - d * d) * (1.0 - d2 * d2)) + d * d2;
    }

    private void recursion(double[] dArray) {
        for (int i = 1; i < this.dimVector; ++i) {
            for (int j = i + 1; j < this.dimVector; ++j) {
                for (int k = 1; k < i + 1; ++k) {
                    this.setUpperTriangular(dArray, i, j, this.recursionFormula(dArray, i, j, k));
                }
            }
        }
    }

    private double recursionFormula(double[] dArray, int n, int n2, int n3) {
        double d = this.getUpperTriangular(dArray, n3 - 1, n);
        double d2 = this.getUpperTriangular(dArray, n3 - 1, n2);
        return (this.getUpperTriangular(dArray, n, n2) - d * d2) / Math.sqrt((1.0 - d * d) * (1.0 - d2 * d2));
    }
}

