/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class SPegasos
extends AbstractClassifier
implements TechnicalInformationHandler,
UpdateableClassifier,
OptionHandler {
    private static final long serialVersionUID = -3732968666673530290L;
    protected ReplaceMissingValues m_replaceMissing;
    protected NominalToBinary m_nominalToBinary;
    protected Normalize m_normalize;
    protected double m_lambda = 1.0E-4;
    protected double[] m_weights;
    protected double m_t;
    protected int m_epochs = 500;
    protected boolean m_dontNormalize = false;
    protected boolean m_dontReplaceMissing = false;
    protected Instances m_data;
    protected static final int HINGE = 0;
    protected static final int LOGLOSS = 1;
    protected int m_loss = 0;
    public static final Tag[] TAGS_SELECTION = new Tag[]{new Tag(0, "Hinge loss (SVM)"), new Tag(1, "Log loss (logistic regression)")};

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.BINARY_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    public String lambdaTipText() {
        return "The regularization constant. (default = 0.0001)";
    }

    public void setLambda(double lambda) {
        this.m_lambda = lambda;
    }

    public double getLambda() {
        return this.m_lambda;
    }

    public String epochsTipText() {
        return "The number of epochs to perform (batch learning). The total number of iterations is epochs * num instances.";
    }

    public void setEpochs(int e) {
        this.m_epochs = e;
    }

    public int getEpochs() {
        return this.m_epochs;
    }

    public void setDontNormalize(boolean m) {
        this.m_dontNormalize = m;
    }

    public boolean getDontNormalize() {
        return this.m_dontNormalize;
    }

    public String dontNormalizeTipText() {
        return "Turn normalization off";
    }

    public void setDontReplaceMissing(boolean m) {
        this.m_dontReplaceMissing = m;
    }

    public boolean getDontReplaceMissing() {
        return this.m_dontReplaceMissing;
    }

    public String dontReplaceMissingTipText() {
        return "Turn off global replacement of missing values";
    }

    public void setLossFunction(SelectedTag function) {
        if (function.getTags() == TAGS_SELECTION) {
            this.m_loss = function.getSelectedTag().getID();
        }
    }

    public SelectedTag getLossFunction() {
        return new SelectedTag(this.m_loss, TAGS_SELECTION);
    }

    public String lossFunctionTipText() {
        return "The loss function to use. Hinge loss (SVM) or log loss (logistic regression).";
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.add(new Option("\tSet the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression).\n\t(default = 0)", "F", 1, "-F"));
        newVector.add(new Option("\tThe lambda regularization constant (default = 0.0001)", "L", 1, "-L <double>"));
        newVector.add(new Option("\tThe number of epochs to perform (batch learning only, default = 500)", "E", 1, "-E <integer>"));
        newVector.add(new Option("\tDon't normalize the data", "N", 0, "-N"));
        newVector.add(new Option("\tDon't replace missing values", "M", 0, "-M"));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String epochsString;
        this.reset();
        String lossString = Utils.getOption('F', options);
        if (lossString.length() != 0) {
            this.setLossFunction(new SelectedTag(Integer.parseInt(lossString), TAGS_SELECTION));
        } else {
            this.setLossFunction(new SelectedTag(0, TAGS_SELECTION));
        }
        String lambdaString = Utils.getOption('L', options);
        if (lambdaString.length() > 0) {
            this.setLambda(Double.parseDouble(lambdaString));
        }
        if ((epochsString = Utils.getOption("E", options)).length() > 0) {
            this.setEpochs(Integer.parseInt(epochsString));
        }
        this.setDontNormalize(Utils.getFlag("N", options));
        this.setDontReplaceMissing(Utils.getFlag('M', options));
    }

    @Override
    public String[] getOptions() {
        ArrayList<String> options = new ArrayList<String>();
        options.add("-F");
        options.add("" + this.getLossFunction().getSelectedTag().getID());
        options.add("-L");
        options.add("" + this.getLambda());
        options.add("-E");
        options.add("" + this.getEpochs());
        if (this.getDontNormalize()) {
            options.add("-N");
        }
        if (this.getDontReplaceMissing()) {
            options.add("-M");
        }
        return options.toArray(new String[1]);
    }

    public String globalInfo() {
        return "Implements the stochastic variant of the Pegasos (Primal Estimated sub-GrAdient SOlver for SVM) method of Shalev-Shwartz et al. (2007). This implementation globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data. Can either minimize the hinge loss (SVM) or log loss (logistic regression). For more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "S. Shalev-Shwartz and Y. Singer and N. Srebro");
        result.setValue(TechnicalInformation.Field.YEAR, "2007");
        result.setValue(TechnicalInformation.Field.TITLE, "Pegasos: Primal Estimated sub-GrAdient SOlver for SVM");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "24th International Conference on MachineLearning");
        result.setValue(TechnicalInformation.Field.PAGES, "807-814");
        return result;
    }

    public void reset() {
        this.m_t = 1.0;
        this.m_weights = null;
        this.m_normalize = null;
        this.m_replaceMissing = null;
        this.m_nominalToBinary = null;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.reset();
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        if (data.numInstances() > 0 && !this.m_dontReplaceMissing) {
            this.m_replaceMissing = new ReplaceMissingValues();
            this.m_replaceMissing.setInputFormat(data);
            data = Filter.useFilter(data, this.m_replaceMissing);
        }
        boolean onlyNumeric = true;
        for (int i = 0; i < data.numAttributes(); ++i) {
            if (i == data.classIndex() || data.attribute(i).isNumeric()) continue;
            onlyNumeric = false;
            break;
        }
        if (!onlyNumeric) {
            this.m_nominalToBinary = new NominalToBinary();
            this.m_nominalToBinary.setInputFormat(data);
            data = Filter.useFilter(data, this.m_nominalToBinary);
        }
        if (!this.m_dontNormalize && data.numInstances() > 0) {
            this.m_normalize = new Normalize();
            this.m_normalize.setInputFormat(data);
            data = Filter.useFilter(data, this.m_normalize);
        }
        this.m_weights = new double[data.numAttributes() + 1];
        this.m_data = new Instances(data, 0);
        if (data.numInstances() > 0) {
            this.train(data);
        }
    }

    protected double dloss(double z) {
        if (this.m_loss == 0) {
            return z < 1.0 ? 1.0 : 0.0;
        }
        if (z < 0.0) {
            return 1.0 / (Math.exp(z) + 1.0);
        }
        double t = Math.exp(-z);
        return t / (t + 1.0);
    }

    private void train(Instances data) {
        for (int e = 0; e < this.m_epochs; ++e) {
            for (int i = 0; i < data.numInstances(); ++i) {
                Instance instance = data.instance(i);
                double learningRate = 1.0 / (this.m_lambda * this.m_t);
                double scale = 1.0 - 1.0 / this.m_t;
                double y = instance.classValue() == 0.0 ? -1.0 : 1.0;
                double wx = SPegasos.dotProd(instance, this.m_weights, instance.classIndex());
                double z = y * (wx + this.m_weights[this.m_weights.length - 1]);
                if (this.m_loss == 1 || z < 1.0) {
                    double delta = learningRate * this.dloss(z);
                    int n1 = instance.numValues();
                    int n2 = data.numAttributes();
                    int p1 = 0;
                    for (int p2 = 0; p2 < n2; ++p2) {
                        int indS = 0;
                        indS = p1 < n1 ? instance.index(p1) : indS;
                        int indP = p2;
                        if (indP != data.classIndex()) {
                            int n = indP;
                            this.m_weights[n] = this.m_weights[n] * scale;
                        }
                        if (indS != indP) continue;
                        if (indS != data.classIndex() && !instance.isMissingSparse(p1)) {
                            double m = delta * (instance.valueSparse(p1) * y);
                            int n = indS;
                            this.m_weights[n] = this.m_weights[n] + m;
                        }
                        ++p1;
                    }
                    int n = this.m_weights.length - 1;
                    this.m_weights[n] = this.m_weights[n] + delta * y;
                    double norm = 0.0;
                    for (int k = 0; k < this.m_weights.length; ++k) {
                        if (k == data.classIndex()) continue;
                        norm += this.m_weights[k] * this.m_weights[k];
                    }
                    norm = Math.sqrt(norm);
                    double scale2 = Math.min(1.0, 1.0 / (Math.sqrt(this.m_lambda) * norm));
                    if (scale2 < 1.0) {
                        int j = 0;
                        while (j < this.m_weights.length) {
                            int n3 = j++;
                            this.m_weights[n3] = this.m_weights[n3] * scale2;
                        }
                    }
                }
                this.m_t += 1.0;
            }
        }
    }

    protected static double dotProd(Instance inst1, double[] weights, int classIndex) {
        double result = 0.0;
        int n1 = inst1.numValues();
        int n2 = weights.length - 1;
        int p1 = 0;
        int p2 = 0;
        while (p1 < n1 && p2 < n2) {
            int ind2;
            int ind1 = inst1.index(p1);
            if (ind1 == (ind2 = p2++)) {
                if (ind1 != classIndex && !inst1.isMissingSparse(p1)) {
                    result += inst1.valueSparse(p1) * weights[p2];
                }
                ++p1;
                ++p2;
                continue;
            }
            if (ind1 > ind2) continue;
            ++p1;
        }
        return result;
    }

    @Override
    public void updateClassifier(Instance instance) throws Exception {
        if (!instance.classIsMissing()) {
            double learningRate = 1.0 / (this.m_lambda * this.m_t);
            double scale = 1.0 - 1.0 / this.m_t;
            double y = instance.classValue() == 0.0 ? -1.0 : 1.0;
            double wx = SPegasos.dotProd(instance, this.m_weights, instance.classIndex());
            double z = y * (wx + this.m_weights[this.m_weights.length - 1]);
            int j = 0;
            while (j < this.m_weights.length) {
                int n = j++;
                this.m_weights[n] = this.m_weights[n] * scale;
            }
            if (this.m_loss == 1 || z < 1.0) {
                double delta = learningRate * this.dloss(z);
                int n1 = instance.numValues();
                int n2 = instance.numAttributes();
                int p1 = 0;
                for (int p2 = 0; p2 < n2; ++p2) {
                    int indS = 0;
                    indS = p1 < n1 ? instance.index(p1) : indS;
                    int indP = p2;
                    if (indP != instance.classIndex()) {
                        int n = indP;
                        this.m_weights[n] = this.m_weights[n] * scale;
                    }
                    if (indS != indP) continue;
                    if (indS != instance.classIndex() && !instance.isMissingSparse(p1)) {
                        double m = delta * (instance.valueSparse(p1) * y);
                        int n = indS;
                        this.m_weights[n] = this.m_weights[n] + m;
                    }
                    ++p1;
                }
                int n = this.m_weights.length - 1;
                this.m_weights[n] = this.m_weights[n] + delta * y;
                double norm = 0.0;
                for (int k = 0; k < this.m_weights.length; ++k) {
                    if (k == instance.classIndex()) continue;
                    norm += this.m_weights[k] * this.m_weights[k];
                }
                norm = Math.sqrt(norm);
                double scale2 = Math.min(1.0, 1.0 / (Math.sqrt(this.m_lambda) * norm));
                if (scale2 < 1.0) {
                    int j2 = 0;
                    while (j2 < this.m_weights.length) {
                        int n3 = j2++;
                        this.m_weights[n3] = this.m_weights[n3] * scale2;
                    }
                }
            }
            this.m_t += 1.0;
        }
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        double wx;
        double z;
        double[] result = new double[2];
        if (this.m_replaceMissing != null) {
            this.m_replaceMissing.input(inst);
            inst = this.m_replaceMissing.output();
        }
        if (this.m_nominalToBinary != null) {
            this.m_nominalToBinary.input(inst);
            inst = this.m_nominalToBinary.output();
        }
        if (this.m_normalize != null) {
            this.m_normalize.input(inst);
            inst = this.m_normalize.output();
        }
        if ((z = (wx = SPegasos.dotProd(inst, this.m_weights, inst.classIndex())) + this.m_weights[this.m_weights.length - 1]) <= 0.0) {
            if (this.m_loss == 1) {
                result[0] = 1.0 / (1.0 + Math.exp(z));
                result[1] = 1.0 - result[0];
            } else {
                result[0] = 1.0;
            }
        } else if (this.m_loss == 1) {
            result[1] = 1.0 / (1.0 + Math.exp(-z));
            result[0] = 1.0 - result[1];
        } else {
            result[1] = 1.0;
        }
        return result;
    }

    public String toString() {
        if (this.m_weights == null) {
            return "SPegasos: No model built yet.\n";
        }
        StringBuffer buff = new StringBuffer();
        buff.append("Loss function: ");
        if (this.m_loss == 0) {
            buff.append("Hinge loss (SVM)\n\n");
        } else {
            buff.append("Log loss (logistic regression)\n\n");
        }
        int printed = 0;
        for (int i = 0; i < this.m_weights.length - 1; ++i) {
            if (i == this.m_data.classIndex()) continue;
            if (printed > 0) {
                buff.append(" + ");
            } else {
                buff.append("   ");
            }
            buff.append(Utils.doubleToString(this.m_weights[i], 12, 4) + " " + (this.m_normalize != null ? "(normalized) " : "") + this.m_data.attribute(i).name() + "\n");
            ++printed;
        }
        if (this.m_weights[this.m_weights.length - 1] > 0.0) {
            buff.append(" + " + Utils.doubleToString(this.m_weights[this.m_weights.length - 1], 12, 4));
        } else {
            buff.append(" - " + Utils.doubleToString(-this.m_weights[this.m_weights.length - 1], 12, 4));
        }
        return buff.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 6105 $");
    }

    public static void main(String[] args) {
        SPegasos.runClassifier(new SPegasos(), args);
    }
}

