/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelVector;
import java.io.Serializable;

public class AdaBoostM2
extends Classifier
implements Serializable {
    private static final long serialVersionUID = 1L;
    Classifier[] weakClassifiers;
    double[] alphas;

    public AdaBoostM2(Pipe instancePipe, Classifier[] weakClassifiers, double[] alphas) {
        super(instancePipe);
        this.weakClassifiers = weakClassifiers;
        this.alphas = alphas;
    }

    public int getNumWeakClassifiers() {
        return this.alphas.length;
    }

    public AdaBoostM2 getTrimmedClassifier(int numWeakClassifiersToUse) {
        if (numWeakClassifiersToUse <= 0 || numWeakClassifiersToUse > this.weakClassifiers.length) {
            throw new IllegalArgumentException("number of weak learners to use out of range:" + numWeakClassifiersToUse);
        }
        Classifier[] newWeakClassifiers = new Classifier[numWeakClassifiersToUse];
        System.arraycopy(this.weakClassifiers, 0, newWeakClassifiers, 0, numWeakClassifiersToUse);
        double[] newAlphas = new double[numWeakClassifiersToUse];
        System.arraycopy(this.alphas, 0, newAlphas, 0, numWeakClassifiersToUse);
        return new AdaBoostM2(this.instancePipe, newWeakClassifiers, newAlphas);
    }

    @Override
    public Classification classify(Instance inst) {
        return this.classify(inst, this.weakClassifiers.length);
    }

    public Classification classify(Instance inst, int numWeakClassifiersToUse) {
        if (numWeakClassifiersToUse <= 0 || numWeakClassifiersToUse > this.weakClassifiers.length) {
            throw new IllegalArgumentException("number of weak learners to use out of range:" + numWeakClassifiersToUse);
        }
        FeatureVector fv = (FeatureVector)inst.getData();
        assert (this.instancePipe == null || fv.getAlphabet() == this.instancePipe.getDataAlphabet());
        int numClasses = this.getLabelAlphabet().size();
        double[] scores = new double[numClasses];
        double sum = 0.0;
        for (int round = 0; round < numWeakClassifiersToUse; ++round) {
            int bestIndex;
            int n = bestIndex = this.weakClassifiers[round].classify(inst).getLabeling().getBestIndex();
            scores[n] = scores[n] + this.alphas[round];
            sum += scores[bestIndex];
        }
        int i = 0;
        while (i < scores.length) {
            int n = i++;
            scores[n] = scores[n] / sum;
        }
        return new Classification(inst, this, new LabelVector(this.getLabelAlphabet(), scores));
    }
}

