/*
 * Decompiled with CFR 0.152.
 */
package dr.evolution.tree;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.SimpleTree;
import dr.evolution.tree.Tree;
import dr.math.ConjugateDirectionSearch;
import dr.math.MultivariateFunction;
import dr.math.MultivariateMinimum;

public class LeastSquaresClockTree
extends SimpleTree {
    private MultivariateFunction leastSquaresClock = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            for (int i = 0; i < LeastSquaresClockTree.this.getInternalNodeCount(); ++i) {
                ((LeastSquaresClockTree)LeastSquaresClockTree.this).nodeValues[i] = dArray[i];
            }
            LeastSquaresClockTree.this.setNodeHeightsFromValues(LeastSquaresClockTree.this.getRoot());
            if (LeastSquaresClockTree.this.optimizeMu) {
                LeastSquaresClockTree.this.mu = dArray[LeastSquaresClockTree.this.muIndex];
            }
            double d = LeastSquaresClockTree.this.getSumOfSquares();
            return d;
        }

        @Override
        public int getNumArguments() {
            if (LeastSquaresClockTree.this.optimizeMu) {
                return LeastSquaresClockTree.this.getInternalNodeCount() + 1;
            }
            return LeastSquaresClockTree.this.getInternalNodeCount();
        }

        @Override
        public double getLowerBound(int n) {
            if (LeastSquaresClockTree.this.optimizeMu && n == LeastSquaresClockTree.this.muIndex) {
                return Double.MIN_VALUE;
            }
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            if (LeastSquaresClockTree.this.optimizeMu && n == LeastSquaresClockTree.this.muIndex) {
                return Double.MAX_VALUE;
            }
            return Double.MAX_VALUE;
        }
    };
    private int nodeCount;
    private double[] nodeValues;
    private Tree sourceTree;
    private double mu;
    private boolean optimizeMu;
    private int muIndex;

    public LeastSquaresClockTree(Tree tree) {
        super(tree);
        this.sourceTree = tree;
        this.mu = 1.0;
        this.optimizeMu = true;
    }

    public LeastSquaresClockTree(Tree tree, double d) {
        this.sourceTree = tree;
        this.mu = d;
        this.optimizeMu = false;
    }

    public double getMu() {
        return this.mu;
    }

    public void optimize() {
        int n = this.nodeCount = this.getInternalNodeCount();
        if (this.optimizeMu) {
            ++n;
            this.muIndex = this.nodeCount;
        }
        ConjugateDirectionSearch conjugateDirectionSearch = new ConjugateDirectionSearch();
        this.nodeValues = new double[this.nodeCount];
        double[] dArray = new double[n];
        for (int i = 0; i < this.nodeCount; ++i) {
            dArray[i] = 1.0;
        }
        if (this.optimizeMu) {
            dArray[this.muIndex] = this.mu;
        }
        ((MultivariateMinimum)conjugateDirectionSearch).optimize(this.leastSquaresClock, dArray, 1.0E-8, 1.0E-8);
    }

    public double getSumOfSquares() {
        double[] dArray = new double[]{0.0};
        NodeRef nodeRef = this.getRoot();
        if (this.getChildCount(nodeRef) != 2) {
            throw new IllegalArgumentException("The tree must have a bifurcating root node");
        }
        NodeRef nodeRef2 = this.getChild(nodeRef, 0);
        NodeRef nodeRef3 = this.getChild(nodeRef, 1);
        if (!this.isExternal(nodeRef2)) {
            this.getSumOfSquaresAtNode(nodeRef2, dArray);
        }
        if (!this.isExternal(nodeRef3)) {
            this.getSumOfSquaresAtNode(nodeRef3, dArray);
        }
        double d = this.sourceTree.getBranchLength(this.sourceTree.getNode(nodeRef2.getNumber())) + this.sourceTree.getBranchLength(this.sourceTree.getNode(nodeRef3.getNumber()));
        double d2 = this.getNodeHeight(nodeRef) - this.getNodeHeight(nodeRef2) + this.getNodeHeight(nodeRef) - this.getNodeHeight(nodeRef3);
        double d3 = d2 * this.mu;
        double d4 = d - d3;
        dArray[0] = dArray[0] + d4 * d4;
        return dArray[0];
    }

    private void getSumOfSquaresAtNode(NodeRef nodeRef, double[] dArray) {
        if (!this.isExternal(nodeRef)) {
            for (int i = 0; i < this.getChildCount(nodeRef); ++i) {
                NodeRef nodeRef2 = this.getChild(nodeRef, i);
                if (this.isExternal(nodeRef2)) continue;
                this.getSumOfSquaresAtNode(nodeRef2, dArray);
                dArray[0] = dArray[0] + this.getScoreAtNode(nodeRef2);
            }
        }
    }

    private double getScoreAtNode(NodeRef nodeRef) {
        double d = this.sourceTree.getBranchLength(this.sourceTree.getNode(nodeRef.getNumber()));
        double d2 = this.getNodeHeight(this.getParent(nodeRef)) - this.getNodeHeight(nodeRef);
        double d3 = d2 * this.mu;
        double d4 = d - d3;
        return d4 * d4;
    }

    private double setNodeHeightsFromValues(NodeRef nodeRef) {
        if (!this.isExternal(nodeRef)) {
            double d = this.setNodeHeightsFromValues(this.getChild(nodeRef, 0));
            for (int i = 1; i < this.getChildCount(nodeRef); ++i) {
                double d2 = this.setNodeHeightsFromValues(this.getChild(nodeRef, i));
                if (!(d2 > d)) continue;
                d = d2;
            }
            this.setNodeHeight(nodeRef, d + this.nodeValues[nodeRef.getNumber() - this.getExternalNodeCount()]);
        }
        return this.getNodeHeight(nodeRef);
    }
}

