/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.epidemiology.casetocase;

import dr.evomodel.coalescent.demographicmodel.DemographicModel;
import dr.evomodel.epidemiology.casetocase.AbstractCase;
import dr.evomodel.epidemiology.casetocase.AbstractOutbreak;
import dr.evomodel.epidemiology.casetocase.BadPartitionException;
import dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood;
import dr.evomodel.epidemiology.casetocase.CategoryOutbreak;
import dr.evomodel.epidemiology.casetocase.SpatialKernel;
import dr.evomodel.epidemiology.casetocase.periodpriors.AbstractPeriodPriorDistribution;
import dr.inference.distribution.ParametricDistributionModel;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;

public class CaseToCaseTransmissionLikelihood
extends AbstractModelLikelihood
implements Loggable {
    private static final boolean DEBUG = false;
    private CategoryOutbreak outbreak;
    private CaseToCaseTreeLikelihood treeLikelihood;
    private SpatialKernel spatialKernel;
    private Parameter transmissionRate;
    private boolean likelihoodKnown;
    private boolean storedLikelihoodKnown;
    private boolean transProbKnown;
    private boolean storedTransProbKnown;
    private boolean periodsProbKnown;
    private boolean storedPeriodsProbKnown;
    private boolean treeProbKnown;
    private boolean storedTreeProbKnown;
    private double logLikelihood;
    private double storedLogLikelihood;
    private double transLogProb;
    private double storedTransLogProb;
    private double periodsLogProb;
    private double storedPeriodsLogProb;
    private double treeLogProb;
    private double storedTreeLogProb;
    private ParametricDistributionModel initialInfectionTimePrior;
    private HashMap<AbstractCase, Double> indexCasePrior;
    private final boolean hasGeography;
    private final boolean hasLatentPeriods;
    private ArrayList<TreeEvent> sortedTreeEvents;
    private ArrayList<TreeEvent> storedSortedTreeEvents;
    private AbstractCase indexCase;
    private AbstractCase storedIndexCase;
    public static final String CASE_TO_CASE_TRANSMISSION_LIKELIHOOD = "caseToCaseTransmissionLikelihood";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        public static final String TRANSMISSION_RATE = "transmissionRate";
        public static final String INITIAL_INFECTION_TIME_PRIOR = "initialInfectionTimePrior";
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new ElementRule(CaseToCaseTreeLikelihood.class, "The tree likelihood"), new ElementRule(SpatialKernel.class, "The spatial kernel", 0, 1), new ElementRule("transmissionRate", Parameter.class, "The transmission rate"), new ElementRule("initialInfectionTimePrior", ParametricDistributionModel.class, "The prior probability distibution of the first infection", true)};

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            CaseToCaseTreeLikelihood caseToCaseTreeLikelihood = (CaseToCaseTreeLikelihood)xMLObject.getChild(CaseToCaseTreeLikelihood.class);
            SpatialKernel spatialKernel = (SpatialKernel)xMLObject.getChild(SpatialKernel.class);
            Parameter parameter = (Parameter)xMLObject.getElementFirstChild(TRANSMISSION_RATE);
            ParametricDistributionModel parametricDistributionModel = null;
            if (xMLObject.hasChildNamed(INITIAL_INFECTION_TIME_PRIOR)) {
                parametricDistributionModel = (ParametricDistributionModel)xMLObject.getElementFirstChild(INITIAL_INFECTION_TIME_PRIOR);
            }
            return new CaseToCaseTransmissionLikelihood(CaseToCaseTransmissionLikelihood.CASE_TO_CASE_TRANSMISSION_LIKELIHOOD, (CategoryOutbreak)caseToCaseTreeLikelihood.getOutbreak(), caseToCaseTreeLikelihood, spatialKernel, parameter, parametricDistributionModel);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public String getParserDescription() {
            return "This element represents a probability distribution for epidemiological parameters of an outbreakgiven a phylogenetic tree";
        }

        @Override
        public Class getReturnType() {
            return CaseToCaseTransmissionLikelihood.class;
        }

        @Override
        public String getParserName() {
            return CaseToCaseTransmissionLikelihood.CASE_TO_CASE_TRANSMISSION_LIKELIHOOD;
        }
    };

    public CaseToCaseTransmissionLikelihood(String string, CategoryOutbreak categoryOutbreak, CaseToCaseTreeLikelihood caseToCaseTreeLikelihood, SpatialKernel spatialKernel, Parameter parameter, ParametricDistributionModel parametricDistributionModel) {
        super(string);
        this.outbreak = categoryOutbreak;
        this.treeLikelihood = caseToCaseTreeLikelihood;
        this.spatialKernel = spatialKernel;
        if (spatialKernel != null) {
            this.addModel(spatialKernel);
        }
        this.transmissionRate = parameter;
        this.addModel(caseToCaseTreeLikelihood);
        this.addVariable(parameter);
        this.likelihoodKnown = false;
        this.hasGeography = spatialKernel != null;
        this.hasLatentPeriods = caseToCaseTreeLikelihood.hasLatentPeriods();
        this.initialInfectionTimePrior = parametricDistributionModel;
        HashMap<AbstractCase, Double> hashMap = categoryOutbreak.getWeightMap();
        double d = 0.0;
        for (AbstractCase abstractCase : hashMap.keySet()) {
            if (!abstractCase.wasEverInfected) continue;
            d += hashMap.get(abstractCase).doubleValue();
        }
        this.indexCasePrior = new HashMap();
        for (AbstractCase abstractCase : categoryOutbreak.getCases()) {
            if (!abstractCase.wasEverInfected) continue;
            this.indexCasePrior.put(abstractCase, hashMap.get(abstractCase) / d);
        }
        this.sortEvents();
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model instanceof CaseToCaseTreeLikelihood) {
            this.treeProbKnown = false;
            if (!(object instanceof DemographicModel)) {
                this.transProbKnown = false;
                this.periodsProbKnown = false;
                this.sortedTreeEvents = null;
                this.indexCase = null;
            }
        } else if (model instanceof SpatialKernel) {
            this.transProbKnown = false;
        } else if (model instanceof AbstractOutbreak) {
            this.transProbKnown = false;
            this.periodsProbKnown = false;
            this.sortedTreeEvents = null;
            this.indexCase = null;
        }
        this.likelihoodKnown = false;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable == this.transmissionRate) {
            this.transProbKnown = false;
        }
        this.likelihoodKnown = false;
    }

    @Override
    protected void storeState() {
        this.storedLogLikelihood = this.logLikelihood;
        this.storedLikelihoodKnown = this.likelihoodKnown;
        this.storedPeriodsLogProb = this.periodsLogProb;
        this.storedPeriodsProbKnown = this.periodsProbKnown;
        this.storedTransLogProb = this.transLogProb;
        this.storedTransProbKnown = this.transProbKnown;
        this.storedTreeLogProb = this.treeLogProb;
        this.storedTreeProbKnown = this.treeProbKnown;
        this.storedSortedTreeEvents = new ArrayList<TreeEvent>(this.sortedTreeEvents);
        this.storedIndexCase = this.indexCase;
    }

    @Override
    protected void restoreState() {
        this.logLikelihood = this.storedLogLikelihood;
        this.likelihoodKnown = this.storedLikelihoodKnown;
        this.transLogProb = this.storedTransLogProb;
        this.transProbKnown = this.storedTransProbKnown;
        this.treeLogProb = this.storedTreeLogProb;
        this.treeProbKnown = this.storedTreeProbKnown;
        this.periodsLogProb = this.storedPeriodsLogProb;
        this.periodsProbKnown = this.storedPeriodsProbKnown;
        this.sortedTreeEvents = this.storedSortedTreeEvents;
        this.indexCase = this.storedIndexCase;
    }

    @Override
    protected void acceptState() {
    }

    public SpatialKernel getSpatialKernel() {
        return this.spatialKernel;
    }

    @Override
    public Model getModel() {
        return this;
    }

    public CaseToCaseTreeLikelihood getTreeLikelihood() {
        return this.treeLikelihood;
    }

    @Override
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            if (!this.treeProbKnown) {
                this.treeLikelihood.prepareTimings();
            }
            if (!this.transProbKnown) {
                try {
                    this.transLogProb = 0.0;
                    if (this.sortedTreeEvents == null) {
                        this.sortEvents();
                    }
                    double d = this.transmissionRate.getParameterValue(0);
                    ArrayList<AbstractCase> object = new ArrayList<AbstractCase>();
                    boolean bl = true;
                    for (TreeEvent treeEvent : this.sortedTreeEvents) {
                        double d2 = treeEvent.getTime();
                        AbstractCase abstractCase = treeEvent.getCase();
                        if (treeEvent.getType() == EventType.INFECTION) {
                            if (bl) {
                                if (this.indexCasePrior != null) {
                                    this.transLogProb += Math.log(this.indexCasePrior.get(abstractCase));
                                }
                                if (this.initialInfectionTimePrior != null) {
                                    this.transLogProb += this.initialInfectionTimePrior.logPdf(d2);
                                }
                                if (!this.hasLatentPeriods) {
                                    object.add(abstractCase);
                                }
                                bl = false;
                                continue;
                            }
                            AbstractCase abstractCase2 = treeEvent.getInfector();
                            if (abstractCase.wasEverInfected()) {
                                if (object.contains(abstractCase)) {
                                    throw new BadPartitionException(abstractCase.caseID + " infected after it was infectious");
                                }
                                if (treeEvent.getTime() > abstractCase.endOfInfectiousTime) {
                                    throw new BadPartitionException(abstractCase.caseID + " ceased to be infected before it was infected");
                                }
                                if (abstractCase2.endOfInfectiousTime < treeEvent.getTime()) {
                                    throw new BadPartitionException(abstractCase.caseID + " infected by " + abstractCase2.caseID + " after the latter ceased to be infectious");
                                }
                                if (this.treeLikelihood.getInfectiousTime(abstractCase2) > treeEvent.getTime()) {
                                    throw new BadPartitionException(abstractCase.caseID + " infected by " + abstractCase2.caseID + " before the latter became infectious");
                                }
                                if (!object.contains(abstractCase2)) {
                                    throw new RuntimeException("Infector not previously infected");
                                }
                            }
                            for (AbstractCase abstractCase3 : object) {
                                double d3 = abstractCase3.endOfInfectiousTime < treeEvent.getTime() ? abstractCase3.endOfInfectiousTime - this.treeLikelihood.getInfectiousTime(abstractCase3) : treeEvent.getTime() - this.treeLikelihood.getInfectiousTime(abstractCase3);
                                if (d3 < 0.0) {
                                    throw new RuntimeException("negative time");
                                }
                                double d4 = d;
                                if (this.hasGeography) {
                                    d4 *= this.outbreak.getKernelValue(abstractCase, abstractCase3, this.spatialKernel);
                                }
                                this.transLogProb += -d4 * d3;
                            }
                            if (abstractCase.wasEverInfected()) {
                                double d5 = d;
                                if (this.hasGeography) {
                                    d5 *= this.outbreak.getKernelValue(abstractCase, abstractCase2, this.spatialKernel);
                                }
                                this.transLogProb += Math.log(d5);
                            }
                            if (this.hasLatentPeriods) continue;
                            object.add(abstractCase);
                            continue;
                        }
                        if (treeEvent.getType() != EventType.INFECTIOUSNESS || !(treeEvent.getTime() < Double.POSITIVE_INFINITY)) continue;
                        if (treeEvent.getTime() > treeEvent.getCase().endOfInfectiousTime) {
                            throw new BadPartitionException(treeEvent.getCase().caseID + " noninfectious beforeinfectious");
                        }
                        if (bl) {
                            throw new RuntimeException("First event is not an infection");
                        }
                        object.add(abstractCase);
                    }
                    this.transProbKnown = true;
                }
                catch (BadPartitionException badPartitionException) {
                    this.transLogProb = Double.NEGATIVE_INFINITY;
                    this.transProbKnown = true;
                    this.logLikelihood = Double.NEGATIVE_INFINITY;
                    this.likelihoodKnown = true;
                    return this.logLikelihood;
                }
            }
            if (!this.periodsProbKnown) {
                Serializable serializable;
                this.periodsLogProb = 0.0;
                HashMap hashMap = new HashMap();
                for (AbstractCase abstractCase : this.outbreak.getCases()) {
                    if (!abstractCase.wasEverInfected()) continue;
                    String string = this.outbreak.getInfectiousCategory(abstractCase);
                    if (!hashMap.keySet().contains(string)) {
                        hashMap.put(string, new ArrayList());
                    }
                    serializable = (ArrayList)hashMap.get(string);
                    ((ArrayList)serializable).add(this.treeLikelihood.getInfectiousPeriod(abstractCase));
                }
                for (String string : this.outbreak.getInfectiousCategories()) {
                    Double[] doubleArray = ((ArrayList)hashMap.get(string)).toArray(new Double[((ArrayList)hashMap.get(string)).size()]);
                    serializable = this.outbreak.getInfectiousCategoryPrior(string);
                    double[] dArray = new double[doubleArray.length];
                    for (int i = 0; i < doubleArray.length; ++i) {
                        dArray[i] = doubleArray[i];
                    }
                    this.periodsLogProb += ((AbstractPeriodPriorDistribution)serializable).getLogLikelihood(dArray);
                }
                this.periodsProbKnown = true;
            }
            if (!this.treeProbKnown) {
                this.treeLogProb = this.treeLikelihood.getLogLikelihood();
                this.treeProbKnown = true;
            }
            if (this.transLogProb == Double.POSITIVE_INFINITY) {
                System.out.println("TransLogProb +INF");
                return Double.NEGATIVE_INFINITY;
            }
            if (this.periodsLogProb == Double.POSITIVE_INFINITY) {
                System.out.println("PeriodsLogProb +INF");
                return Double.NEGATIVE_INFINITY;
            }
            if (this.treeLogProb == Double.POSITIVE_INFINITY) {
                System.out.println("TreeLogProb +INF");
                return Double.NEGATIVE_INFINITY;
            }
            this.logLikelihood = this.treeLogProb + this.periodsLogProb + this.transLogProb;
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    @Override
    public void makeDirty() {
        this.likelihoodKnown = false;
        this.transProbKnown = false;
        this.periodsProbKnown = false;
        this.treeProbKnown = false;
        this.sortedTreeEvents = null;
        this.treeLikelihood.makeDirty();
        this.indexCase = null;
    }

    private void sortEvents() {
        ArrayList<TreeEvent> arrayList = new ArrayList<TreeEvent>();
        for (AbstractCase abstractCase : this.outbreak.getCases()) {
            double d = this.treeLikelihood.getInfectionTime(abstractCase);
            arrayList.add(new TreeEvent(d, abstractCase, this.treeLikelihood.getInfector(this.outbreak.getCaseIndex(abstractCase))));
            if (!abstractCase.wasEverInfected()) continue;
            double d2 = abstractCase.endOfInfectiousTime;
            arrayList.add(new TreeEvent(EventType.END, d2, abstractCase));
            if (!this.hasLatentPeriods) continue;
            double d3 = this.treeLikelihood.getInfectiousTime(abstractCase);
            arrayList.add(new TreeEvent(EventType.INFECTIOUSNESS, d3, abstractCase));
        }
        Collections.sort(arrayList, new EventComparator());
        this.indexCase = ((TreeEvent)arrayList.get(0)).getCase();
        this.sortedTreeEvents = arrayList;
    }

    @Override
    public LogColumn[] getColumns() {
        ArrayList<LogColumn> arrayList = new ArrayList<LogColumn>();
        arrayList.add(new LogColumn.Abstract("trans_LL"){

            @Override
            protected String getFormattedValue() {
                return String.valueOf(CaseToCaseTransmissionLikelihood.this.transLogProb);
            }
        });
        arrayList.add(new LogColumn.Abstract("period_LL"){

            @Override
            protected String getFormattedValue() {
                return String.valueOf(CaseToCaseTransmissionLikelihood.this.periodsLogProb);
            }
        });
        arrayList.addAll(Arrays.asList(this.treeLikelihood.passColumns()));
        for (AbstractPeriodPriorDistribution abstractPeriodPriorDistribution : this.outbreak.getInfectiousMap().values()) {
            arrayList.addAll(Arrays.asList(abstractPeriodPriorDistribution.getColumns()));
        }
        arrayList.add(new LogColumn.Abstract("FirstInfectionTime"){

            @Override
            protected String getFormattedValue() {
                if (CaseToCaseTransmissionLikelihood.this.sortedTreeEvents == null) {
                    CaseToCaseTransmissionLikelihood.this.sortEvents();
                }
                return String.valueOf(CaseToCaseTransmissionLikelihood.this.treeLikelihood.getInfectionTime(CaseToCaseTransmissionLikelihood.this.indexCase));
            }
        });
        arrayList.add(new LogColumn.Abstract("IndexCaseIndex"){

            @Override
            protected String getFormattedValue() {
                return String.valueOf(CaseToCaseTransmissionLikelihood.this.treeLikelihood.getOutbreak().getCaseIndex(CaseToCaseTransmissionLikelihood.this.indexCase));
            }
        });
        return arrayList.toArray(new LogColumn[arrayList.size()]);
    }

    private class TreeEvent {
        private EventType type;
        private double time;
        private AbstractCase aCase;
        private AbstractCase infectorCase;

        private TreeEvent(EventType eventType, double d, AbstractCase abstractCase) {
            this.type = eventType;
            this.time = d;
            this.aCase = abstractCase;
            this.infectorCase = null;
        }

        private TreeEvent(double d, AbstractCase abstractCase, AbstractCase abstractCase2) {
            this.type = EventType.INFECTION;
            this.time = d;
            this.aCase = abstractCase;
            this.infectorCase = abstractCase2;
        }

        public double getTime() {
            return this.time;
        }

        public EventType getType() {
            return this.type;
        }

        public AbstractCase getCase() {
            return this.aCase;
        }

        public AbstractCase getInfector() {
            return this.infectorCase;
        }
    }

    private static enum EventType {
        INFECTION,
        INFECTIOUSNESS,
        END;

    }

    private class EventComparator
    implements Comparator<TreeEvent> {
        private EventComparator() {
        }

        @Override
        public int compare(TreeEvent treeEvent, TreeEvent treeEvent2) {
            return Double.compare(treeEvent.getTime(), treeEvent2.getTime());
        }
    }
}

