// Copyright 2008-2010 by Mary McGlohon
// Carnegie Mellon University
// mmcgloho@cs.cmu.edu

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;

// Takes a list of edges and formats into cascades (using an intermediate file)
public class BuildCascadesEfficient {
        
    static HashMap<Long, Integer> nodesToCascades;
    static HashMap<Integer, HashSet<Long>> cascadesToNodes;
    static HashMap<Integer, ArrayList<Long>> cascadeList;
    static int curr_cascade_id = 0;
    static String delimiter = ",";
    // may want to assign this later.
    static int maxEdges = Integer.MAX_VALUE;


    // To run, java BuildCascadesEfficient infile outfile useAuthor
    // infile should be comma-separated srcid,destid,time,srcauthor
    // if destid=-1 then it's a root
    // TODO(marymc): Make it so user can have multiple infiles, and use -flags
    public static void main(String[] args) {
        String infile = args[0];
        String midfile = infile +"-temp";
        String outfile = args[1];
        boolean useAuthor = new Boolean(args[2]);
        buildCascadesEfficient(infile, midfile, outfile, useAuthor);
    }

    public static void buildCascadesEfficient(String infile, String midfile, 
                                  String outfile, boolean useAuthor) {
	// Builds cascades by finding components, then writing outfiles.
	// Saves memory opposed to building cascades as you go.
        nodesToCascades = new HashMap<Long, Integer>();
        cascadesToNodes = new HashMap<Integer, HashSet<Long>>();
        cascadeList = new HashMap<Integer, ArrayList<Long>>();
        curr_cascade_id = 1;
        assignCascades(infile);
        writeCascadeIds(infile, midfile);
        writeAuthoredCascades(midfile, outfile);
    }

    private static boolean assignedToCascade(long nodeid) {
        return nodesToCascades.containsKey(nodeid);
    }
    

    private static int startNewCascade() {
        int cascadeid = curr_cascade_id++;
        HashSet<Long> casc = new HashSet<Long>();
        cascadesToNodes.put(cascadeid, casc);
        return cascadeid;
    }

    private static void addEdgeToCascade(long srcid, long destid, int cascadeid) {
        cascadesToNodes.get(cascadeid).add(destid);
        cascadesToNodes.get(cascadeid).add(srcid);
        nodesToCascades.put(srcid, cascadeid);
        nodesToCascades.put(destid, cascadeid);
    }

    private static int mergeCascades(int cascade1, int cascade2) {
        if (cascade1 == cascade2)
            return cascade1;
        else {
            int size1 = cascadesToNodes.get(cascade1).size();
            int size2 = cascadesToNodes.get(cascade2).size();
            int larger;
            int smaller;
            if (size1 > size2) {
                larger = cascade1;
                smaller = cascade2;
            } else {
                larger = cascade2;
                smaller = cascade1;
            }
            HashSet<Long> smallercasc = cascadesToNodes.remove(smaller);
            if (!cascadesToNodes.containsKey(larger))
                System.out.println(larger + " " + cascade1 + " " + cascade2);
            cascadesToNodes.get(larger).addAll(smallercasc);
            Iterator<Long> cascit = smallercasc.iterator();
            while (cascit.hasNext()) {
                nodesToCascades.put(cascit.next(), larger);
            }
            return larger;
        }
    }
    
  

    private static void assignCascades(String infile) {
        try {
            BufferedReader br = new BufferedReader(new FileReader(infile));
            String line;
            int numlines = 0;
            while ((line=br.readLine())!= null && numlines < maxEdges) {
                //System.out.println(line);
                ++numlines;
                if (numlines % 1000000 == 0) {
                    System.out.println(numlines);
                }
                String[] parts = line.split(delimiter);
                long srcnode = new Long(parts[0]);
                long destnode = new Long(parts[1]);
                if (srcnode != destnode) {
                    if (assignedToCascade(srcnode)) {
                        if (destnode != -1) {
                            if (assignedToCascade(destnode)) {
                                int srccascade = nodesToCascades.get(srcnode);
                                int destcascade = nodesToCascades.get(destnode);
                                /*if (srccascade == destcascade) {
                                    System.out.print(srcnode + " " + destnode);
                                    Iterator<Long> cascit = cascadesToNodes.get(srccascade).iterator();
                                    while (cascit.hasNext()) {
                                        System.out.print(" " + cascit.next());
                                    }
                                    System.out.println();
                                    }*/
                                int cascadeid = mergeCascades(srccascade, destcascade);
                                addEdgeToCascade(srcnode, destnode, cascadeid);
                            
                            } else {
                                // add dest node to src cascade
                                int srccascade = nodesToCascades.get(srcnode);
                                addEdgeToCascade(srcnode, destnode, srccascade);
                            }
                        } else {
                            int srccascade = nodesToCascades.get(srcnode);
                            addEdgeToCascade(srcnode, srcnode, srccascade);
                            // do nothing
                        }
                    } else {
                        // This node isn't assigned.
                        if (destnode != -1) {
                            if (assignedToCascade(destnode)) {
                                int destcascade = nodesToCascades.get(destnode);
                                addEdgeToCascade(srcnode, destnode, destcascade);
                            } else {
                                // start new cascade with both nodes
                                int newcascade = startNewCascade();
                                addEdgeToCascade(srcnode, destnode, newcascade);
                            }
                        } else {
                            // Start new cascade with only this node.
                            int newcascade = startNewCascade();
                            addEdgeToCascade(srcnode, srcnode, newcascade);
                        }
                    }
                }
                
            }
        } catch(IOException e) {
            e.printStackTrace();
        }
            
    }

