/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import com.carrotsearch.hppc.IntIntHashMap;
import com.carrotsearch.hppc.ObjectDoubleHashMap;
import com.carrotsearch.hppc.cursors.IntIntCursor;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;

public class HierarchicalLDA
implements Serializable {
    InstanceList instances;
    InstanceList testing;
    NCRPNode rootNode;
    NCRPNode node;
    int numLevels;
    int numDocuments;
    int numTypes;
    double alpha = 10.0;
    double gamma = 1.0;
    double eta = 0.1;
    double etaSum;
    int[][] levels;
    NCRPNode[] documentLeaves;
    int totalNodes = 0;
    String stateFile = "hlda.state";
    Randoms random;
    boolean showProgress = true;
    int displayTopicsInterval = 50;
    int numWordsToDisplay = 10;

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    public void setGamma(double gamma) {
        this.gamma = gamma;
    }

    public void setEta(double eta) {
        this.eta = eta;
    }

    public void setStateFile(String stateFile) {
        this.stateFile = stateFile;
    }

    public void setTopicDisplay(int interval, int words) {
        this.displayTopicsInterval = interval;
        this.numWordsToDisplay = words;
    }

    public void setProgressDisplay(boolean showProgress) {
        this.showProgress = showProgress;
    }

    public void initialize(InstanceList instances, InstanceList testing, int numLevels, Randoms random) {
        this.instances = instances;
        this.testing = testing;
        this.numLevels = numLevels;
        this.random = random;
        if (!(((Instance)instances.get(0)).getData() instanceof FeatureSequence)) {
            throw new IllegalArgumentException("Input must be a FeatureSequence, using the --feature-sequence option when impoting data, for example");
        }
        this.numDocuments = instances.size();
        this.numTypes = instances.getDataAlphabet().size();
        this.etaSum = this.eta * (double)this.numTypes;
        NCRPNode[] path = new NCRPNode[numLevels];
        this.rootNode = new NCRPNode(this.numTypes);
        this.levels = new int[this.numDocuments][];
        this.documentLeaves = new NCRPNode[this.numDocuments];
        for (int doc = 0; doc < this.numDocuments; ++doc) {
            FeatureSequence fs = (FeatureSequence)((Instance)instances.get(doc)).getData();
            int seqLen = fs.getLength();
            path[0] = this.rootNode;
            ++this.rootNode.customers;
            for (int level = 1; level < numLevels; ++level) {
                path[level] = path[level - 1].select();
                ++path[level].customers;
            }
            this.node = path[numLevels - 1];
            this.levels[doc] = new int[seqLen];
            this.documentLeaves[doc] = this.node;
            for (int token = 0; token < seqLen; ++token) {
                int type = fs.getIndexAtPosition(token);
                this.levels[doc][token] = random.nextInt(numLevels);
                this.node = path[this.levels[doc][token]];
                ++this.node.totalTokens;
                int n = type;
                this.node.typeCounts[n] = this.node.typeCounts[n] + 1;
            }
        }
    }

    public void estimate(int numIterations) {
        for (int iteration = 1; iteration <= numIterations; ++iteration) {
            int doc;
            for (doc = 0; doc < this.numDocuments; ++doc) {
                this.samplePath(doc, iteration);
            }
            for (doc = 0; doc < this.numDocuments; ++doc) {
                this.sampleTopics(doc);
            }
            if (this.showProgress) {
                System.out.print(".");
                if (iteration % 50 == 0) {
                    System.out.println(" " + iteration);
                }
            }
            if (iteration % this.displayTopicsInterval != 0) continue;
            this.printNodes();
        }
    }

    public void samplePath(int doc, int iteration) {
        int i;
        int level;
        NCRPNode[] path = new NCRPNode[this.numLevels];
        NCRPNode node = this.documentLeaves[doc];
        for (level = this.numLevels - 1; level >= 0; --level) {
            path[level] = node;
            node = node.parent;
        }
        this.documentLeaves[doc].dropPath();
        ObjectDoubleHashMap<NCRPNode> nodeWeights = new ObjectDoubleHashMap<NCRPNode>();
        this.calculateNCRP(nodeWeights, this.rootNode, 0.0);
        IntIntHashMap[] typeCounts = new IntIntHashMap[this.numLevels];
        for (level = 0; level < this.numLevels; ++level) {
            typeCounts[level] = new IntIntHashMap();
        }
        int[] docLevels = this.levels[doc];
        FeatureSequence fs = (FeatureSequence)((Instance)this.instances.get(doc)).getData();
        for (int token = 0; token < docLevels.length; ++token) {
            level = docLevels[token];
            int type = fs.getIndexAtPosition(token);
            if (!typeCounts[level].containsKey(type)) {
                typeCounts[level].put(type, 1);
            } else {
                typeCounts[level].addTo(type, 1);
            }
            int n = type;
            path[level].typeCounts[n] = path[level].typeCounts[n] - 1;
            assert (path[level].typeCounts[type] >= 0);
            --path[level].totalTokens;
            assert (path[level].totalTokens >= 0);
        }
        double[] newTopicWeights = new double[this.numLevels];
        for (level = 1; level < this.numLevels; ++level) {
            int totalTokens = 0;
            for (IntIntCursor keyVal : typeCounts[level]) {
                for (int i2 = 0; i2 < keyVal.value; ++i2) {
                    int n = level;
                    newTopicWeights[n] = newTopicWeights[n] + Math.log((this.eta + (double)i2) / (this.etaSum + (double)totalTokens));
                    ++totalTokens;
                }
            }
        }
        this.calculateWordLikelihood(nodeWeights, this.rootNode, 0.0, typeCounts, newTopicWeights, 0, iteration);
        Object[] objectArray = nodeWeights.keys().toArray();
        NCRPNode[] nodes = (NCRPNode[])Arrays.copyOf(objectArray, objectArray.length, NCRPNode[].class);
        double[] weights = new double[nodes.length];
        double sum = 0.0;
        double max = Double.NEGATIVE_INFINITY;
        for (i = 0; i < nodes.length; ++i) {
            if (!(nodeWeights.get(nodes[i]) > max)) continue;
            max = nodeWeights.get(nodes[i]);
        }
        for (i = 0; i < nodes.length; ++i) {
            weights[i] = Math.exp(nodeWeights.get(nodes[i]) - max);
            sum += weights[i];
        }
        node = nodes[this.random.nextDiscrete(weights, sum)];
        if (!node.isLeaf()) {
            node = node.getNewLeaf();
        }
        node.addPath();
        this.documentLeaves[doc] = node;
        for (level = this.numLevels - 1; level >= 0; --level) {
            for (IntIntCursor keyVal : typeCounts[level]) {
                int n = keyVal.key;
                node.typeCounts[n] = node.typeCounts[n] + keyVal.value;
                node.totalTokens += keyVal.value;
            }
            node = node.parent;
        }
    }

    public void calculateNCRP(ObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight) {
        for (NCRPNode child : node.children) {
            this.calculateNCRP(nodeWeights, child, weight + Math.log((double)child.customers / ((double)node.customers + this.gamma)));
        }
        nodeWeights.put(node, weight + Math.log(this.gamma / ((double)node.customers + this.gamma)));
    }

    public void calculateWordLikelihood(ObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight, IntIntHashMap[] typeCounts, double[] newTopicWeights, int level, int iteration) {
        double nodeWeight = 0.0;
        int totalTokens = 0;
        for (IntIntCursor keyVal : typeCounts[level]) {
            for (int i = 0; i < keyVal.value; ++i) {
                nodeWeight += Math.log((this.eta + (double)node.typeCounts[keyVal.key] + (double)i) / (this.etaSum + (double)node.totalTokens + (double)totalTokens));
                ++totalTokens;
            }
        }
        for (NCRPNode child : node.children) {
            this.calculateWordLikelihood(nodeWeights, child, weight + nodeWeight, typeCounts, newTopicWeights, level + 1, iteration);
        }
        ++level;
        while (level < this.numLevels) {
            nodeWeight += newTopicWeights[level];
            ++level;
        }
        nodeWeights.addTo(node, nodeWeight);
    }

    public void propagateTopicWeight(ObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight) {
        if (!nodeWeights.containsKey(node)) {
            return;
        }
        for (NCRPNode child : node.children) {
            this.propagateTopicWeight(nodeWeights, child, weight);
        }
        nodeWeights.addTo(node, weight);
    }

    public void sampleTopics(int doc) {
        int token;
        int level;
        FeatureSequence fs = (FeatureSequence)((Instance)this.instances.get(doc)).getData();
        int seqLen = fs.getLength();
        int[] docLevels = this.levels[doc];
        NCRPNode[] path = new NCRPNode[this.numLevels];
        int[] levelCounts = new int[this.numLevels];
        NCRPNode node = this.documentLeaves[doc];
        for (level = this.numLevels - 1; level >= 0; --level) {
            path[level] = node;
            node = node.parent;
        }
        double[] levelWeights = new double[this.numLevels];
        for (token = 0; token < seqLen; ++token) {
            int n = docLevels[token];
            levelCounts[n] = levelCounts[n] + 1;
        }
        for (token = 0; token < seqLen; ++token) {
            int type = fs.getIndexAtPosition(token);
            int n = docLevels[token];
            levelCounts[n] = levelCounts[n] - 1;
            node = path[docLevels[token]];
            int n2 = type;
            node.typeCounts[n2] = node.typeCounts[n2] - 1;
            --node.totalTokens;
            double sum = 0.0;
            for (level = 0; level < this.numLevels; ++level) {
                levelWeights[level] = (this.alpha + (double)levelCounts[level]) * (this.eta + (double)path[level].typeCounts[type]) / (this.etaSum + (double)path[level].totalTokens);
                sum += levelWeights[level];
            }
            docLevels[token] = level = this.random.nextDiscrete(levelWeights, sum);
            int n3 = docLevels[token];
            levelCounts[n3] = levelCounts[n3] + 1;
            node = path[level];
            int n4 = type;
            node.typeCounts[n4] = node.typeCounts[n4] + 1;
            ++node.totalTokens;
        }
    }

    public void printState() throws IOException, FileNotFoundException {
        this.printState(new PrintWriter(new BufferedWriter(new FileWriter(this.stateFile))));
    }

    public void printState(PrintWriter out) throws IOException {
        int doc = 0;
        Alphabet alphabet = this.instances.getDataAlphabet();
        for (Instance instance : this.instances) {
            int level;
            FeatureSequence fs = (FeatureSequence)instance.getData();
            int seqLen = fs.getLength();
            int[] docLevels = this.levels[doc];
            StringBuffer path = new StringBuffer();
            NCRPNode node = this.documentLeaves[doc];
            for (level = this.numLevels - 1; level >= 0; --level) {
                path.append(node.nodeID + " ");
                node = node.parent;
            }
            for (int token = 0; token < seqLen; ++token) {
                int type = fs.getIndexAtPosition(token);
                level = docLevels[token];
                out.println(path + "" + type + " " + alphabet.lookupObject(type) + " " + level + " ");
            }
            ++doc;
        }
    }

    public void printNodes() {
        this.printNode(this.rootNode, 0, false);
    }

    public void printNodes(boolean withWeight) {
        this.printNode(this.rootNode, 0, withWeight);
    }

    public void printNode(NCRPNode node, int indent, boolean withWeight) {
        StringBuffer out = new StringBuffer();
        for (int i = 0; i < indent; ++i) {
            out.append("  ");
        }
        out.append(node.totalTokens + "/" + node.customers + " ");
        out.append(node.getTopWords(this.numWordsToDisplay, withWeight));
        System.out.println(out);
        for (NCRPNode child : node.children) {
            this.printNode(child, indent + 1, withWeight);
        }
    }

    public double empiricalLikelihood(int numSamples, InstanceList testing) {
        int doc;
        int sample;
        NCRPNode[] path = new NCRPNode[this.numLevels];
        path[0] = this.rootNode;
        Dirichlet dirichlet = new Dirichlet(this.numLevels, this.alpha);
        double[] multinomial = new double[this.numTypes];
        double[][] likelihoods = new double[testing.size()][numSamples];
        for (sample = 0; sample < numSamples; ++sample) {
            int type;
            int level;
            Arrays.fill(multinomial, 0.0);
            for (level = 1; level < this.numLevels; ++level) {
                path[level] = path[level - 1].selectExisting();
            }
            double[] levelWeights = dirichlet.nextDistribution();
            for (type = 0; type < this.numTypes; ++type) {
                for (level = 0; level < this.numLevels; ++level) {
                    NCRPNode node = path[level];
                    int n = type;
                    multinomial[n] = multinomial[n] + levelWeights[level] * (this.eta + (double)node.typeCounts[type]) / (this.etaSum + (double)node.totalTokens);
                }
            }
            for (type = 0; type < this.numTypes; ++type) {
                multinomial[type] = Math.log(multinomial[type]);
            }
            for (doc = 0; doc < testing.size(); ++doc) {
                FeatureSequence fs = (FeatureSequence)((Instance)testing.get(doc)).getData();
                int seqLen = fs.getLength();
                for (int token = 0; token < seqLen; ++token) {
                    type = fs.getIndexAtPosition(token);
                    double[] dArray = likelihoods[doc];
                    int n = sample;
                    dArray[n] = dArray[n] + multinomial[type];
                }
            }
        }
        double averageLogLikelihood = 0.0;
        double logNumSamples = Math.log(numSamples);
        for (doc = 0; doc < testing.size(); ++doc) {
            double max = Double.NEGATIVE_INFINITY;
            for (sample = 0; sample < numSamples; ++sample) {
                if (!(likelihoods[doc][sample] > max)) continue;
                max = likelihoods[doc][sample];
            }
            double sum = 0.0;
            for (sample = 0; sample < numSamples; ++sample) {
                sum += Math.exp(likelihoods[doc][sample] - max);
            }
            averageLogLikelihood += Math.log(sum) + max - logNumSamples;
        }
        return averageLogLikelihood;
    }

    public void write(File serializedModelFile) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(serializedModelFile));
            oos.writeObject(this);
            oos.close();
        }
        catch (IOException e) {
            System.err.println("Problem serializing HierarchicalLDA to file " + serializedModelFile + ": " + e);
        }
    }

    public static HierarchicalLDA read(File f) throws Exception {
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));
        HierarchicalLDA topicModel = (HierarchicalLDA)ois.readObject();
        ois.close();
        return topicModel;
    }

    public static void main(String[] args) {
        try {
            InstanceList instances = InstanceList.load(new File(args[0]));
            InstanceList testing = InstanceList.load(new File(args[1]));
            HierarchicalLDA sampler = new HierarchicalLDA();
            sampler.initialize(instances, testing, 5, new Randoms());
            sampler.estimate(250);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    class NCRPNode
    implements Serializable {
        int customers = 0;
        ArrayList<NCRPNode> children;
        NCRPNode parent;
        int level;
        int totalTokens;
        int[] typeCounts;
        public int nodeID;

        public NCRPNode(NCRPNode parent, int dimensions, int level) {
            this.parent = parent;
            this.children = new ArrayList();
            this.level = level;
            this.totalTokens = 0;
            this.typeCounts = new int[dimensions];
            this.nodeID = HierarchicalLDA.this.totalNodes++;
        }

        public NCRPNode(int dimensions) {
            this(null, dimensions, 0);
        }

        public NCRPNode addChild() {
            NCRPNode node = new NCRPNode(this, this.typeCounts.length, this.level + 1);
            this.children.add(node);
            return node;
        }

        public boolean isLeaf() {
            return this.level == HierarchicalLDA.this.numLevels - 1;
        }

        public NCRPNode getNewLeaf() {
            NCRPNode node = this;
            for (int l = this.level; l < HierarchicalLDA.this.numLevels - 1; ++l) {
                node = node.addChild();
            }
            return node;
        }

        public void dropPath() {
            NCRPNode node = this;
            --node.customers;
            if (node.customers == 0) {
                node.parent.remove(node);
            }
            for (int l = 1; l < HierarchicalLDA.this.numLevels; ++l) {
                node = node.parent;
                --node.customers;
                if (node.customers != 0) continue;
                node.parent.remove(node);
            }
        }

        public void remove(NCRPNode node) {
            this.children.remove(node);
        }

        public void addPath() {
            NCRPNode node = this;
            ++node.customers;
            for (int l = 1; l < HierarchicalLDA.this.numLevels; ++l) {
                node = node.parent;
                ++node.customers;
            }
        }

        public NCRPNode selectExisting() {
            double[] weights = new double[this.children.size()];
            int i = 0;
            for (NCRPNode child : this.children) {
                weights[i] = (double)child.customers / (HierarchicalLDA.this.gamma + (double)this.customers);
                ++i;
            }
            int choice = HierarchicalLDA.this.random.nextDiscrete(weights);
            return this.children.get(choice);
        }

        public NCRPNode select() {
            double[] weights = new double[this.children.size() + 1];
            weights[0] = HierarchicalLDA.this.gamma / (HierarchicalLDA.this.gamma + (double)this.customers);
            int i = 1;
            for (NCRPNode child : this.children) {
                weights[i] = (double)child.customers / (HierarchicalLDA.this.gamma + (double)this.customers);
                ++i;
            }
            int choice = HierarchicalLDA.this.random.nextDiscrete(weights);
            if (choice == 0) {
                return this.addChild();
            }
            return this.children.get(choice - 1);
        }

        public String getTopWords(int numWords, boolean withWeight) {
            Object[] sortedTypes = new IDSorter[HierarchicalLDA.this.numTypes];
            for (int type = 0; type < HierarchicalLDA.this.numTypes; ++type) {
                sortedTypes[type] = new IDSorter(type, this.typeCounts[type]);
            }
            Arrays.sort(sortedTypes);
            Alphabet alphabet = HierarchicalLDA.this.instances.getDataAlphabet();
            StringBuffer out = new StringBuffer();
            for (int i = 0; i < numWords; ++i) {
                if (withWeight) {
                    out.append(alphabet.lookupObject(((IDSorter)sortedTypes[i]).getID()) + ":" + ((IDSorter)sortedTypes[i]).getWeight() + " ");
                    continue;
                }
                out.append(alphabet.lookupObject(((IDSorter)sortedTypes[i]).getID()) + " ");
            }
            return out.toString();
        }
    }
}

