/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.factorAnalysis;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.factorAnalysis.FactorAnalysisOperatorAdaptor;
import dr.inference.operators.factorAnalysis.FactorAnalysisStatisticsProvider;
import dr.inferencexml.operators.factorAnalysis.LoadingsOperatorParserUtilities;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class SampledLoadingsGradient
implements GradientWrtParameterProvider {
    private final MatrixParameterInterface loadings;
    private final FactorAnalysisStatisticsProvider statisticsProvider;
    private final FactorAnalysisOperatorAdaptor adaptor;
    private final double[][] scaledFactorTraitProducts;
    private final double[][][] precisions;
    private final int nFactors;
    private final int nTraits;
    private boolean statisticsKnown = false;
    private Likelihood likelihood;
    private static final String SAMPLED_LOADINGS_GRADIENT = "sampledLoadingsGradient";
    public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser(){

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            FactorAnalysisStatisticsProvider factorAnalysisStatisticsProvider = LoadingsOperatorParserUtilities.parseAdaptorAndStatistics((XMLObject)xMLObject);
            return new SampledLoadingsGradient(factorAnalysisStatisticsProvider);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return LoadingsOperatorParserUtilities.statisticsProviderRules;
        }

        @Override
        public String getParserDescription() {
            return null;
        }

        @Override
        public Class getReturnType() {
            return SampledLoadingsGradient.class;
        }

        @Override
        public String getParserName() {
            return SampledLoadingsGradient.SAMPLED_LOADINGS_GRADIENT;
        }
    };

    SampledLoadingsGradient(FactorAnalysisStatisticsProvider factorAnalysisStatisticsProvider) {
        this.statisticsProvider = factorAnalysisStatisticsProvider;
        this.adaptor = factorAnalysisStatisticsProvider.getAdaptor();
        this.loadings = this.adaptor.getLoadings();
        this.nFactors = this.adaptor.getNumberOfFactors();
        this.nTraits = this.adaptor.getNumberOfTraits();
        this.scaledFactorTraitProducts = new double[this.nTraits][this.nFactors];
        this.precisions = new double[this.nTraits][this.nFactors][this.nFactors];
        this.likelihood = new CompoundLikelihood(this.adaptor.getLikelihoods());
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.loadings;
    }

    @Override
    public int getDimension() {
        return this.loadings.getDimension();
    }

    private void updateStatistics() {
        this.adaptor.drawFactors();
        for (int i = 0; i < this.nTraits; ++i) {
            this.statisticsProvider.getScaledFactorInnerProduct(i, this.nFactors, this.precisions[i]);
            this.statisticsProvider.getScaledFactorTraitProduct(i, this.nFactors, this.scaledFactorTraitProducts[i]);
        }
        this.statisticsKnown = true;
    }

    @Override
    public double[] getGradientLogDensity() {
        this.updateStatistics();
        double[] dArray = new double[this.getDimension()];
        for (int i = 0; i < this.nTraits; ++i) {
            for (int j = 0; j < this.nFactors; ++j) {
                int n = j * this.nTraits + i;
                dArray[n] = this.scaledFactorTraitProducts[i][j];
                for (int k = 0; k < this.nFactors; ++k) {
                    int n2 = n;
                    dArray[n2] = dArray[n2] - this.precisions[i][j][k] * this.loadings.getParameterValue(i, k);
                }
            }
        }
        return dArray;
    }
}