    // Writes to an intermediate file the edges with the cascade ids
    static void writeCascadeIds(String infile, String midfile) {
        try {
            BufferedReader br = new BufferedReader(new FileReader(infile));
            FileWriter fw = new FileWriter(midfile);
            String line;
            int numlines = 0;
            while ((line=br.readLine()) != null && numlines < maxEdges) {
                ++numlines;
                
                //line = line.substring(0, line.length() - 1);
                line = line.trim();
		String[] parts = line.split(delimiter);
                int casc1, casc2;
                long srcnode = new Long(parts[0]);
                long destnode = new Long(parts[1]);
                if (srcnode != destnode) {
                    casc1 = nodesToCascades.get(srcnode);
                    if (destnode != -1)
                        casc2 = nodesToCascades.get(destnode);
                    else
                        casc2 = casc1;
                    if (casc1 != casc2) {
                        System.out.println(line);
                        System.out.println("OH NOES");
                    }
                    fw.write(casc1+ delimiter +line +"\n");
                }
            }
            br.close();
            fw.close();
        } catch(IOException e) {
            e.printStackTrace();
        }
    }


    // TODO(marymc): Find some way to put the authors in here
    // Takes the intermediate cascade-labeled edges and makes cascades
    static void writeAuthoredCascades(String midfile, String outfile) {
        Iterator<Long> nodeit = nodesToCascades.keySet().iterator();
	while (nodeit.hasNext()) {
	    long nodeid = nodeit.next();
	    int cascadeid = nodesToCascades.get(nodeid);
	    if (!cascadesToNodes.containsKey(cascadeid)) {
		HashSet<Long> mylist = new HashSet<Long>();
		cascadesToNodes.put(cascadeid, mylist);
	    }
	    cascadesToNodes.get(cascadeid).add(nodeid);
	}
	nodesToCascades.clear();
	int minCascadeSize = 1;
	Iterator<Integer> cascit = cascadesToNodes.keySet().iterator();
	while (cascit.hasNext()) {
	    int cascid = cascit.next();
	    HashSet<Long> nodes = cascadesToNodes.get(cascid);
	    if (nodes.size() >= minCascadeSize) {
		ArrayList<Long> mylist = new ArrayList<Long>();
		cascadeList.put(cascid, mylist);
	    }
        }
	cascadesToNodes.clear();
	try {
	    BufferedReader br = new BufferedReader(new FileReader(midfile));
	    String line;
            int numlines = 0;
            while ((line=br.readLine()) != null && numlines < maxEdges) {
		++numlines;
                String[] parts = line.split(delimiter);
		int cascid = new Integer(parts[0]);
		if (cascadeList.containsKey(cascid)) {
		    long srcid = new Long(parts[1]);
		    long destid = new Long(parts[2]);
                    long time = new Long(parts[3]);
                    long srcauthor = new Long(parts[4]);
                    cascadeList.get(cascid).add(srcid);
		    cascadeList.get(cascid).add(destid);
                    cascadeList.get(cascid).add(srcauthor);
                    cascadeList.get(cascid).add(time);
		}
	    }
	    FileWriter fw = new FileWriter(outfile);
	    cascit = cascadeList.keySet().iterator();
	    while (cascit.hasNext()) {
		ArrayList<Long> myedges = cascadeList.get(cascit.next());
		for (int i = 0; i < myedges.size(); i = i+4) {
		    fw.write(myedges.get(i) + ":" + myedges.get(i+2) + ":" +
                             myedges.get(i+3) + " " + myedges.get(i+1)+" ");
		}
		fw.write("\n");
	    }
            fw.close();
        } catch(IOException e) {
            e.printStackTrace();
        }
    }

}
