/*
 * Decompiled with CFR 0.152.
 */
package moa.streams.generators.multilabel;

import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Attribute;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import com.yahoo.labs.samoa.instances.SparseInstance;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Random;
import moa.core.FastVector;
import moa.core.InstanceExample;
import moa.core.MultilabelInstancesHeader;
import moa.core.ObjectRepository;
import moa.core.Utils;
import moa.options.AbstractOptionHandler;
import moa.options.ClassOption;
import moa.streams.InstanceStream;
import moa.tasks.TaskMonitor;

public class MetaMultilabelGenerator
extends AbstractOptionHandler
implements InstanceStream {
    private static final long serialVersionUID = 1L;
    public ClassOption binaryGeneratorOption = new ClassOption("binaryGenerator", 's', "Binary Generator (specify the number of attributes here, but only two classes!).", InstanceStream.class, "generators.RandomTreeGenerator");
    public IntOption metaRandomSeedOption = new IntOption("metaRandomSeed", 'm', "Random seed (for the meta process). Use two streams with the same seed and r > 0.0 in the second stream if you wish to introduce drift to the label dependencies without changing the underlying concept.", 1);
    public IntOption numLabelsOption = new IntOption("numLabels", 'c', "Number of labels.", 10, 2, Integer.MAX_VALUE);
    public IntOption skewOption = new IntOption("skew", 'k', "Skewed label distribution: 1 (default) = yes; 0 = no (relatively uniform) @NOTE: not currently implemented.", 1, 0, 1);
    public FloatOption labelCardinalityOption = new FloatOption("labelCardinality", 'z', "Desired label cardinality (average number of labels per example).", 1.5, 0.0, 2.147483647E9);
    public FloatOption labelCardinalityVarOption = new FloatOption("labelCardinalityVar", 'v', "Desired label cardinality variance (variance of z) @NOTE: not currently implemented.", 1.0, 0.0, 2.147483647E9);
    public FloatOption labelCardinalityRatioOption = new FloatOption("labelDependency", 'u', "Specifies how much label dependency from 0 (total independence) to 1 (full dependence).", 0.25, 0.0, 1.0);
    public FloatOption labelDependencyChangeRatioOption = new FloatOption("labelDependencyRatioChange", 'r', "Each label-pair dependency has a 'r' chance of being modified. Use this option on the second of two streams with the same random seed (-m) to introduce label-dependence drift.", 0.0, 0.0, 1.0);
    protected MultilabelInstancesHeader m_MultilabelInstancesHeader = null;
    protected InstanceStream m_BinaryGenerator = null;
    protected Instances multilabelStreamTemplate = null;
    protected Random m_MetaRandom = new Random();
    protected int m_L = 0;
    protected int m_A = 0;
    protected double[] priors = null;
    protected double[] priors_norm = null;
    protected double[][] Conditional = null;
    protected HashSet[] m_TopCombinations = null;
    LinkedList<Instance>[] queue = null;

    @Override
    public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
        this.restart();
    }

    @Override
    public void restart() {
        this.m_L = this.numLabelsOption.getValue();
        if (this.labelCardinalityOption.getValue() > (double)this.m_L) {
            System.err.println("Error: Label cardinality (z) cannot be greater than the number of labels (c)!");
            System.exit(1);
        }
        this.m_BinaryGenerator = (InstanceStream)this.getPreparedClassOption(this.binaryGeneratorOption);
        this.m_BinaryGenerator.restart();
        this.m_A = this.m_BinaryGenerator.getHeader().numAttributes() - 1;
        this.m_MetaRandom = new Random(this.metaRandomSeedOption.getValue());
        this.queue = new LinkedList[2];
        for (int i = 0; i < this.queue.length; ++i) {
            this.queue[i] = new LinkedList();
        }
        this.m_MultilabelInstancesHeader = this.generateMultilabelHeader(this.m_BinaryGenerator.getHeader());
        this.priors = this.generatePriors(this.m_MetaRandom, this.m_L, this.labelCardinalityOption.getValue(), this.skewOption.getValue() >= 1);
        boolean[][] DependencyMatrix = this.modifyDependencyMatrix(new boolean[this.m_L][this.m_L], this.labelCardinalityRatioOption.getValue(), this.m_MetaRandom);
        if (this.labelDependencyChangeRatioOption.getValue() > 0.0) {
            this.priors = this.modifyPriorVector(this.priors, this.labelDependencyChangeRatioOption.getValue(), this.m_MetaRandom, this.skewOption.getValue() >= 1);
            this.modifyDependencyMatrix(DependencyMatrix, this.labelDependencyChangeRatioOption.getValue(), this.m_MetaRandom);
        }
        this.Conditional = this.generateConditional(this.priors, DependencyMatrix);
        this.priors_norm = Arrays.copyOf(this.priors, this.priors.length);
        Utils.normalize(this.priors_norm);
        this.m_TopCombinations = this.getTopCombinations(this.m_A);
    }

    protected MultilabelInstancesHeader generateMultilabelHeader(Instances si) {
        Instances mi = new Instances(si, 0, 0);
        mi.setClassIndex(-1);
        mi.deleteAttributeAt(mi.numAttributes() - 1);
        FastVector<String> bfv = new FastVector<String>();
        bfv.addElement("0");
        bfv.addElement("1");
        for (int i = 0; i < this.m_L; ++i) {
            mi.insertAttributeAt(new Attribute("class" + i, bfv), i);
        }
        this.multilabelStreamTemplate = mi;
        this.multilabelStreamTemplate.setRelationName("SYN_Z" + this.labelCardinalityOption.getValue() + "L" + this.m_L + "X" + this.m_A + "S" + this.metaRandomSeedOption.getValue() + ": -C " + this.m_L);
        this.multilabelStreamTemplate.setClassIndex(this.m_L);
        return new MultilabelInstancesHeader(this.multilabelStreamTemplate, this.m_L);
    }

    private double[] generatePriors(Random r, int L, double z, boolean skew) {
        double[] P = new double[L];
        for (int i = 0; i < L; ++i) {
            P[i] = r.nextDouble();
        }
        do {
            double c = Utils.sum(P) / z;
            for (int i = 0; i < L; ++i) {
                P[i] = Math.min(1.0, P[i] / c);
            }
        } while (Utils.sum(P) < z);
        return P;
    }

    private Instance getNextWithBinary(int i) {
        int lim = 1000;
        if (this.queue[i].size() <= 0) {
            int c = -1;
            while (lim-- > 0) {
                Instance tinst = (Instance)this.m_BinaryGenerator.nextInstance().getData();
                c = (int)Math.round(tinst.classValue());
                if (i == c) {
                    return tinst;
                }
                if (this.queue[c].size() >= 100) continue;
                this.queue[c].add(tinst);
            }
            System.err.println("[Overflow] The binary stream is too skewed, could not get an example of class " + i + "");
            System.exit(1);
            return null;
        }
        return this.queue[i].remove();
    }

    @Override
    public InstanceExample nextInstance() {
        return new InstanceExample(this.generateMLInstance(this.generateSet()));
    }

    private HashSet generateSet() {
        int[] y = new int[this.m_L];
        int k = this.samplePMF(this.priors_norm);
        y[k] = 1;
        ArrayList<Integer> indices = this.getShuffledListToLWithoutK(this.m_L, k);
        for (int j : indices) {
            y[j] = this.joint(j, y) > this.m_MetaRandom.nextDouble() ? 1 : 0;
        }
        return this.vector2set(y);
    }

    private double joint(int k, int[] y) {
        double p = 1.0;
        for (int j = 0; j < y.length; ++j) {
            if (j == k || y[j] != 1) continue;
            p *= this.Conditional[k][j];
        }
        return p;
    }

    private Instance generateMLInstance(HashSet<Integer> Y) {
        SparseInstance x_ml = new SparseInstance((double)this.multilabelStreamTemplate.numAttributes());
        x_ml.setDataset(this.multilabelStreamTemplate);
        for (int j = 0; j < this.m_L; ++j) {
            x_ml.setValue(j, 0.0);
        }
        for (int l : Y) {
            x_ml.setValue(l, 1.0);
        }
        Instance x_0 = this.getNextWithBinary(0);
        Instance x_1 = this.getNextWithBinary(1);
        for (int a = 0; a < this.m_A; ++a) {
            if (Y.containsAll(this.m_TopCombinations[a])) {
                x_ml.setValue(this.m_L + a, x_1.value(a));
                continue;
            }
            x_ml.setValue(this.m_L + a, x_0.value(a));
        }
        return x_ml;
    }

    private int samplePMF(double[] p) {
        double r = this.m_MetaRandom.nextDouble();
        double sum = 0.0;
        for (int i = 0; i < p.length; ++i) {
            if (!(r < (sum += p[i]))) continue;
            return i;
        }
        return -1;
    }

    protected double[] modifyPriorVector(double[] P, double u, Random r, boolean skew) {
        for (int j = 0; j < P.length; ++j) {
            if (!(r.nextDouble() < u)) continue;
            P[j] = r.nextDouble();
        }
        return P;
    }

    protected boolean[][] modifyDependencyMatrix(boolean[][] M, double u, Random r) {
        for (int j = 0; j < M.length; ++j) {
            for (int k = j + 1; k < M[j].length; ++k) {
                if (!(r.nextDouble() <= u)) continue;
                boolean[] blArray = M[j];
                int n = k;
                blArray[n] = blArray[n] ^ true;
            }
        }
        return M;
    }

    protected double[][] generateConditional(double[] P, boolean[][] M) {
        int j;
        int L = P.length;
        double[][] Q = new double[L][L];
        for (j = 0; j < L; ++j) {
            Q[j][j] = P[j];
        }
        for (j = 0; j < Q.length; ++j) {
            for (int k = j + 1; k < Q[j].length; ++k) {
                if (M[j][k]) {
                    Q[j][k] = this.m_MetaRandom.nextBoolean() ? this.min(P[j], P[k]) : this.max(P[j], P[k]);
                    Q[k][j] = Q[j][k] * Q[j][j] / Q[k][k];
                    continue;
                }
                Q[j][k] = P[j];
                Q[k][j] = Q[j][k] * P[k] / P[j];
            }
        }
        return Q;
    }

    private HashSet[] getTopCombinations(int n) {
        final HashMap<HashSet, Integer> count = new HashMap<HashSet, Integer>();
        HashMap isets = new HashMap();
        int N = 100000;
        double lc = 0.0;
        for (int i = 0; i < N; ++i) {
            HashSet Y = this.generateSet();
            lc += (double)Y.size();
            count.put(Y, count.get(Y) != null ? (Integer)count.get(Y) + 1 : 1);
        }
        lc /= (double)N;
        ArrayList top_set = new ArrayList(count.keySet());
        Collections.sort(top_set, new Comparator<HashSet>(){

            @Override
            public int compare(HashSet Y1, HashSet Y2) {
                return ((Integer)count.get(Y2)).compareTo((Integer)count.get(Y1));
            }
        });
        System.err.println("The most common labelsets (from which we will build the map) will likely be: ");
        HashSet[] map_set = new HashSet[n];
        double[] weights = new double[n];
        int idx = 0;
        for (HashSet Y : top_set) {
            System.err.println(" " + Y + " : " + (double)((Integer)count.get(Y)).intValue() * 100.0 / (double)N + "%");
            weights[idx++] = ((Integer)count.get(Y)).intValue();
            if (idx != weights.length) continue;
            break;
        }
        double sum = Utils.sum(weights);
        System.err.println("Estimated Label Cardinality:  " + lc + "\n\n");
        System.err.println("Estimated % Unique Labelsets: " + (double)count.size() * 100.0 / (double)N + "%\n\n");
        Utils.normalize(weights);
        int k = 0;
        for (int i = 0; i < top_set.size() && k < map_set.length; ++i) {
            int num = (int)Math.round(Math.max(weights[i] * (double)map_set.length, 1.0));
            for (int j = 0; j < num && k < map_set.length; ++j) {
                map_set[k++] = (HashSet)top_set.get(i);
            }
        }
        Collections.shuffle(Arrays.asList(map_set));
        return map_set;
    }

    @Override
    public InstancesHeader getHeader() {
        return this.m_MultilabelInstancesHeader;
    }

    @Override
    public String getPurposeString() {
        return "Generates a multi-label stream based on a binary random generator.";
    }

    @Override
    public long estimatedRemainingInstances() {
        return -1L;
    }

    @Override
    public boolean hasMoreInstances() {
        return true;
    }

    @Override
    public boolean isRestartable() {
        return true;
    }

    @Override
    public void getDescription(StringBuilder sb, int indent) {
    }

    private int[] set2vector(HashSet<Integer> Y, int L) {
        int[] y = new int[L];
        for (int j : Y) {
            y[j] = 1;
        }
        return y;
    }

    private HashSet<Integer> vector2set(int[] y) {
        HashSet<Integer> Y = new HashSet<Integer>();
        for (int j = 0; j < y.length; ++j) {
            if (y[j] <= 0) continue;
            Y.add(j);
        }
        return Y;
    }

    private double max(double A, double B) {
        return Math.min(1.0, B / A);
    }

    private double min(double A, double B) {
        return Math.max(0.0, -1.0 + A + B);
    }

    private ArrayList<Integer> getShuffledListToLWithoutK(int L, int k) {
        ArrayList<Integer> list = new ArrayList<Integer>(L - 1);
        for (int j = 0; j < L; ++j) {
            if (j == k) continue;
            list.add(j);
        }
        Collections.shuffle(list);
        return list;
    }

    public static void main(String[] args) {
    }

    private void printMatrix(double[][] M) {
        System.out.println("--- MATRIX ---");
        for (int i = 0; i < M.length; ++i) {
            for (int j = 0; j < M[i].length; ++j) {
                System.out.print(" " + Utils.doubleToString(M[i][j], 5, 3));
            }
            System.out.println("");
        }
    }

    private void printVector(double[] V) {
        System.out.println("--- VECTOR ---");
        for (int j = 0; j < V.length; ++j) {
            System.out.print(" " + Utils.doubleToString(V[j], 5, 3));
        }
        System.out.println("");
    }
}

