/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchmodel;

import dr.evolution.datatype.DataType;
import dr.evolution.tree.NodeRef;
import dr.evomodel.branchmodel.BranchModel;
import dr.evomodel.branchmodel.TransitionMatrixProviderBranchModel;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.substmodel.ComplexSubstitutionModel;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.markovjumps.SericolaSeriesMarkovReward;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

public class RewardsAwareBranchModel
extends AbstractModel
implements TransitionMatrixProviderBranchModel,
Citable,
Reportable {
    public static final String REWARDS_AWARE_BRANCH_MODEL = "RewardsAwareBranchModel";
    private final Parameter rewardRates;
    private final SubstitutionModel underlyingSubstitutionModel;
    private final TreeModel tree;
    private final ArbitraryBranchRates branchRateModel;
    private final int nstates;
    private double[] unsortedQ;
    private double[][] unsortedW;
    private double[] sortedRewardRates;
    private double[] Q;
    private double[][] W;
    private int[] perm;
    private List<SubstitutionModel> substitutionModels;
    private boolean knownTransitionMatrices;
    private boolean knownSortedQ;
    private boolean knownSortedRewardRates;
    boolean DUMMYTESTING = false;
    boolean DEBUG = false;
    boolean ignoreModelChangedEvent = false;
    double[][] storedW;

    public RewardsAwareBranchModel(TreeModel treeModel, SubstitutionModel substitutionModel, Parameter parameter, ArbitraryBranchRates arbitraryBranchRates) {
        super(REWARDS_AWARE_BRANCH_MODEL);
        this.underlyingSubstitutionModel = substitutionModel;
        this.rewardRates = parameter;
        this.branchRateModel = arbitraryBranchRates;
        this.tree = treeModel;
        if (substitutionModel == null) {
            throw new IllegalArgumentException("RewardsAwareBranchModel must be provided with an underlying substitution model");
        }
        this.nstates = substitutionModel.getDataType().getStateCount();
        this.sortedRewardRates = new double[this.nstates];
        this.unsortedQ = new double[this.nstates * this.nstates];
        this.Q = new double[this.nstates * this.nstates];
        this.W = new double[treeModel.getNodeCount()][this.nstates * this.nstates];
        this.unsortedW = new double[treeModel.getNodeCount()][this.nstates * this.nstates];
        this.perm = new int[this.nstates];
        this.knownSortedRewardRates = false;
        this.knownSortedQ = false;
        this.knownTransitionMatrices = false;
        this.addModel(substitutionModel);
        this.addModel(arbitraryBranchRates);
        this.addVariable(parameter);
    }

    @Override
    public FrequencyModel getRootFrequencyModel() {
        return this.underlyingSubstitutionModel.getFrequencyModel();
    }

    @Override
    public SubstitutionModel getRootSubstitutionModel() {
        return this.underlyingSubstitutionModel;
    }

    public double[] getTransitionMatrix(int n) {
        if (this.DUMMYTESTING) {
            NodeRef nodeRef = this.tree.getNode(n);
            double d = this.tree.getBranchLength(nodeRef);
            this.getRootSubstitutionModel().getTransitionProbabilities(d, this.W[n]);
        } else {
            this.computeTransitionMatrices();
        }
        return this.W[n];
    }

    @Override
    public double[] getTransitionMatrix(NodeRef nodeRef) {
        return this.getTransitionMatrix(nodeRef.getNumber());
    }

    private void computeTransitionMatrices() {
        if (this.DEBUG) {
            System.out.println("computeTransitionMatrices");
        }
        if (!this.knownTransitionMatrices) {
            this.sortRewardRates();
            this.sortQ();
            SericolaSeriesMarkovReward sericolaSeriesMarkovReward = new SericolaSeriesMarkovReward(this.Q, this.sortedRewardRates, this.nstates);
            for (int i = 0; i < this.tree.getNodeCount(); ++i) {
                NodeRef nodeRef = this.tree.getNode(i);
                if (this.tree.isRoot(nodeRef)) continue;
                double d = this.branchRateModel.getBranchRate(this.tree, nodeRef);
                double d2 = this.tree.getBranchLength(nodeRef);
                this.unsortedW[i] = sericolaSeriesMarkovReward.computePdf(d, d2);
            }
            this.sortW();
            this.knownTransitionMatrices = true;
        }
    }

    private void sortRewardRates() {
        if (!this.knownSortedRewardRates) {
            int n2;
            System.out.println("Sorting reward rates");
            double[] dArray = this.rewardRates.getParameterValues();
            Integer[] integerArray = new Integer[this.nstates];
            for (n2 = 0; n2 < this.nstates; ++n2) {
                integerArray[n2] = n2;
            }
            Arrays.sort(integerArray, Comparator.comparingDouble(n -> dArray[n]));
            for (n2 = 0; n2 < this.nstates; ++n2) {
                this.perm[n2] = integerArray[n2];
                this.sortedRewardRates[n2] = dArray[this.perm[n2]];
            }
            this.knownSortedQ = false;
            this.knownSortedRewardRates = true;
        }
    }

    private void sortQ() {
        if (!this.knownSortedQ) {
            this.underlyingSubstitutionModel.getInfinitesimalMatrix(this.unsortedQ);
            for (int i = 0; i < this.nstates; ++i) {
                int n = this.perm[i];
                for (int j = 0; j < this.nstates; ++j) {
                    this.Q[i * this.nstates + j] = this.unsortedQ[n * this.nstates + this.perm[j]];
                }
            }
            this.knownSortedQ = true;
        }
    }

    private void sortW() {
        int n;
        int[] nArray = new int[this.nstates];
        for (n = 0; n < this.nstates; ++n) {
            nArray[this.perm[n]] = n;
        }
        for (n = 0; n < this.tree.getNodeCount(); ++n) {
            if (this.tree.isRoot(this.tree.getNode(n))) continue;
            double[] dArray = this.unsortedW[n];
            double[] dArray2 = this.W[n];
            for (int i = 0; i < this.nstates; ++i) {
                int n2 = nArray[i] * this.nstates;
                int n3 = i * this.nstates;
                for (int j = 0; j < this.nstates; ++j) {
                    dArray2[n3 + j] = dArray[n2 + nArray[j]];
                }
            }
        }
    }

    @Override
    public BranchModel.Mapping getBranchModelMapping(NodeRef nodeRef) {
        final double[] dArray = new double[]{1.0};
        final int[] nArray = new int[]{nodeRef.getNumber()};
        return new BranchModel.Mapping(){

            @Override
            public int[] getOrder() {
                return nArray;
            }

            @Override
            public double[] getWeights() {
                return dArray;
            }
        };
    }

    @Override
    public boolean requiresMatrixConvolution() {
        return false;
    }

    public TreeModel getTree() {
        return this.tree;
    }

    public Parameter getRewardRates() {
        return this.rewardRates;
    }

    public BranchRateModel getRateBranchModel() {
        return this.branchRateModel;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (this.ignoreModelChangedEvent) {
            return;
        }
        System.out.println("Model changed event");
        if (model == this.underlyingSubstitutionModel) {
            this.knownSortedQ = false;
            this.knownTransitionMatrices = false;
            this.fireModelChanged();
        } else if (model == this.branchRateModel) {
            this.knownSortedQ = false;
            this.knownTransitionMatrices = false;
            this.fireModelChanged();
        } else {
            throw new IllegalArgumentException("Unknown model: " + model);
        }
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable != this.rewardRates) {
            throw new IllegalArgumentException("Unknown variable: " + variable);
        }
        System.out.println("Reward rates changed");
        this.knownSortedRewardRates = false;
        this.knownTransitionMatrices = false;
        this.fireModelChanged();
    }

    @Override
    protected void storeState() {
        this.storedW = (double[][])Arrays.copyOf(this.W, this.W.length);
    }

    @Override
    protected void restoreState() {
        double[][] dArray = this.storedW;
        this.storedW = this.W;
        this.W = dArray;
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.SUBSTITUTION_MODELS;
    }

    @Override
    public String getDescription() {
        return "Rewards Aware Branch model";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(new Citation(new Author[]{new Author("F", "Monti"), new Author("MA", "Suchard")}, "Dependencies between CTMCs", 2025, "TOBE", 1, 1, 1, Citation.Status.IN_PRESS));
    }

    @Override
    public String getReport() {
        if (!this.DUMMYTESTING) {
            this.computeTransitionMatrices();
        }
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("W matrix: ");
        for (int i = 0; i < this.tree.getNodeCount() - 1; ++i) {
            for (double d : this.W[i]) {
                stringBuilder.append(d).append(" ");
            }
        }
        return stringBuilder.toString();
    }

    @Override
    public List<SubstitutionModel> getSubstitutionModels() {
        System.out.println("getSubstitutionModels");
        this.buildSubstitutionModels();
        return this.substitutionModels;
    }

    protected void buildSubstitutionModels() {
        this.substitutionModels = new ArrayList<SubstitutionModel>();
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            this.ignoreModelChangedEvent = true;
            TransitionMatrixProvider transitionMatrixProvider = new TransitionMatrixProvider("RewardsAwareSubstitutionModel", this.underlyingSubstitutionModel.getDataType(), this.underlyingSubstitutionModel.getFrequencyModel(), this.W[nodeRef.getNumber()]);
            this.ignoreModelChangedEvent = false;
            this.substitutionModels.add(transitionMatrixProvider);
        }
    }

    class TransitionMatrixProvider
    extends ComplexSubstitutionModel {
        private double[] transitionMatrix;

        public TransitionMatrixProvider(String string, DataType dataType, FrequencyModel frequencyModel, double[] dArray) {
            super(string, dataType, frequencyModel, (Parameter)null);
            this.transitionMatrix = dArray;
        }

        @Override
        public void getTransitionProbabilities(double d, double[] dArray) {
            System.arraycopy(this.transitionMatrix, 0, dArray, 0, this.transitionMatrix.length);
        }

        @Override
        protected void handleModelChangedEvent(Model model, Object object, int n) {
        }

        @Override
        protected void frequenciesChanged() {
        }

        @Override
        protected void ratesChanged() {
        }

        @Override
        protected void setupRelativeRates(double[] dArray) {
        }
    }
}

