/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchmodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.TaxonList;
import dr.evomodel.branchmodel.BranchModel;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class BranchSpecificBranchModel
extends AbstractModel
implements BranchModel {
    private TreeModel treeModel;
    protected Map<BitSet, Clade> clades = new HashMap<BitSet, Clade>();
    private boolean updateNodeMaps = true;
    private Map<NodeRef, BranchModel.Mapping> nodeMap = new HashMap<NodeRef, BranchModel.Mapping>();
    private Map<NodeRef, BranchModel.Mapping> externalNodeMap = new HashMap<NodeRef, BranchModel.Mapping>();
    private final SubstitutionModel rootSubstitutionModel;
    private final List<SubstitutionModel> substitutionModels = new ArrayList<SubstitutionModel>();
    private boolean requiresMatrixConvolution = false;

    public BranchSpecificBranchModel(TreeModel treeModel, SubstitutionModel substitutionModel) {
        super("localClockModel");
        this.treeModel = treeModel;
        this.addModel(treeModel);
        this.rootSubstitutionModel = substitutionModel;
        this.addModel(substitutionModel);
        this.substitutionModels.add(substitutionModel);
    }

    public void addClade(TaxonList taxonList, SubstitutionModel substitutionModel, double d) throws TreeUtils.MissingTaxonException {
        int n = this.substitutionModels.indexOf(substitutionModel);
        if (n == -1) {
            n = this.substitutionModels.size();
            this.substitutionModels.add(substitutionModel);
            this.addModel(substitutionModel);
        }
        BitSet bitSet = TreeUtils.getTipsBitSetForTaxa(this.treeModel, taxonList);
        Clade clade = new Clade(n, bitSet, d);
        this.clades.put(bitSet, clade);
        if (d > 0.0 || d < 1.0) {
            this.requiresMatrixConvolution = true;
        }
    }

    public void addExternalBranches(TaxonList taxonList, SubstitutionModel substitutionModel) throws TreeUtils.MissingTaxonException {
        int n = this.substitutionModels.indexOf(substitutionModel);
        if (n == -1) {
            n = this.substitutionModels.size();
            this.substitutionModels.add(substitutionModel);
            this.addModel(substitutionModel);
        }
        final int n2 = n;
        for (int i = 0; i < this.treeModel.getExternalNodeCount(); ++i) {
            NodeRef nodeRef = this.treeModel.getExternalNode(i);
            this.externalNodeMap.put(nodeRef, new BranchModel.Mapping(){

                @Override
                public int[] getOrder() {
                    return new int[]{n2};
                }

                @Override
                public double[] getWeights() {
                    return new double[]{1.0};
                }
            });
        }
    }

    public void addBackbone(TaxonList taxonList, SubstitutionModel substitutionModel) throws TreeUtils.MissingTaxonException {
        throw new UnsupportedOperationException("Not implemented yet");
    }

    @Override
    public BranchModel.Mapping getBranchModelMapping(NodeRef nodeRef) {
        BranchModel.Mapping mapping;
        if (this.updateNodeMaps) {
            this.setupNodeMaps();
        }
        if ((mapping = this.externalNodeMap.get(nodeRef)) != null) {
            return mapping;
        }
        mapping = this.nodeMap.get(nodeRef);
        if (mapping != null) {
            return mapping;
        }
        return BranchModel.DEFAULT;
    }

    @Override
    public List<SubstitutionModel> getSubstitutionModels() {
        return this.substitutionModels;
    }

    @Override
    public SubstitutionModel getRootSubstitutionModel() {
        return this.rootSubstitutionModel;
    }

    @Override
    public FrequencyModel getRootFrequencyModel() {
        return this.getRootSubstitutionModel().getFrequencyModel();
    }

    @Override
    public boolean requiresMatrixConvolution() {
        return this.requiresMatrixConvolution;
    }

    @Override
    public void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.treeModel && this.clades.size() > 0) {
            this.updateNodeMaps = true;
        }
        this.fireModelChanged();
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
    }

    @Override
    protected void storeState() {
    }

    @Override
    protected void restoreState() {
        if (this.clades.size() > 0) {
            this.updateNodeMaps = true;
        }
    }

    @Override
    protected void acceptState() {
    }

    private void setupNodeMaps() {
        if (this.clades.size() > 0) {
            this.setupNodeMaps(this.treeModel, this.treeModel.getRoot(), new BitSet());
        }
        this.updateNodeMaps = false;
    }

    private void setupNodeMaps(Tree tree, NodeRef nodeRef, BitSet bitSet) {
        NodeRef nodeRef2;
        int n;
        Clade clade;
        if (tree.isExternal(nodeRef)) {
            bitSet.set(nodeRef.getNumber());
            clade = null;
        } else {
            for (n = 0; n < tree.getChildCount(nodeRef); ++n) {
                nodeRef2 = tree.getChild(nodeRef, n);
                BitSet bitSet2 = new BitSet();
                this.setupNodeMaps(tree, nodeRef2, bitSet2);
                bitSet.or(bitSet2);
            }
            clade = this.clades.get(bitSet);
        }
        if (clade != null) {
            for (n = 0; n < tree.getChildCount(nodeRef); ++n) {
                nodeRef2 = tree.getChild(nodeRef, n);
                this.setNodeMap(tree, nodeRef2, clade);
            }
            final double d = clade.getStemWeight();
            if (d > 0.0) {
                final int n2 = clade.getIndex();
                BranchModel.Mapping mapping = this.nodeMap.get(nodeRef);
                final int n3 = mapping != null ? mapping.getOrder()[0] : 0;
                this.nodeMap.put(nodeRef, new BranchModel.Mapping(){

                    @Override
                    public int[] getOrder() {
                        return new int[]{n2, n3};
                    }

                    @Override
                    public double[] getWeights() {
                        return new double[]{d, 1.0 - d};
                    }
                });
            }
        }
    }

    private void setNodeMap(Tree tree, NodeRef nodeRef, final Clade clade) {
        if (!tree.isExternal(nodeRef)) {
            for (int i = 0; i < tree.getChildCount(nodeRef); ++i) {
                NodeRef nodeRef2 = tree.getChild(nodeRef, i);
                this.setNodeMap(tree, nodeRef2, clade);
            }
        }
        this.nodeMap.put(nodeRef, new BranchModel.Mapping(){

            @Override
            public int[] getOrder() {
                return new int[]{clade.getIndex()};
            }

            @Override
            public double[] getWeights() {
                return new double[]{1.0};
            }
        });
    }

    private class Clade {
        private final int index;
        private final BitSet tips;
        private final double stemWeight;

        Clade(int n, BitSet bitSet, double d) {
            this.index = n;
            this.tips = bitSet;
            this.stemWeight = d;
        }

        public int getIndex() {
            return this.index;
        }

        public BitSet getTips() {
            return this.tips;
        }

        public double getStemWeight() {
            return this.stemWeight;
        }
    }
}

