// Copyright 2010 by Mary McGlohon
// Carnegie Mellon University
// mmcgloho@cs.cmu.edu

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.HashMap;

public class LearnCascadeMarkovModel {
    static String[] cascadeFiles;
    static CascadeMarkovModel markovModel;
    static String outFile;

    public static void main(String[] args) {
        setUp(args);
        analyzeCascades(cascadeFiles);
        writeOutput(outFile);
       
    }
    
    static void setUp(String[] args) {
        // TODO(mmcgloho): Make it so multiple files can be used, use -flags
        cascadeFiles = new String[1];
        cascadeFiles[0] = args[0];
        outFile = args[1];
        boolean useAuthors = new Boolean(args[2]);
        int maxsize = new Integer(args[3]);
        markovModel = new CascadeMarkovModel(maxsize, useAuthors);
    }

    static void writeOutput(String outFile) {
        String dotFile = outFile;
        markovModel.writeGraphViz(dotFile);   
    }

    static void analyzeCascades(String[] cascadeFiles) {
        try {
            int threadID=0;
            for (int f = 0; f < cascadeFiles.length; ++f) {
                System.out.println(cascadeFiles[f]);
                BufferedReader br = new BufferedReader(new FileReader(cascadeFiles[f]));
                String line;
                while ((line=br.readLine())!=null) {
                    String cstring = readCascade(line);
                    if (!cstring.equals("")) {
                        threadID++;
                        analyzeLinksByCascade(cstring);
                    }
                }
            }
            markovModel.calcTransitionProbabilities();
            markovModel.print();
        } catch(IOException e) {
            e.printStackTrace();
        }
    }
    
    
    static void analyzeLinksByCascade(String cstring) {
        // First get a list of the links
        HashMap<Integer, String> linkmap = new HashMap<Integer, String>();
        String[] parts = cstring.split(" ");
        for (int i = 0; i < parts.length; i = i+2) {
            String[] nodeparts = parts[i].split(":");
            linkmap.put(new Integer(nodeparts[0]), parts[i] + " " + parts[i+1]);
        }

        // now from 1 to size
        //   String cascadeString = linkmap.get(1);
        String cascadeString = "";
        Cascade cOld = new Cascade(cascadeString, markovModel.useAuthors);
        Cascade cNew;
        int cascadeSize = new Cascade(cstring, markovModel.useAuthors).size();
        for (int i = 1; i <= cascadeSize; ++i) {
            cascadeString = cascadeString + " " + linkmap.get(i);
            cNew = new Cascade(cascadeString, markovModel.useAuthors);
            markovModel.addCascadeGrowth(cOld, cNew);
            cOld = cNew;
        }
    }

    // reads a cascade from a line and reorders the nodes from 1 to n
    // based on time or id, whichever works. If neither works return false;
    static String readCascade(String line) {
        Cascade c = new Cascade(line, markovModel.useAuthors);
        int toUse = checkCascadeLinks(line);
        if (toUse == 0) {
            c.remapByIds();
        } else if (toUse == 1) {
            // TODO(mmcgloho): add a method to Cascade to remap by timestamps.
            // c.remapByTimestamps();
            return "";
        } else {
            return "";
        }
        c.remapAuthors();
        return c.edgeString();
    }

    // Checks if the edges are in timed order-- and can be used
    static int checkCascadeLinks(String line) {
        // if use ids return 0
        // if use times return 1
        // if toss out return -1
        String[] parts = line.split(" ");
        HashMap<Integer, Long> times = new HashMap<Integer, Long>();
        for (int i = 0; i < parts.length; i = i + 2) {
            String[] nodeparts = parts[i].split(":");
            int src = new Integer(nodeparts[0]);
            long time = new Long(nodeparts[2]);
            times.put(src,time);
        }
        boolean idError = false;
        boolean timeError = false;
        for (int i =  0; i < parts.length; i = i+2) {
            String[] nodeparts = parts[i].split(":");
            int src = new Integer(nodeparts[0]);
            int dest = new Integer(parts[i+1]);
            if (src < dest)
                idError = true;
            if (dest > 0) {
                if (times.containsKey(dest)) {
                    long srctime = times.get(src);
                    long desttime = times.get(dest);
                    if (srctime < desttime) {
                        timeError = true;
                    }
                }
                else
                    timeError = true;
            }
        }
        if (!idError)
            return 0;
        else if (!timeError)
            return 1;
        else
            return -1;
    }

}