/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.mcmc;

import dr.inference.loggers.Logger;
import dr.inference.loggers.MCLogger;
import dr.inference.markovchain.MarkovChain;
import dr.inference.markovchain.MarkovChainListener;
import dr.inference.mcmc.MCMC;
import dr.inference.mcmc.MCMCCriterion;
import dr.inference.model.Model;
import dr.inference.model.PathLikelihood;
import dr.inference.operators.CombinedOperatorSchedule;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.OperatorAnalysisPrinter;
import dr.inference.operators.OperatorSchedule;
import dr.inference.operators.PathDependent;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.Identifiable;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.BetaDistributionImpl;

public class MarginalLikelihoodEstimator
implements Runnable,
Identifiable,
Citable {
    private final MarkovChainListener chainListener = new MarkovChainListener(){

        @Override
        public void currentState(long l, MarkovChain markovChain, Model model) {
            MarginalLikelihoodEstimator.this.currentState = l;
            if (MarginalLikelihoodEstimator.this.currentState >= MarginalLikelihoodEstimator.this.burnin) {
                for (MCLogger mCLogger : MarginalLikelihoodEstimator.this.loggers) {
                    mCLogger.log(l);
                }
            }
        }

        @Override
        public void bestState(long l, MarkovChain markovChain, Model model) {
            MarginalLikelihoodEstimator.this.currentState = l;
        }

        @Override
        public void finished(long l, MarkovChain markovChain) {
            MarginalLikelihoodEstimator.this.currentState = l;
            OperatorAnalysisPrinter.showOperatorAnalysis(System.out, MarginalLikelihoodEstimator.this.schedule, false);
            for (MCLogger mCLogger : MarginalLikelihoodEstimator.this.loggers) {
                mCLogger.stopLogging();
            }
        }
    };
    private boolean spawnable = true;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newIntegerRule("chainLength"), AttributeRule.newIntegerRule("pathSteps", true), AttributeRule.newIntegerRule("burnin", true), AttributeRule.newIntegerRule("prerun", true), AttributeRule.newBooleanRule("linear", true), AttributeRule.newBooleanRule("lacing", true), AttributeRule.newBooleanRule("spawn", true), AttributeRule.newBooleanRule("printOperatorAnalysis", true), AttributeRule.newStringRule("pathScheme", true), AttributeRule.newDoubleArrayRule("fixedValues", true), AttributeRule.newDoubleRule("alpha", true), AttributeRule.newDoubleRule("beta", true), new ElementRule("samplers", new XMLSyntaxRule[]{new ElementRule(MCMC.class, 1, Integer.MAX_VALUE)}, false), new ElementRule(PathLikelihood.class), new ElementRule(MCLogger.class, 1, Integer.MAX_VALUE)};

        @Override
        public String getParserName() {
            return MarginalLikelihoodEstimator.MARGINAL_LIKELIHOOD_ESTIMATOR;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            Object object;
            Object object2;
            boolean bl;
            int n;
            PathLikelihood pathLikelihood = (PathLikelihood)xMLObject.getChild(PathLikelihood.class);
            ArrayList<MCLogger> arrayList = new ArrayList<MCLogger>();
            for (n = 0; n < xMLObject.getChildCount(); ++n) {
                if (!(xMLObject.getChild(n) instanceof MCLogger)) continue;
                arrayList.add((MCLogger)xMLObject.getChild(n));
            }
            n = xMLObject.getIntegerAttribute(MarginalLikelihoodEstimator.CHAIN_LENGTH);
            int n2 = -1;
            double[] dArray = null;
            if (xMLObject.hasAttribute(MarginalLikelihoodEstimator.PATH_STEPS)) {
                n2 = xMLObject.getIntegerAttribute(MarginalLikelihoodEstimator.PATH_STEPS);
            } else if (xMLObject.hasAttribute(MarginalLikelihoodEstimator.FIXED_VALUE)) {
                dArray = xMLObject.getDoubleArrayAttribute(MarginalLikelihoodEstimator.FIXED_VALUE);
            } else {
                throw new RuntimeException("Either a number of path steps or predefined beta values need to be provided.");
            }
            if (xMLObject.hasAttribute(MarginalLikelihoodEstimator.PRINT_OPERATOR_ANALYSIS)) {
                SHOW_OPERATOR_ANALYSIS = xMLObject.getBooleanAttribute(MarginalLikelihoodEstimator.PRINT_OPERATOR_ANALYSIS);
            }
            int n3 = -1;
            if (xMLObject.hasAttribute(MarginalLikelihoodEstimator.BURNIN)) {
                n3 = xMLObject.getIntegerAttribute(MarginalLikelihoodEstimator.BURNIN);
            }
            int n4 = -1;
            if (xMLObject.hasAttribute(MarginalLikelihoodEstimator.PRERUN)) {
                n4 = xMLObject.getIntegerAttribute(MarginalLikelihoodEstimator.PRERUN);
            }
            PathScheme pathScheme = (bl = xMLObject.getAttribute(MarginalLikelihoodEstimator.LINEAR, true).booleanValue()) ? PathScheme.LINEAR : PathScheme.GEOMETRIC;
            if (xMLObject.hasAttribute(MarginalLikelihoodEstimator.PATH_SCHEME)) {
                pathScheme = PathScheme.parseFromString(xMLObject.getAttribute(MarginalLikelihoodEstimator.PATH_SCHEME, PathScheme.LINEAR.getText()));
            }
            for (int i = 0; i < xMLObject.getChildCount(); ++i) {
                object2 = xMLObject.getChild(i);
                if (!(object2 instanceof Logger)) continue;
            }
            CombinedOperatorSchedule combinedOperatorSchedule = new CombinedOperatorSchedule();
            object2 = xMLObject.getChild(MarginalLikelihoodEstimator.MCMC);
            for (int i = 0; i < ((XMLObject)object2).getChildCount(); ++i) {
                if (!(((XMLObject)object2).getChild(i) instanceof MCMC)) continue;
                object = (MCMC)((XMLObject)object2).getChild(i);
                if (n4 > 0) {
                    java.util.logging.Logger.getLogger("dr.inference").info("Path Sampling Marginal Likelihood Estimator:\n\tEquilibrating chain " + ((MCMC)object).getId() + " for " + n4 + " iterations.");
                    for (Logger logger : ((MCMC)object).getLoggers()) {
                        logger.stopLogging();
                    }
                    ((MCMC)object).getMarkovChain().runChain(n4, false);
                }
                if (xMLObject.getChild(OperatorSchedule.class) != null) {
                    combinedOperatorSchedule.addOperatorSchedule((OperatorSchedule)xMLObject.getChild(OperatorSchedule.class));
                    continue;
                }
                combinedOperatorSchedule.addOperatorSchedule(((MCMC)object).getOperatorSchedule());
            }
            if (combinedOperatorSchedule.getScheduleCount() == 0) {
                System.err.println("Error: no mcmc objects provided in construction. Bayes Factor estimation will likely fail.");
            }
            MarginalLikelihoodEstimator marginalLikelihoodEstimator = new MarginalLikelihoodEstimator(MarginalLikelihoodEstimator.MARGINAL_LIKELIHOOD_ESTIMATOR, n, n3, n2, dArray, pathScheme, pathLikelihood, combinedOperatorSchedule, arrayList);
            if (!xMLObject.getAttribute(MarginalLikelihoodEstimator.SPAWN, true).booleanValue()) {
                marginalLikelihoodEstimator.setSpawnable(false);
            }
            if (xMLObject.hasAttribute(MarginalLikelihoodEstimator.ALPHA)) {
                marginalLikelihoodEstimator.setAlphaFactor(xMLObject.getAttribute(MarginalLikelihoodEstimator.ALPHA, 0.5));
            }
            if (xMLObject.hasAttribute(MarginalLikelihoodEstimator.BETA)) {
                marginalLikelihoodEstimator.setBetaFactor(xMLObject.getAttribute(MarginalLikelihoodEstimator.BETA, 0.5));
            }
            object = "";
            if (pathScheme == PathScheme.ONE_SIDED_BETA) {
                object = (String)object + "(1," + marginalLikelihoodEstimator.getBetaFactor() + ")";
            } else if (pathScheme == PathScheme.BETA) {
                object = (String)object + "(" + marginalLikelihoodEstimator.getAlphaFactor() + "," + marginalLikelihoodEstimator.getBetaFactor() + ")";
            } else if (pathScheme == PathScheme.BETA_QUANTILE) {
                object = (String)object + "(" + marginalLikelihoodEstimator.getAlphaFactor() + ")";
            } else if (pathScheme == PathScheme.SIGMOID) {
                object = (String)object + "(" + marginalLikelihoodEstimator.getAlphaFactor() + ")";
            }
            java.util.logging.Logger.getLogger("dr.inference").info("\nCreating the Marginal Likelihood Estimator chain:\n  chainLength=" + n + "\n  pathSteps=" + n2 + "\n  pathScheme=" + pathScheme.getText() + (String)object);
            return marginalLikelihoodEstimator;
        }

        @Override
        public String getParserDescription() {
            return "This element returns an MCMC chain and runs the chain as a side effect.";
        }

        @Override
        public Class getReturnType() {
            return MCMC.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };
    private final MarkovChain mc;
    private OperatorSchedule schedule;
    private String id = null;
    private long currentState;
    private final long chainLength;
    private long burnin;
    private final long burninLength;
    private int pathSteps;
    private final PathScheme scheme;
    private double alphaFactor = 0.5;
    private double betaFactor = 0.5;
    private double[] fixedRunValues;
    private final double pathDelta;
    private double pathParameter;
    private final List<MCLogger> loggers;
    private final PathLikelihood pathLikelihood;
    public static final String MARGINAL_LIKELIHOOD_ESTIMATOR = "marginalLikelihoodEstimator";
    public static final String CHAIN_LENGTH = "chainLength";
    public static final String PATH_STEPS = "pathSteps";
    public static final String FIXED = "fixed";
    public static final String LINEAR = "linear";
    public static final String LACING = "lacing";
    public static final String SPAWN = "spawn";
    public static final String BURNIN = "burnin";
    public static final String MCMC = "samplers";
    public static final String PATH_SCHEME = "pathScheme";
    public static final String FIXED_VALUE = "fixedValues";
    public static final String ALPHA = "alpha";
    public static final String BETA = "beta";
    public static final String PRERUN = "prerun";
    public static final String PRINT_OPERATOR_ANALYSIS = "printOperatorAnalysis";
    private static boolean SHOW_OPERATOR_ANALYSIS = false;

    public MarginalLikelihoodEstimator(String string, int n, int n2, int n3, double[] dArray, PathScheme pathScheme, PathLikelihood pathLikelihood, OperatorSchedule operatorSchedule, List<MCLogger> list) {
        this.id = string;
        this.chainLength = n;
        this.pathSteps = n3;
        this.scheme = pathScheme;
        this.schedule = operatorSchedule;
        this.fixedRunValues = dArray;
        this.burninLength = n2;
        MCMCCriterion mCMCCriterion = new MCMCCriterion();
        this.pathDelta = 1.0 / (double)n3;
        this.pathParameter = 1.0;
        this.pathLikelihood = pathLikelihood;
        pathLikelihood.setPathParameter(this.pathParameter);
        this.mc = new MarkovChain(pathLikelihood, operatorSchedule, mCMCCriterion, 0L, 0, 0.0, true, false);
        this.loggers = list;
    }

    private void setDefaultBurnin() {
        this.burnin = this.burninLength == -1L ? (long)((int)(0.1 * (double)this.chainLength)) : this.burninLength;
    }

    public void integrate(Integrator integrator) {
        this.setDefaultBurnin();
        this.mc.setCurrentLength(this.burnin);
        integrator.init();
        ((CombinedOperatorSchedule)this.schedule).reset();
        this.pathParameter = integrator.nextPathParameter();
        while (this.pathParameter >= 0.0) {
            this.pathLikelihood.setPathParameter(this.pathParameter);
            this.reportIteration(this.pathParameter, this.chainLength, this.burnin, integrator.pathSteps, integrator.step);
            for (int i = 0; i < this.schedule.getOperatorCount(); ++i) {
                MCMCOperator mCMCOperator = this.schedule.getOperator(i);
                if (!(mCMCOperator instanceof PathDependent)) continue;
                ((PathDependent)((Object)mCMCOperator)).setPathParameter(this.pathParameter);
            }
            long l = this.mc.getCurrentLength();
            this.mc.setCurrentLength(0L);
            this.mc.runChain(this.burnin, false);
            this.mc.setCurrentLength(l);
            this.mc.runChain(this.chainLength, false);
            if (SHOW_OPERATOR_ANALYSIS) {
                OperatorAnalysisPrinter.showOperatorAnalysis(System.out, this.schedule, false);
            }
            ((CombinedOperatorSchedule)this.schedule).reset();
            this.pathParameter = integrator.nextPathParameter();
        }
    }

    private void reportIteration(double d, long l, long l2, long l3, long l4) {
        if (this.scheme == PathScheme.FIXED) {
            System.out.println("Attempting fixed theta (" + l4 + "/" + l3 + ") = " + d + " for " + l + " iterations + " + l2 + " burnin.");
        } else {
            System.out.println("Attempting theta (" + l4 + "/" + (l3 + 1L) + ") = " + d + " for " + l + " iterations + " + l2 + " burnin.");
        }
    }

    @Override
    public void run() {
        for (MCLogger mCLogger : this.loggers) {
            mCLogger.startLogging();
        }
        this.mc.addMarkovChainListener(this.chainListener);
        switch (this.scheme) {
            case FIXED: {
                this.integrate(new FixedThetaRun(this.fixedRunValues));
                break;
            }
            case LINEAR: {
                this.integrate(new LinearIntegrator(this.pathSteps));
                break;
            }
            case GEOMETRIC: {
                this.integrate(new GeometricIntegrator(this.pathSteps));
                break;
            }
            case ONE_SIDED_BETA: {
                this.integrate(new BetaIntegrator(1.0, this.betaFactor, this.pathSteps));
                break;
            }
            case BETA: {
                this.integrate(new BetaIntegrator(this.alphaFactor, this.betaFactor, this.pathSteps));
                break;
            }
            case BETA_QUANTILE: {
                this.integrate(new BetaQuantileIntegrator(this.alphaFactor, this.pathSteps));
                break;
            }
            case SIGMOID: {
                this.integrate(new SigmoidIntegrator(this.alphaFactor, this.pathSteps));
                break;
            }
            default: {
                throw new RuntimeException("Illegal path scheme");
            }
        }
        this.mc.removeMarkovChainListener(this.chainListener);
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.FRAMEWORK;
    }

    @Override
    public String getDescription() {
        return "Marginal likelihood estimation using path sampling / stepping-stone sampling (first 2 citations) and generalized stepping-stone sampling (3rd citation)";
    }

    @Override
    public List<Citation> getCitations() {
        return Arrays.asList(new Citation(new Author[]{new Author("G", "Baele"), new Author("P", "Lemey"), new Author("T", "Bedford"), new Author("A", "Rambaut"), new Author("MA", "Suchard"), new Author("AV", "Alekseyenko")}, "Improving the accuracy of demographic and molecular clock model comparison while accommodating phylogenetic uncertainty", 2012, "Mol. Biol. Evol.", 29, 2157, 2167, Citation.Status.PUBLISHED), new Citation(new Author[]{new Author("G", "Baele"), new Author("WLS", "Li"), new Author("AJ", "Drummond"), new Author("MA", "Suchard"), new Author("P", "Lemey")}, "Accurate model selection of relaxed molecular clocks in Bayesian phylogenetics", 2013, "Mol. Biol. Evol.", 30, 239, 243, Citation.Status.PUBLISHED), new Citation(new Author[]{new Author("G", "Baele"), new Author("P", "Lemey"), new Author("MA", "Suchard")}, "Genealogical working distributions for Bayesian model testing with phylogenetic uncertainty", 2016, "Syst. Biol.", 65, 250, 264, Citation.Status.PUBLISHED));
    }

    public boolean getSpawnable() {
        return this.spawnable;
    }

    public void setSpawnable(boolean bl) {
        this.spawnable = bl;
    }

    public void setAlphaFactor(double d) {
        this.alphaFactor = d;
    }

    public void setBetaFactor(double d) {
        this.betaFactor = d;
    }

    public double getAlphaFactor() {
        return this.alphaFactor;
    }

    public double getBetaFactor() {
        return this.betaFactor;
    }

    @Override
    public String getId() {
        return this.id;
    }

    @Override
    public void setId(String string) {
        this.id = string;
    }

    static enum PathScheme {
        FIXED("fixed"),
        LINEAR("linear"),
        GEOMETRIC("geometric"),
        BETA("beta"),
        ONE_SIDED_BETA("oneSidedBeta"),
        BETA_QUANTILE("betaQuantile"),
        SIGMOID("sigmoid");

        private final String text;

        private PathScheme(String string2) {
            this.text = string2;
        }

        public String getText() {
            return this.text;
        }

        public static PathScheme parseFromString(String string) {
            for (PathScheme pathScheme : PathScheme.values()) {
                if (pathScheme.getText().compareToIgnoreCase(string) != 0) continue;
                return pathScheme;
            }
            return null;
        }
    }

    public class GeometricIntegrator
    extends Integrator {
        public GeometricIntegrator(int n) {
            super(n);
        }

        @Override
        double nextPathParameter() {
            if (this.step > this.pathSteps) {
                return -1.0;
            }
            if (this.step == this.pathSteps) {
                ++this.step;
                return 0.0;
            }
            ++this.step;
            return Math.pow(2.0, -(this.step - 1));
        }
    }

    public class BetaIntegrator
    extends Integrator {
        private BetaDistributionImpl betaDistribution;

        public BetaIntegrator(double d, double d2, int n) {
            super(n);
            this.betaDistribution = new BetaDistributionImpl(d, d2);
        }

        @Override
        double nextPathParameter() {
            if (this.step > this.pathSteps) {
                return -1.0;
            }
            if (this.step == 0) {
                ++this.step;
                return 1.0;
            }
            if (this.step + 1 < this.pathSteps) {
                double d = (double)this.step / (double)(this.pathSteps - 1);
                try {
                    ++this.step;
                    return 1.0 - this.betaDistribution.inverseCumulativeProbability(d);
                }
                catch (MathException mathException) {
                    mathException.printStackTrace();
                }
            }
            ++this.step;
            return 0.0;
        }
    }

    public class BetaQuantileIntegrator
    extends Integrator {
        private double alpha;

        public BetaQuantileIntegrator(double d, int n) {
            super(n);
            this.alpha = d;
        }

        @Override
        double nextPathParameter() {
            if (this.step > this.pathSteps) {
                return -1.0;
            }
            double d = Math.pow((double)(this.pathSteps - this.step) / (double)this.pathSteps, 1.0 / this.alpha);
            ++this.step;
            return d;
        }
    }

    public class SigmoidIntegrator
    extends Integrator {
        private double alpha;

        public SigmoidIntegrator(double d, int n) {
            super(n);
            this.alpha = d;
        }

        @Override
        double nextPathParameter() {
            if (this.step == 0) {
                ++this.step;
                return 1.0;
            }
            if (this.step == this.pathSteps) {
                ++this.step;
                return 0.0;
            }
            if (this.step > this.pathSteps) {
                return -1.0;
            }
            double d = (double)(this.pathSteps - this.step) / (double)this.pathSteps - 0.5;
            ++this.step;
            return Math.exp(this.alpha * d) / (Math.exp(this.alpha * d) + Math.exp(-this.alpha * d));
        }
    }

    public class LinearIntegrator
    extends Integrator {
        public LinearIntegrator(int n) {
            super(n);
        }

        @Override
        double nextPathParameter() {
            if (this.step > this.pathSteps) {
                return -1.0;
            }
            double d = 1.0 - (double)this.step / (double)this.pathSteps;
            ++this.step;
            return d;
        }
    }

    public class FixedThetaRun
    extends Integrator {
        private double[] value;

        public FixedThetaRun(double[] dArray) {
            super(dArray.length);
            this.value = dArray;
        }

        @Override
        double nextPathParameter() {
            if (this.step < this.value.length) {
                ++this.step;
                return this.value[this.step - 1];
            }
            return -1.0;
        }
    }

    public abstract class Integrator {
        protected int step;
        protected int pathSteps;

        protected Integrator(int n) {
            this.pathSteps = n;
        }

        public void init() {
            this.step = 0;
        }

        abstract double nextPathParameter();
    }
}

