/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.speciation;

import dr.evolution.tree.Tree;
import dr.evomodel.speciation.SpeciationLikelihood;
import dr.evomodel.speciation.SpeciationModelGradientProvider;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightProxyParameter;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.Reportable;

public class SpeciationLikelihoodGradient
implements GradientWrtParameterProvider,
Reportable,
Loggable {
    private final SpeciationLikelihood likelihood;
    private final Parameter parameter;
    private final WrtParameter wrtParameter;
    private final TreeModel tree;
    private final SpeciationModelGradientProvider provider;
    private static final boolean DO_IT_RIGHT = false;

    public SpeciationLikelihoodGradient(SpeciationLikelihood speciationLikelihood, TreeModel treeModel, WrtParameter wrtParameter) {
        this.likelihood = speciationLikelihood;
        this.tree = treeModel;
        this.wrtParameter = wrtParameter;
        this.provider = speciationLikelihood.getGradientProvider();
        this.parameter = wrtParameter.getParameter(this.provider, treeModel);
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.parameter.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        return this.wrtParameter.getGradientLogDensity(this.provider, this.tree);
    }

    public TreeModel getTree() {
        return this.tree;
    }

    @Override
    public LogColumn[] getColumns() {
        return Loggable.getColumnsFromReport(this, "SpeciationLikelihoodGradient check");
    }

    @Override
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, 0.001);
    }

    public static enum WrtParameter {
        NODE_HEIGHT("nodeHeight"){

            @Override
            double[] getGradientLogDensity(SpeciationModelGradientProvider speciationModelGradientProvider, Tree tree) {
                double[] dArray = new double[tree.getInternalNodeCount()];
                for (int i = 0; i < tree.getInternalNodeCount(); ++i) {
                    dArray[i] = speciationModelGradientProvider.getNodeHeightGradient(tree, tree.getNode(i + tree.getExternalNodeCount()));
                }
                return dArray;
            }

            @Override
            Parameter getParameter(SpeciationModelGradientProvider speciationModelGradientProvider, TreeModel treeModel) {
                return new NodeHeightProxyParameter("nodeHeightProxyParameter", treeModel, true);
            }

            @Override
            double[] filter(double[] dArray) {
                return dArray;
            }

            @Override
            double[] filter(double[] dArray, int n) {
                return dArray;
            }
        }
        ,
        BIRTH_RATE("birthRate"){

            @Override
            double[] getGradientLogDensity(SpeciationModelGradientProvider speciationModelGradientProvider, Tree tree) {
                return speciationModelGradientProvider.getBirthRateGradient(tree, null);
            }

            @Override
            Parameter getParameter(SpeciationModelGradientProvider speciationModelGradientProvider, TreeModel treeModel) {
                return speciationModelGradientProvider.getBirthRateParameter();
            }

            @Override
            double[] filter(double[] dArray) {
                return new double[]{dArray[0]};
            }

            @Override
            double[] filter(double[] dArray, int n) {
                double[] dArray2 = new double[n];
                for (int i = 0; i < n; ++i) {
                    dArray2[i] = dArray[i * 5];
                }
                return dArray2;
            }
        }
        ,
        DEATH_RATE("deathRate"){

            @Override
            double[] getGradientLogDensity(SpeciationModelGradientProvider speciationModelGradientProvider, Tree tree) {
                return speciationModelGradientProvider.getDeathRateGradient(tree, null);
            }

            @Override
            Parameter getParameter(SpeciationModelGradientProvider speciationModelGradientProvider, TreeModel treeModel) {
                return speciationModelGradientProvider.getDeathRateParameter();
            }

            @Override
            double[] filter(double[] dArray) {
                return new double[]{dArray[1]};
            }

            @Override
            double[] filter(double[] dArray, int n) {
                double[] dArray2 = new double[n];
                for (int i = 0; i < n; ++i) {
                    dArray2[i] = dArray[i * 5 + 1];
                }
                return dArray2;
            }
        }
        ,
        SAMPLING_RATE("samplingRate"){

            @Override
            double[] getGradientLogDensity(SpeciationModelGradientProvider speciationModelGradientProvider, Tree tree) {
                return speciationModelGradientProvider.getSamplingRateGradient(tree, null);
            }

            @Override
            Parameter getParameter(SpeciationModelGradientProvider speciationModelGradientProvider, TreeModel treeModel) {
                return speciationModelGradientProvider.getSamplingRateParameter();
            }

            @Override
            double[] filter(double[] dArray) {
                return new double[]{dArray[2]};
            }

            @Override
            double[] filter(double[] dArray, int n) {
                double[] dArray2 = new double[n];
                for (int i = 0; i < n; ++i) {
                    dArray2[i] = dArray[i * 5 + 2];
                }
                return dArray2;
            }
        }
        ,
        SAMPLING_PROBABILITY("samplingProbability"){

            @Override
            double[] getGradientLogDensity(SpeciationModelGradientProvider speciationModelGradientProvider, Tree tree) {
                return speciationModelGradientProvider.getSamplingProbabilityGradient(tree, null);
            }

            @Override
            Parameter getParameter(SpeciationModelGradientProvider speciationModelGradientProvider, TreeModel treeModel) {
                return speciationModelGradientProvider.getSamplingProbabilityParameter();
            }

            @Override
            double[] filter(double[] dArray) {
                return new double[]{dArray[3]};
            }

            @Override
            double[] filter(double[] dArray, int n) {
                double[] dArray2 = new double[n];
                for (int i = 0; i < n; ++i) {
                    dArray2[i] = dArray[i * 5 + 3];
                }
                return dArray2;
            }
        }
        ,
        TREATMENT_PROBABILITY("treatmentProbability"){

            @Override
            double[] getGradientLogDensity(SpeciationModelGradientProvider speciationModelGradientProvider, Tree tree) {
                return speciationModelGradientProvider.getTreatmentProbabilityGradient(tree, null);
            }

            @Override
            Parameter getParameter(SpeciationModelGradientProvider speciationModelGradientProvider, TreeModel treeModel) {
                return speciationModelGradientProvider.getTreatmentProbabilityParameter();
            }

            @Override
            double[] filter(double[] dArray) {
                return new double[]{dArray[4]};
            }

            @Override
            double[] filter(double[] dArray, int n) {
                double[] dArray2 = new double[n];
                for (int i = 0; i < n; ++i) {
                    dArray2[i] = dArray[i * 5 + 4];
                }
                return dArray2;
            }
        }
        ,
        ALL("all"){

            @Override
            double[] getGradientLogDensity(SpeciationModelGradientProvider speciationModelGradientProvider, Tree tree) {
                throw new RuntimeException("Not yet implemented");
            }

            @Override
            Parameter getParameter(SpeciationModelGradientProvider speciationModelGradientProvider, TreeModel treeModel) {
                CompoundParameter compoundParameter = new CompoundParameter("allSpeciationParameters");
                compoundParameter.addParameter(speciationModelGradientProvider.getBirthRateParameter());
                compoundParameter.addParameter(speciationModelGradientProvider.getDeathRateParameter());
                compoundParameter.addParameter(speciationModelGradientProvider.getSamplingRateParameter());
                compoundParameter.addParameter(speciationModelGradientProvider.getSamplingProbabilityParameter());
                compoundParameter.addParameter(speciationModelGradientProvider.getTreatmentProbabilityParameter());
                return compoundParameter;
            }

            @Override
            double[] filter(double[] dArray) {
                return dArray;
            }

            @Override
            double[] filter(double[] dArray, int n) {
                return dArray;
            }
        };

        private final String name;

        private WrtParameter(String string2) {
            this.name = string2;
        }

        abstract double[] getGradientLogDensity(SpeciationModelGradientProvider var1, Tree var2);

        abstract Parameter getParameter(SpeciationModelGradientProvider var1, TreeModel var2);

        abstract double[] filter(double[] var1);

        abstract double[] filter(double[] var1, int var2);

        public static WrtParameter factory(String string) {
            for (WrtParameter wrtParameter : WrtParameter.values()) {
                if (!string.equalsIgnoreCase(wrtParameter.name)) continue;
                return wrtParameter;
            }
            return null;
        }
    }
}

