/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

public class MINND
extends Classifier
implements OptionHandler,
MultiInstanceCapabilitiesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -4512599203273864994L;
    protected int m_Neighbour = 1;
    protected double[][] m_Mean = null;
    protected double[][] m_Variance = null;
    protected int m_Dimension = 0;
    protected Instances m_Attributes;
    protected double[] m_Class = null;
    protected int m_NumClasses = 0;
    protected double[] m_Weights = null;
    private static double m_ZERO = 1.0E-45;
    protected double m_Rate = -1.0;
    private double[] m_MinArray = null;
    private double[] m_MaxArray = null;
    private double m_STOP = 1.0E-45;
    private double[][] m_Change = null;
    private double[][] m_NoiseM = null;
    private double[][] m_NoiseV = null;
    private double[][] m_ValidM = null;
    private double[][] m_ValidV = null;
    private int m_Select = 1;
    private int m_Choose = 1;
    private double m_Decay = 0.5;

    public String globalInfo() {
        return "Multiple-Instance Nearest Neighbour with Distribution learner.\n\nIt uses gradient descent to find the weight for each dimension of each exeamplar from the starting point of 1.0. In order to avoid overfitting, it uses mean-square function (i.e. the Euclidean distance) to search for the weights.\n It then uses the weights to cleanse the training data. After that it searches for the weights again from the starting points of the weights searched before.\n Finally it uses the most updated weights to cleanse the test exemplar and then finds the nearest neighbour of the test exemplar using partly-weighted Kullback distance. But the variances in the Kullback distance are the ones before cleansing.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.MISC);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Xin Xu");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2001");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "A nearest distribution approach to multiple-instance learning");
        technicalInformation.setValue(TechnicalInformation.Field.SCHOOL, "University of Waikato");
        technicalInformation.setValue(TechnicalInformation.Field.ADDRESS, "Hamilton, NZ");
        technicalInformation.setValue(TechnicalInformation.Field.NOTE, "0657.591B");
        return technicalInformation;
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return capabilities;
    }

    public Capabilities getMultiInstanceCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.disableAllClasses();
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        int n;
        Instance instance;
        int n2;
        int n3;
        this.getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        int n4 = instances2.numInstances();
        this.m_Dimension = instances2.attribute(1).relation().numAttributes();
        this.m_Attributes = instances2.stringFreeStructure();
        this.m_Change = new double[n4][this.m_Dimension];
        this.m_NumClasses = instances.numClasses();
        this.m_Mean = new double[n4][this.m_Dimension];
        this.m_Variance = new double[n4][this.m_Dimension];
        this.m_Class = new double[n4];
        this.m_Weights = new double[n4];
        this.m_NoiseM = new double[n4][this.m_Dimension];
        this.m_NoiseV = new double[n4][this.m_Dimension];
        this.m_ValidM = new double[n4][this.m_Dimension];
        this.m_ValidV = new double[n4][this.m_Dimension];
        this.m_MinArray = new double[this.m_Dimension];
        this.m_MaxArray = new double[this.m_Dimension];
        for (n3 = 0; n3 < this.m_Dimension; ++n3) {
            this.m_MaxArray[n3] = Double.NaN;
            this.m_MinArray[n3] = Double.NaN;
        }
        for (n3 = 0; n3 < n4; ++n3) {
            this.updateMinMax(instances2.instance(n3));
        }
        Instances instances3 = this.m_Attributes;
        for (n2 = 0; n2 < n4; ++n2) {
            instance = instances2.instance(n2);
            instance = this.scale(instance);
            for (n = 0; n < this.m_Dimension; ++n) {
                this.m_Mean[n2][n] = instance.relationalValue(1).meanOrMode(n);
                this.m_Variance[n2][n] = instance.relationalValue(1).variance(n);
                if (Utils.eq(this.m_Variance[n2][n], 0.0)) {
                    this.m_Variance[n2][n] = m_ZERO;
                }
                this.m_Change[n2][n] = 1.0;
            }
            instances3.add(instance);
            this.m_Class[n2] = instance.classValue();
            this.m_Weights[n2] = instance.weight();
        }
        for (n2 = 0; n2 < n4; ++n2) {
            this.findWeights(n2, this.m_Mean);
        }
        for (n2 = 0; n2 < n4; ++n2) {
            instance = this.preprocess(instances3, n2);
            if (this.getDebug()) {
                System.out.println("???Exemplar " + n2 + " has been pre-processed:" + instances3.instance(n2).relationalValue(1).sumOfWeights() + "|" + instance.relationalValue(1).sumOfWeights() + "; class:" + this.m_Class[n2]);
            }
            if (Utils.gr(instance.relationalValue(1).sumOfWeights(), 0.0)) {
                for (n = 0; n < this.m_Dimension; ++n) {
                    this.m_ValidM[n2][n] = instance.relationalValue(1).meanOrMode(n);
                    this.m_ValidV[n2][n] = instance.relationalValue(1).variance(n);
                    if (!Utils.eq(this.m_ValidV[n2][n], 0.0)) continue;
                    this.m_ValidV[n2][n] = m_ZERO;
                }
                continue;
            }
            this.m_ValidM[n2] = null;
            this.m_ValidV[n2] = null;
        }
        for (n2 = 0; n2 < n4; ++n2) {
            if (this.m_ValidM[n2] == null) continue;
            this.findWeights(n2, this.m_ValidM);
        }
    }

    public Instance preprocess(Instances instances, int n) throws Exception {
        int n2;
        Instance instance = instances.instance(n);
        if ((int)instance.classValue() == 0) {
            this.m_NoiseM[n] = null;
            this.m_NoiseV[n] = null;
            return instance;
        }
        Instances instances2 = instance.attribute(1).relation().stringFreeStructure();
        Instances instances3 = instance.attribute(1).relation().stringFreeStructure();
        Instances instances4 = this.m_Attributes;
        Instance instance2 = new Instance(instance.numAttributes());
        Instance instance3 = new Instance(instance.numAttributes());
        instance2.setDataset(instances4);
        instance3.setDataset(instances4);
        for (n2 = 0; n2 < instance.relationalValue(1).numInstances(); ++n2) {
            int n3;
            Instance instance4 = instance.relationalValue(1).instance(n2);
            double[] dArray = new double[instances.numInstances()];
            for (int i = 0; i < instances.numInstances(); ++i) {
                dArray[i] = i != n ? this.distance(instance4, this.m_Mean[i], this.m_Variance[i], i) : Double.POSITIVE_INFINITY;
            }
            int[] nArray = new int[this.m_NumClasses];
            for (n3 = 0; n3 < nArray.length; ++n3) {
                nArray[n3] = 0;
            }
            for (n3 = 0; n3 < this.m_Select; ++n3) {
                int n4 = Utils.minIndex(dArray);
                int n5 = (int)this.m_Class[n4];
                nArray[n5] = nArray[n5] + 1;
                dArray[n4] = Double.POSITIVE_INFINITY;
            }
            n3 = Utils.maxIndex(nArray);
            if ((int)instance.classValue() != n3) {
                instances3.add(instance4);
                continue;
            }
            instances2.add(instance4);
        }
        n2 = instance3.attribute(1).addRelation(instances3);
        instance3.setValue(0, instance.value(0));
        instance3.setValue(1, (double)n2);
        instance3.setValue(2, instance.classValue());
        n2 = instance2.attribute(1).addRelation(instances2);
        instance2.setValue(0, instance.value(0));
        instance2.setValue(1, (double)n2);
        instance2.setValue(2, instance.classValue());
        if (Utils.gr(instance3.relationalValue(1).sumOfWeights(), 0.0)) {
            for (int i = 0; i < this.m_Dimension; ++i) {
                this.m_NoiseM[n][i] = instance3.relationalValue(1).meanOrMode(i);
                this.m_NoiseV[n][i] = instance3.relationalValue(1).variance(i);
                if (!Utils.eq(this.m_NoiseV[n][i], 0.0)) continue;
                this.m_NoiseV[n][i] = m_ZERO;
            }
        } else {
            this.m_NoiseM[n] = null;
            this.m_NoiseV[n] = null;
        }
        return instance2;
    }

    private double distance(Instance instance, double[] dArray, double[] dArray2, int n) {
        double d = 0.0;
        for (int i = 0; i < this.m_Dimension; ++i) {
            if (!instance.attribute(i).isNumeric()) continue;
            if (!instance.isMissing(i)) {
                double d2 = instance.value(i) - dArray[i];
                if (Utils.gr(dArray2[i], m_ZERO)) {
                    d += this.m_Change[n][i] * dArray2[i] * d2 * d2;
                    continue;
                }
                d += this.m_Change[n][i] * d2 * d2;
                continue;
            }
            if (Utils.gr(dArray2[i], m_ZERO)) {
                d += this.m_Change[n][i] * dArray2[i];
                continue;
            }
            d += this.m_Change[n][i] * 1.0;
        }
        return d;
    }

    private void updateMinMax(Instance instance) {
        Instances instances = instance.relationalValue(1);
        for (int i = 0; i < this.m_Dimension; ++i) {
            if (!instances.attribute(i).isNumeric()) continue;
            for (int j = 0; j < instances.numInstances(); ++j) {
                Instance instance2 = instances.instance(j);
                if (instance2.isMissing(i)) continue;
                if (Double.isNaN(this.m_MinArray[i])) {
                    this.m_MinArray[i] = instance2.value(i);
                    this.m_MaxArray[i] = instance2.value(i);
                    continue;
                }
                if (instance2.value(i) < this.m_MinArray[i]) {
                    this.m_MinArray[i] = instance2.value(i);
                    continue;
                }
                if (!(instance2.value(i) > this.m_MaxArray[i])) continue;
                this.m_MaxArray[i] = instance2.value(i);
            }
        }
    }

    private Instance scale(Instance instance) throws Exception {
        int n;
        Instances instances = instance.relationalValue(1).stringFreeStructure();
        Instance instance2 = new Instance(instance.numAttributes());
        instance2.setDataset(this.m_Attributes);
        for (n = 0; n < instance.relationalValue(1).numInstances(); ++n) {
            Instance instance3 = instance.relationalValue(1).instance(n);
            Instance instance4 = (Instance)instance3.copy();
            for (int i = 0; i < this.m_Dimension; ++i) {
                if (!instance.relationalValue(1).attribute(i).isNumeric()) continue;
                instance4.setValue(i, (instance3.value(i) - this.m_MinArray[i]) / (this.m_MaxArray[i] - this.m_MinArray[i]));
            }
            instances.add(instance4);
        }
        n = instance2.attribute(1).addRelation(instances);
        instance2.setValue(0, instance.value(0));
        instance2.setValue(1, (double)n);
        instance2.setValue(2, instance.value(2));
        return instance2;
    }

    public void findWeights(int n, double[][] dArray) {
        double[] dArray2 = new double[this.m_Dimension];
        double[] dArray3 = new double[this.m_Dimension];
        System.arraycopy(this.m_Change[n], 0, dArray2, 0, this.m_Dimension);
        double d = this.target(dArray2, dArray, n, this.m_Class);
        double d2 = Double.POSITIVE_INFINITY;
        double d3 = 0.05;
        if (this.m_Rate != -1.0) {
            d3 = this.m_Rate;
        }
        block0: while (Utils.gr(d2 - d, this.m_STOP)) {
            int n2;
            dArray3 = dArray2;
            dArray2 = new double[this.m_Dimension];
            double[] dArray4 = this.delta(dArray3, dArray, n, this.m_Class);
            for (n2 = 0; n2 < this.m_Dimension; ++n2) {
                if (!Utils.gr(this.m_Variance[n][n2], 0.0)) continue;
                dArray2[n2] = dArray3[n2] + d3 * dArray4[n2];
            }
            d2 = d;
            d = this.target(dArray2, dArray, n, this.m_Class);
            while (Utils.gr(d, d2)) {
                if (this.m_Rate == -1.0) {
                    d3 *= this.m_Decay;
                    for (n2 = 0; n2 < this.m_Dimension; ++n2) {
                        if (!Utils.gr(this.m_Variance[n][n2], 0.0)) continue;
                        dArray2[n2] = dArray3[n2] + d3 * dArray4[n2];
                    }
                    d = this.target(dArray2, dArray, n, this.m_Class);
                    continue;
                }
                for (n2 = 0; n2 < this.m_Dimension; ++n2) {
                    dArray2[n2] = dArray3[n2];
                }
                break block0;
            }
        }
        this.m_Change[n] = dArray2;
    }

    private double[] delta(double[] dArray, double[][] dArray2, int n, double[] dArray3) {
        int n2;
        double d = dArray3[n];
        double[] dArray4 = new double[this.m_Dimension];
        for (n2 = 0; n2 < this.m_Dimension; ++n2) {
            dArray4[n2] = 0.0;
        }
        for (n2 = 0; n2 < dArray2.length; ++n2) {
            int n3;
            if (n2 == n || dArray2[n2] == null) continue;
            double d2 = d == dArray3[n2] ? 0.0 : Math.sqrt((double)this.m_Dimension - 1.0);
            double d3 = 0.0;
            for (n3 = 0; n3 < this.m_Dimension; ++n3) {
                if (!Utils.gr(this.m_Variance[n][n3], 0.0)) continue;
                d3 += dArray[n3] * (dArray2[n][n3] - dArray2[n2][n3]) * (dArray2[n][n3] - dArray2[n2][n3]);
            }
            if ((d3 = Math.sqrt(d3)) == 0.0) continue;
            for (n3 = 0; n3 < this.m_Dimension; ++n3) {
                if (!(this.m_Variance[n][n3] > 0.0)) continue;
                int n4 = n3;
                dArray4[n4] = dArray4[n4] + (d2 / d3 - 1.0) * 0.5 * (dArray2[n][n3] - dArray2[n2][n3]) * (dArray2[n][n3] - dArray2[n2][n3]);
            }
        }
        return dArray4;
    }

    public double target(double[] dArray, double[][] dArray2, int n, double[] dArray3) {
        double d = dArray3[n];
        double d2 = 0.0;
        for (int i = 0; i < dArray2.length; ++i) {
            if (i == n || dArray2[i] == null) continue;
            double d3 = d == dArray3[i] ? 0.0 : Math.sqrt((double)this.m_Dimension - 1.0);
            double d4 = 0.0;
            for (int j = 0; j < this.m_Dimension; ++j) {
                if (!Utils.gr(this.m_Variance[n][j], 0.0)) continue;
                d4 += dArray[j] * (dArray2[n][j] - dArray2[i][j]) * (dArray2[n][j] - dArray2[i][j]);
            }
            if (Double.isInfinite(d4 = Math.sqrt(d4))) {
                System.exit(1);
            }
            d2 += 0.5 * (d4 - d3) * (d4 - d3);
        }
        return d2;
    }

    public double classifyInstance(Instance instance) throws Exception {
        int n;
        instance = this.scale(instance);
        double[] dArray = new double[this.m_Dimension];
        for (int i = 0; i < this.m_Dimension; ++i) {
            dArray[i] = instance.relationalValue(1).variance(i);
        }
        double[] dArray2 = new double[this.m_Class.length];
        double[] dArray3 = new double[this.m_NumClasses];
        for (int i = 0; i < dArray3.length; ++i) {
            dArray3[i] = 0.0;
        }
        if ((instance = this.cleanse(instance)).relationalValue(1).numInstances() == 0) {
            if (this.getDebug()) {
                System.out.println("???Whole exemplar falls into ambiguous area!");
            }
            return 1.0;
        }
        double[] dArray4 = new double[this.m_Dimension];
        for (n = 0; n < this.m_Dimension; ++n) {
            dArray4[n] = instance.relationalValue(1).meanOrMode(n);
        }
        for (n = 0; n < dArray.length; ++n) {
            if (!Utils.eq(dArray[n], 0.0)) continue;
            dArray[n] = m_ZERO;
        }
        for (n = 0; n < this.m_Class.length; ++n) {
            dArray2[n] = this.m_ValidM[n] != null ? this.kullback(dArray4, this.m_ValidM[n], dArray, this.m_Variance[n], n) : Double.POSITIVE_INFINITY;
        }
        for (n = 0; n < this.m_Neighbour; ++n) {
            int n2 = Utils.minIndex(dArray2);
            int n3 = (int)this.m_Class[n2];
            dArray3[n3] = dArray3[n3] + this.m_Weights[n2];
            dArray2[n2] = Double.POSITIVE_INFINITY;
        }
        if (this.getDebug()) {
            System.out.println("???There are still some unambiguous instances in this exemplar! Predicted as: " + Utils.maxIndex(dArray3));
        }
        return Utils.maxIndex(dArray3);
    }

    public Instance cleanse(Instance instance) throws Exception {
        Instances instances = instance.relationalValue(1).stringFreeStructure();
        Instance instance2 = new Instance(instance.numAttributes());
        instance2.setDataset(this.m_Attributes);
        for (int i = 0; i < instance.relationalValue(1).numInstances(); ++i) {
            int n;
            int n2;
            Instance instance3 = instance.relationalValue(1).instance(i);
            double[] dArray = new double[this.m_Choose];
            double[] dArray2 = new double[this.m_Choose];
            int n3 = 0;
            int n4 = 0;
            double[] dArray3 = new double[this.m_Mean.length];
            double[] dArray4 = new double[this.m_Mean.length];
            for (n2 = 0; n2 < this.m_Mean.length; ++n2) {
                dArray4[n2] = this.m_ValidM[n2] == null ? Double.POSITIVE_INFINITY : this.distance(instance3, this.m_ValidM[n2], this.m_ValidV[n2], n2);
                dArray3[n2] = this.m_NoiseM[n2] == null ? Double.POSITIVE_INFINITY : this.distance(instance3, this.m_NoiseM[n2], this.m_NoiseV[n2], n2);
            }
            for (n2 = 0; n2 < this.m_Choose; ++n2) {
                n = Utils.minIndex(dArray4);
                dArray2[n2] = dArray4[n];
                dArray4[n] = Double.POSITIVE_INFINITY;
                n = Utils.minIndex(dArray3);
                dArray[n2] = dArray3[n];
                dArray3[n] = Double.POSITIVE_INFINITY;
            }
            n2 = 0;
            n = 0;
            while (n2 + n < this.m_Choose) {
                if (dArray2[n2] <= dArray[n]) {
                    ++n4;
                    ++n2;
                    continue;
                }
                ++n3;
                ++n;
            }
            if (n2 < n) continue;
            instances.add(instance3);
        }
        instance2.setValue(0, instance.value(0));
        instance2.setValue(1, (double)instance2.attribute(1).addRelation(instances));
        instance2.setValue(2, instance.value(2));
        return instance2;
    }

    public double kullback(double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4, int n) {
        int n2 = dArray.length;
        double d = 0.0;
        for (int i = 0; i < n2; ++i) {
            if (!Utils.gr(dArray3[i], 0.0) || !Utils.gr(dArray4[i], 0.0)) continue;
            d += Math.log(Math.sqrt(dArray4[i] / dArray3[i])) + dArray3[i] / (2.0 * dArray4[i]) + this.m_Change[n][i] * (dArray[i] - dArray2[i]) * (dArray[i] - dArray2[i]) / (2.0 * dArray4[i]) - 0.5;
        }
        return d;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>();
        vector.addElement(new Option("\tSet number of nearest neighbour for prediction\n\t(default 1)", "K", 1, "-K <number of neighbours>"));
        vector.addElement(new Option("\tSet number of nearest neighbour for cleansing the training data\n\t(default 1)", "S", 1, "-S <number of neighbours>"));
        vector.addElement(new Option("\tSet number of nearest neighbour for cleansing the testing data\n\t(default 1)", "E", 1, "-E <number of neighbours>"));
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        this.setDebug(Utils.getFlag('D', stringArray));
        String string = Utils.getOption('K', stringArray);
        if (string.length() != 0) {
            this.setNumNeighbours(Integer.parseInt(string));
        } else {
            this.setNumNeighbours(1);
        }
        string = Utils.getOption('S', stringArray);
        if (string.length() != 0) {
            this.setNumTrainingNoises(Integer.parseInt(string));
        } else {
            this.setNumTrainingNoises(1);
        }
        string = Utils.getOption('E', stringArray);
        if (string.length() != 0) {
            this.setNumTestingNoises(Integer.parseInt(string));
        } else {
            this.setNumTestingNoises(1);
        }
    }

    public String[] getOptions() {
        Vector<String> vector = new Vector<String>();
        if (this.getDebug()) {
            vector.add("-D");
        }
        vector.add("-K");
        vector.add("" + this.getNumNeighbours());
        vector.add("-S");
        vector.add("" + this.getNumTrainingNoises());
        vector.add("-E");
        vector.add("" + this.getNumTestingNoises());
        return vector.toArray(new String[vector.size()]);
    }

    public String numNeighboursTipText() {
        return "The number of nearest neighbours to the estimate the class prediction of test bags.";
    }

    public void setNumNeighbours(int n) {
        this.m_Neighbour = n;
    }

    public int getNumNeighbours() {
        return this.m_Neighbour;
    }

    public String numTrainingNoisesTipText() {
        return "The number of nearest neighbour instances in the selection of noises in the training data.";
    }

    public void setNumTrainingNoises(int n) {
        this.m_Select = n;
    }

    public int getNumTrainingNoises() {
        return this.m_Select;
    }

    public String numTestingNoisesTipText() {
        return "The number of nearest neighbour instances in the selection of noises in the test data.";
    }

    public int getNumTestingNoises() {
        return this.m_Choose;
    }

    public void setNumTestingNoises(int n) {
        this.m_Choose = n;
    }

    public static void main(String[] stringArray) {
        MINND.runClassifier(new MINND(), stringArray);
    }
}

