/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.preorder;

import beagle.Beagle;
import dr.evolution.alignment.PatternList;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.EvolutionaryProcessDelegate;
import dr.evomodel.treedatalikelihood.ProcessOnTreeDelegate;
import dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate;
import dr.inference.model.Model;
import dr.math.matrixAlgebra.WrappedVector;
import java.util.List;

public abstract class AbstractBeagleGradientDelegate
extends ProcessSimulationDelegate.AbstractDelegate {
    private static final String GRADIENT_TRAIT_NAME = "Gradient";
    private static final String HESSIAN_TRAIT_NAME = "Hessian";
    private static final boolean DEBUG_TRANSPOSE = false;
    protected final BeagleDataLikelihoodDelegate likelihoodDelegate;
    protected final Beagle beagle;
    protected EvolutionaryProcessDelegate evolutionaryProcessDelegate;
    protected final SiteRateModel siteRateModel;
    protected final PatternList patternList;
    protected final int patternCount;
    protected final int stateCount;
    protected final int categoryCount;
    private final int preOrderPartialOffset;
    protected double[] gradient;
    protected boolean substitutionProcessKnown;
    protected Tree tree;
    private static final boolean COUNT_TOTAL_OPERATIONS = true;
    final boolean DEBUG = false;
    private long simulateCount = 0L;
    private long getTraitCount = 0L;
    private long updatePrePartialCount = 0L;

    protected AbstractBeagleGradientDelegate(String string, Tree tree, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate) {
        super(string, tree);
        this.tree = tree;
        this.likelihoodDelegate = beagleDataLikelihoodDelegate;
        this.beagle = beagleDataLikelihoodDelegate.getBeagleInstance();
        assert (this.likelihoodDelegate.isUsePreOrder());
        this.evolutionaryProcessDelegate = beagleDataLikelihoodDelegate.getEvolutionaryProcessDelegate();
        this.siteRateModel = beagleDataLikelihoodDelegate.getSiteRateModel();
        this.patternCount = beagleDataLikelihoodDelegate.getPatternList().getPatternCount();
        this.stateCount = beagleDataLikelihoodDelegate.getPatternList().getDataType().getStateCount();
        this.categoryCount = this.siteRateModel.getCategoryCount();
        this.preOrderPartialOffset = beagleDataLikelihoodDelegate.getPartialBufferCount();
        this.patternList = beagleDataLikelihoodDelegate.getPatternList();
        beagleDataLikelihoodDelegate.addModelListener(this);
        beagleDataLikelihoodDelegate.addModelRestoreListener(this);
        this.substitutionProcessKnown = false;
    }

    protected abstract int getGradientLength();

    private void printMatrix(double[] dArray) {
        for (int i = 0; i < this.siteRateModel.getCategoryCount(); ++i) {
            System.err.println("\nRate = " + i);
            for (int j = 0; j < this.stateCount; ++j) {
                double[] dArray2 = new double[this.stateCount];
                System.arraycopy(dArray, i * this.stateCount * this.stateCount + j * this.stateCount, dArray2, 0, this.stateCount);
                System.err.println(new WrappedVector.Raw(dArray2));
            }
        }
    }

    private void debugMatrixTranspose(int[] nArray) {
        double[] dArray = new double[this.stateCount * this.stateCount * this.siteRateModel.getCategoryCount()];
        int n = nArray[4];
        this.beagle.getTransitionMatrix(n, dArray);
        this.printMatrix(dArray);
        int n2 = 1;
        this.beagle.transposeTransitionMatrices(new int[]{n}, new int[]{n2}, 1);
        this.beagle.getTransitionMatrix(n2, dArray);
        this.printMatrix(dArray);
    }

    @Override
    public void simulate(int[] nArray, int n, int n2) {
        this.simulateRoot(n2);
        this.beagle.updatePrePartials(nArray, n, -1);
        if (this.gradient == null) {
            this.gradient = new double[this.getGradientLength()];
        }
        this.getNodeDerivatives(this.tree, this.gradient, null);
        ++this.simulateCount;
        this.updatePrePartialCount += (long)n;
    }

    @Override
    public void setupStatistics() {
        throw new RuntimeException("Not used (?) with BEAGLE");
    }

    @Override
    protected void simulateRoot(int n) {
        double[] dArray = this.evolutionaryProcessDelegate.getRootStateFrequencies();
        double[] dArray2 = new double[this.stateCount * this.patternCount * this.categoryCount];
        for (int i = 0; i < this.patternCount * this.categoryCount; ++i) {
            System.arraycopy(dArray, 0, dArray2, i * this.stateCount, this.stateCount);
        }
        this.beagle.setPartials(this.getPreOrderPartialIndex(n), dArray2);
    }

    @Override
    protected void simulateNode(int n, int n2, int n3, int n4, int n5) {
        throw new RuntimeException("Not used with BEAGLE");
    }

    protected String getGradientTraitName() {
        return GRADIENT_TRAIT_NAME;
    }

    protected String getHessianTraitName() {
        return HESSIAN_TRAIT_NAME;
    }

    double[] getHessian(Tree tree, NodeRef nodeRef) {
        this.simulationProcess.cacheSimulatedTraits(nodeRef);
        double[] dArray = new double[this.getGradientLength()];
        this.getNodeDerivatives(tree, null, dArray);
        return dArray;
    }

    protected double[] getGradient(NodeRef nodeRef) {
        ++this.getTraitCount;
        this.simulationProcess.cacheSimulatedTraits(nodeRef);
        return (double[])this.gradient.clone();
    }

    protected abstract void getNodeDerivatives(Tree var1, double[] var2, double[] var3);

    protected int getFirstDerivativeMatrixBufferIndex(int n) {
        return this.evolutionaryProcessDelegate.getInfinitesimalMatrixBufferIndex(n);
    }

    protected int getSecondDerivativeMatrixBufferIndex(int n) {
        return this.evolutionaryProcessDelegate.getInfinitesimalSquaredMatrixBufferIndex(n);
    }

    @Override
    public void modelChangedEvent(Model model, Object object, int n) {
        this.substitutionProcessKnown = false;
    }

    @Override
    public void modelRestored(Model model) {
        this.substitutionProcessKnown = false;
    }

    @Override
    public int vectorizeNodeOperations(List<ProcessOnTreeDelegate.NodeOperation> list, int[] nArray) {
        int n = 0;
        for (ProcessOnTreeDelegate.NodeOperation nodeOperation : list) {
            nArray[n++] = this.getPreOrderPartialIndex(nodeOperation.getLeftChild());
            nArray[n++] = -1;
            nArray[n++] = -1;
            nArray[n++] = this.getPreOrderPartialIndex(nodeOperation.getNodeNumber());
            nArray[n++] = this.evolutionaryProcessDelegate.getMatrixIndex(nodeOperation.getLeftChild());
            nArray[n++] = this.getPostOrderPartialIndex(nodeOperation.getRightChild());
            nArray[n++] = this.evolutionaryProcessDelegate.getMatrixIndex(nodeOperation.getRightChild());
        }
        return list.size();
    }

    @Override
    public int getSingleOperationSize() {
        return 7;
    }

    protected int getPostOrderPartialIndex(int n) {
        return this.likelihoodDelegate.getPartialBufferIndex(n);
    }

    protected int getPreOrderPartialIndex(int n) {
        return this.preOrderPartialOffset + n;
    }

    public String toString() {
        return "\tsimulateCount = " + this.simulateCount + "\n\tgetTraitCount = " + this.getTraitCount + "\n\tupPrePartialCount = " + this.updatePrePartialCount + "\n";
    }
}

