// Copyright 2010 by Mary McGlohon
// Carnegie Mellon University
// mmcgloho@cs.cmu.edu

import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

public class CascadeMarkovModel {

    HashMap<Integer, ArrayList<CascadeMarkovModelNode>> markovModel;
    boolean useAuthors;

    // Generate all possible cascades up to a given size
    public CascadeMarkovModel(int maxsize, boolean useAuthors) {
        this.useAuthors = useAuthors;
        buildMarkovModel(maxsize);
    }

    public void print() {
        for (int i  = 1; i < markovModel.size(); ++i) {
            System.out.println("\n\n---- Level "+ i + "--------");
            ArrayList<CascadeMarkovModelNode> currLevel = markovModel.get(i);
            for (int j = 0; j < currLevel.size(); ++j) {
                currLevel.get(j).print();
            }
        }
    }

    private void buildMarkovModel(int maxSize) {        
        markovModel = new HashMap<Integer, ArrayList<CascadeMarkovModelNode>>();
        for (int i = 1; i <= maxSize; ++i) {
            markovModel.put(i, new ArrayList<CascadeMarkovModelNode>());
        }
        markovModel.put(0, new ArrayList<CascadeMarkovModelNode>());
        CascadeMarkovModelNode nullCascade =
            new CascadeMarkovModelNode(new Cascade("", useAuthors));
        markovModel.get(0).add(nullCascade);
        nullCascade.generateChildren();
        for (int i = 1; i < maxSize; ++i) {
            ArrayList<CascadeMarkovModelNode> currLevel = markovModel.get(i);
            for (int j = 0; j < currLevel.size(); ++j) {
                currLevel.get(j).generateChildren();
            }
        }
    }


    private CascadeMarkovModelNode findCascade(Cascade c) {
        int size = c.size();
        ArrayList<CascadeMarkovModelNode> level = markovModel.get(size);
        for (int i = 0; i < level.size(); ++i) {
            if (Cascade.isomorphic(c, level.get(i).cascade))
                return level.get(i);
        }
        return null;
    }
                                       
    public void addCascadeGrowth(Cascade oldCascade, Cascade newCascade) {
        // find the old cascade
        if (newCascade.size() <= markovModel.size()) {
            CascadeMarkovModelNode oldnode = findCascade(oldCascade);
            oldnode.incrementChild(newCascade);
        }
    }

    public void calcTransitionProbabilities() {
        for (int i  = 1; i < markovModel.size(); ++i) {
            System.out.println("\n\n---- Level "+ i + "--------");
            ArrayList<CascadeMarkovModelNode> currLevel = markovModel.get(i);
            for (int j = 0; j < currLevel.size(); ++j) {
                currLevel.get(j).calcTransitionProbabilities();
            }
        }
    }

    public void sortLevelsByFrequency() {
        for (int i = 2; i < markovModel.size(); ++i) {
            ArrayList<CascadeMarkovModelNode> currLevel = markovModel.get(i);
            HashMap<CascadeMarkovModelNode, Integer> levelmap = 
                new HashMap<CascadeMarkovModelNode, Integer>();
            for (int j = 0; j < currLevel.size(); ++j) {
                levelmap.put(currLevel.get(j), currLevel.get(j).getFrequency());
            }
            ArrayList<Map.Entry<CascadeMarkovModelNode, Integer>> sortedLevel =
                CascadeUtils.sortMapByValues(levelmap);
            ArrayList<CascadeMarkovModelNode> newLevel = 
                new ArrayList<CascadeMarkovModelNode>();
            for (int j = 0; j < sortedLevel.size(); ++j) {
                newLevel.add(sortedLevel.get(j).getKey());
            }
            markovModel.put(i, newLevel);
        }
    }

