/*
 * Decompiled with CFR 0.152.
 */
package eu.amidst.core.distribution;

import eu.amidst.core.distribution.Distribution;
import eu.amidst.core.distribution.Normal;
import eu.amidst.core.distribution.UnivariateDistribution;
import eu.amidst.core.exponentialfamily.EF_UnivariateDistribution;
import eu.amidst.core.variables.Variable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.DoubleStream;

public abstract class GaussianMixture
extends UnivariateDistribution {
    private static final long serialVersionUID = 3362372347079403247L;
    private List<Normal> terms;
    private double[] coefficients;

    public GaussianMixture(Variable var1) {
        this.var = var1;
        this.terms = new ArrayList<Normal>();
        Normal aux = new Normal(var1);
        aux.setMean(0.0);
        aux.setVariance(1.0);
        this.terms.add(aux);
        this.coefficients = new double[1];
        this.coefficients[0] = 1.0;
    }

    public GaussianMixture(List<Normal> list, double[] coeffs) {
        this.var = list.get(0).getVariable();
        this.terms = list;
        this.coefficients = coeffs;
    }

    public GaussianMixture(double[] params) {
        this.setParameters(params);
    }

    @Override
    public double getLogProbability(double value) {
        double prob = 0.0;
        int index = 0;
        for (Normal normal : this.terms) {
            prob += this.coefficients[index] * normal.getProbability(value);
            ++index;
        }
        return Math.log(prob);
    }

    @Override
    public double sample(Random rand) {
        double prob = rand.nextDouble();
        int term = 0;
        for (double sumcoefs = 0.0; prob >= sumcoefs; sumcoefs += this.coefficients[++term]) {
        }
        return this.terms.get(term).sample(rand);
    }

    @Override
    public double[] getParameters() {
        int numParameters = 3 * this.coefficients.length;
        double[] parameters = new double[numParameters];
        int index = 0;
        for (Normal normal : this.terms) {
            parameters[3 * index] = this.coefficients[index];
            parameters[3 * index + 1] = normal.getMean();
            parameters[3 * index + 2] = normal.getVariance();
            ++index;
        }
        return parameters;
    }

    public void setParameters(double[] params) {
        if (params.length % 3 != 0) {
            throw new UnsupportedOperationException("The number of parameters for the Gaussian mixture is not valid");
        }
        int numTerms = params.length / 3;
        this.coefficients = new double[numTerms];
        this.terms = new ArrayList<Normal>(numTerms);
        for (int index = 0; index < numTerms; ++index) {
            this.coefficients[index] = params[3 * index];
            Normal aux = new Normal(this.var);
            aux.setMean(params[3 * index + 1]);
            aux.setVariance(params[3 * index + 2]);
            this.terms.add(aux);
        }
    }

    @Override
    public int getNumberOfParameters() {
        return 3 * this.coefficients.length;
    }

    @Override
    public Variable getVariable() {
        return this.var;
    }

    @Override
    public String label() {
        return "Gaussian Mixture";
    }

    public void randomInitialization(Random random, int numTerms) {
        this.coefficients = new double[numTerms];
        this.terms = new ArrayList<Normal>(numTerms);
        for (int k = 0; k < numTerms; ++k) {
            this.coefficients[k] = random.nextDouble();
            Normal aux = new Normal(this.var);
            aux.setMean(5.0 * random.nextGaussian());
            aux.setVariance(random.nextDouble());
            this.terms.add(aux);
        }
        DoubleStream aux = Arrays.stream(this.coefficients);
        double suma = aux.sum();
        aux = Arrays.stream(this.coefficients);
        this.coefficients = aux.map(x -> x / suma).toArray();
    }

    @Override
    public void randomInitialization(Random random) {
        int numTerms = random.nextInt(5);
        this.randomInitialization(random, numTerms);
    }

    @Override
    public boolean equalDist(Distribution dist, double threshold) {
        if (dist.getClass().getName().equals("eu.amidst.core.distribution.GaussianMixture")) {
            return this.equalDist((Normal)dist, threshold);
        }
        return false;
    }

    @Override
    public abstract <E extends EF_UnivariateDistribution> E toEFUnivariateDistribution();

    @Override
    public String toString() {
        String text = "";
        for (int k = 0; k < this.coefficients.length; ++k) {
            text = text + String.format("%.3f", this.coefficients[k]) + " " + this.terms.get(k).toString();
            if (k >= this.coefficients.length - 1) continue;
            text = text + " + ";
        }
        return text;
    }
}