    // TODO(mmcgloho): Move this stuff to CascadeViz class
    public void writeGraphViz(String outfile) {
        //sortLevelsByFrequency();
        try {
            FileWriter fw = new FileWriter(outfile);
	    fw.write("digraph G { \n");
            fw.write("compound=true;\n");
	    fw.write("edge [dir = \"back\"];\n\n");
	    // make the boxes for sizes
	    for (int sz = 1; sz <= markovModel.size() + 1; ++sz) {
                double height = 0.5;
                if (sz >=4)
                    height = (double)sz/2; 
                fw.write("size"+sz+ 
                         //" [shape=box,color=yellow,style=filled,fontsize=100];\n");
                         " [shape=box,color=yellow,style=invis,fontsize=1,"+
                         "width=.01,height="+ height + "];\n");
                
            }
	    for (int sz = 1; sz <= markovModel.size()-1; ++sz) {
                // print all cascades on this level
                ArrayList<CascadeMarkovModelNode> currlevel = markovModel.get(sz);
                // TODO: maybe sort these first
                for (int i = 0; i < currlevel.size(); ++i) {
                    CascadeMarkovModelNode currnode = currlevel.get(i);
                    String cascadeName = "size" + sz + "type" + i + "n";
                    String stayp;
                    if (sz < markovModel.size()) {
                        stayp =  new Double(currnode.getChildProbability(null)).
                            toString();
                        stayp = stayp.substring(0, Math.min(4, stayp.length()));
                    }
                    else
                        stayp = "n/a";
                    int frequency = currnode.getFrequency();
                    String label = "count = " + frequency + 
                        "\\n p_{stay} = " + stayp;
                    String cascadestr = CascadeViz.graphVizString(currnode.cascade, cascadeName, label);
                    fw.write(cascadestr);
                    fw.write("\n");
                }
                // draw a couple extra edges to cause layout to put each size on new line
                for (int i = 0; i < currlevel.size(); ++i) {
                    fw.write("size"+sz+"->"+"size"+sz+"type"+i+"n1 [style=invis]; \n");
                    fw.write("size"+sz+"type"+i+"n"+sz+"->size"+ (sz+1) + 
                             " [style=invis];\n"); 
                }
                
            }  
            // Now add the edges between the threads for the model
            for (int sz = 1; sz < markovModel.size()-1; ++sz) {
                ArrayList<CascadeMarkovModelNode> currlevel = markovModel.get(sz);
                for (int i = 0; i < currlevel.size(); ++i) {
                    CascadeMarkovModelNode currnode = currlevel.get(i);
                    String srccluster = getClusterName(sz, i);
                    for (int j = 0; j < currnode.children.size(); ++j) {
                        String destcluster = 
                            getClusterName(currnode.children.get(j).cascade);
                        double prob = 
                            currnode.getChildProbability(currnode.children.get(j).cascade);
                        String probstr = new Double(prob).toString();
                        probstr = probstr.substring(0,Math.min(4, probstr.length()));
                        // now add the edge, backwards since 
                        // we reversed edges
                        // the src will be the bottom node of parent
                        // the dest will be the top node of child
                        fw.write(srccluster+ sz +"->"+destcluster + "1" +
                                 "[lhead=cluster_" + destcluster + "," + 
                                 "ltail=cluster_" + srccluster + "," +
                                 "label=\"" + probstr + "\"," +
                                 "dir=\"forward\"];\n");
                        // write extra edges
                        /*int transitions = currnode.transitionPaths.get(j);
                        for (int k = 1; k < transitions; ++k) {
                            fw.write(srccluster+ sz +"->"+destcluster + "1" +
                                 "[lhead=cluster_" + destcluster + "," + 
                                 "ltail=cluster_" + srccluster + "," +
                                     // "label=\"" + probstr + "\"," +
                                 "dir=\"forward\"];\n");
                                 }*/
                                 
                    }
                }
             }
	    fw.write("}\n");
            fw.close();
        } catch(IOException e) {
            e.printStackTrace();
        }                
    }
        
    private String getClusterName(int size, int index) {
        return "size"+size+"type"+index+"n";
    }

    private String getClusterName(Cascade c) {
        int index = -1;
        int size = c.size();
        ArrayList<CascadeMarkovModelNode> level = markovModel.get(size);
        for (int i = 0; i < level.size(); ++i) {
            if (Cascade.isomorphic(c, level.get(i).cascade))
                index = i;
        }
        return getClusterName(size, index);
    }



    //*************** MARKOV MODEL NODE ********************//
    // Hey look it's a member class.  What could go wrong?
    
    class CascadeMarkovModelNode {
        Cascade cascade;
        // Note there can be duplicate references in both children and parents
        ArrayList<CascadeMarkovModelNode> children;
        ArrayList<CascadeMarkovModelNode> parents;
        //maybe deal with this later.
        //ArrayList<CascadeMarkovModelNode> nonuniqueChildren;
        //ArrayList<CascadeMarkovModelNode> nonuniqueParents;
        ArrayList<Integer> childrenCounts;
        ArrayList<Integer> transitionPaths;
        ArrayList<Double> transitionProbabilities;

        CascadeMarkovModelNode(Cascade c) {
            this.cascade = c;
            children = new ArrayList<CascadeMarkovModelNode>();
            parents = new ArrayList<CascadeMarkovModelNode>();
            childrenCounts = new ArrayList<Integer>();
            transitionPaths = new ArrayList<Integer>();
            transitionProbabilities = new ArrayList<Double>();
        }
        
        void generateChildren() {
            if (this.cascade.size() == 0) {
                // just add a single cascade as child
                Cascade one = new Cascade("1:1:0 -1", useAuthors);
                this.addChild(one);
                return;
            }

            // for each node in the cascade, make a new cascade
            // with one extra child for that node
            int nextsize = this.cascade.size() + 1;
            String currCascadeString = cascade.edgeString();
            for (int i = 1; i <= cascade.size(); ++i) {
                // Make a copy and add a new edge.
                if (!useAuthors) {
                    String newCascadeString = currCascadeString +
                        new Integer(nextsize).toString() + ":0:0 " + 
                        new Integer(i).toString() + " ";
                    Cascade cascadePrime = 
                        new Cascade(newCascadeString, this.cascade.useAuthors());
                    addChild(cascadePrime);
                } else {
                    // if we're using authors we need to generate
                    // a string for each shape.
                    for (int j = 1; j <= cascade.authorCount() + 1; ++j) {
                        String newCascadeString = 
                            currCascadeString + 
                            new Integer(nextsize).toString() + ":" + j + ":0 " +
                            new Integer(i).toString() + " ";
                        Cascade cascadePrime = 
                            new Cascade(newCascadeString, this.cascade.useAuthors());
                        addChild(cascadePrime);
                    }
                }
            }
        }
        
        // Check if already in MM, add if not.
        // Also add to the children.
        private void addChild(Cascade cascadePrime) {
            // reference to child. In the end we'll add this to the kids.
            CascadeMarkovModelNode childref = 
                new CascadeMarkovModelNode(cascadePrime);
            // If it's already in the level just pick that one.
            ArrayList<CascadeMarkovModelNode> currentLevel = 
                markovModel.get(cascadePrime.size());
            boolean found = false;
            for (int j = 0; j < currentLevel.size(); ++j) {
                CascadeMarkovModelNode cand = currentLevel.get(j);
                if (Cascade.isomorphic(cascadePrime, cand.cascade)) {
                    childref = cand;
                    found = true;
                    break;
                }
            }
            if (!found) {
                currentLevel.add(childref);
            }
            // Now add the ref to the children
            if (!this.children.contains(childref)) {
                this.children.add(childref);
                this.transitionPaths.add(1);
                childref.parents.add(this);
                childrenCounts.add(0);
            } else {
                // add an extra transition path
                int index = this.children.indexOf(childref);
                transitionPaths.set(index, 
                                    transitionPaths.get(index) + 1);
            }
        }
        

        boolean incrementChild(Cascade child) {
            for (int i = 0; i < children.size(); ++i) {
                if (Cascade.isomorphic(children.get(i).cascade, child)) {
                    childrenCounts.set(i, childrenCounts.get(i)+1);
                }
            }
            return false;
        }

        int getChildCount(Cascade child) {
            //System.out.println("parent " + this.cascade.edgeString());
            //System.out.println("child  "+ child.edgeString());
            for (int i = 0; i < children.size(); ++i) {
                if (Cascade.isomorphic(children.get(i).cascade, child)) {
                    //if (children.get(i).cascade == child) {
                    return childrenCounts.get(i);
                } 
            }
            return -1;
        }

        double getChildProbability(Cascade child) {
            if (child == null)
                return transitionProbabilities.
                    get(transitionProbabilities.size()-1);
            for (int i = 0; i < children.size(); ++i) {
                if (Cascade.isomorphic(children.get(i).cascade, child)) {
                    return transitionProbabilities.get(i);
                } 
            }
            return -1;
        }

        /*ArrayList<Integer> getParentCounts() {
            ArrayList<Integer> parentCounts = new ArrayList<Integer>();
            for (int i = 0; i < parents.size(); ++i) {
                parentCounts.add(parents.get(i).getChildCount(this.cascade));
            }
            return parentCounts;
            }*/

        int getFrequency() {
            // just the count coming in from parents
            ArrayList<Integer> parentCounts = new ArrayList<Integer>();
            for (int i = 0; i < parents.size(); ++i) {
                parentCounts.add(parents.get(i).getChildCount(this.cascade));
            }
            int countTotal = 0;
            for (int i = 0; i < parentCounts.size(); ++i) {
                countTotal += parentCounts.get(i);
            }
            return countTotal;
        }
            
        void calcTransitionProbabilities() {
            // get count of ending here, the count coming in from parents
            // minus the number going out
            /*int countTotal = 0;
            ArrayList<Integer> parentCounts = getParentCounts();
            for (int i = 0; i < parentCounts.size(); ++i) {
                countTotal += parentCounts.get(i);
                }*/
            // get counts of all children
            int countTotal = getFrequency();
            int countEnd = countTotal;
            for (int i = 0; i < childrenCounts.size(); ++i) {
                transitionProbabilities.add((double)childrenCounts.get(i)/
                                            (double)countTotal);
                countEnd -= childrenCounts.get(i);
            }
            transitionProbabilities.add((double)countEnd/(double)countTotal);
            // also the count of ending here
        }

        void print() {
            System.out.println(this.cascade.edgeString());
            for (int i =  0; i < children.size(); ++i) {
                System.out.println("\t" + childrenCounts.get(i) + 
                                   " " + transitionProbabilities.get(i) + " of "
                                   + children.get(i).cascade.edgeString());
            }
            System.out.println("\t" + transitionProbabilities.get(children.size()) +
                               " probability of ending");

        }
    }

}