import java.util.*;
import java.math.*;

public class DecisionTree{

    public Node buildModel(double[][] dataset, int maxDepth, 
			   ArrayList<SplitOption> possibleSplits,
			   int labelColIDx){
	Log.write("Model generation started.");
	Node rootNode = buildRootNode(dataset, labelColIDx);
	Log.write("Root node built....");
	LinkedList<Node> treeQueue = new LinkedList<Node>();
	treeQueue.add(rootNode);//to get the recursion started
	Log.write("Subtree generation started...");
	buildSubTree(treeQueue, maxDepth, dataset.length, labelColIDx,
		     possibleSplits);
	Log.write("Model generation complete.");
	return rootNode;
    }

    private Node buildRootNode(double[][] dataset, int labelColIDx){

	Node rootNode = new Node();
	rootNode.dataset = dataset;
	rootNode.dataPointIDXs = new ArrayList<Integer>();
	for(int counter = 0; counter < dataset.length; counter++){
	    rootNode.dataPointIDXs.add(counter);
	}
	rootNode.attributeIDx = -1;
	rootNode.depth = 0;
	setLabelsDist(rootNode, labelColIDx);
	setNodeEntropy(rootNode, labelColIDx, dataset.length);
	setNodePrediction(rootNode, labelColIDx);
	return rootNode;
    }

    // generates the distribution of class labels in this node
    // Invariant: Assumes dataset, dataPointIDXs property is already set
    private void setLabelsDist(Node nd, int labelColIDx){
	HashMap<Integer, Integer> nodeClasses = new HashMap<Integer, Integer>();
	for(int index : nd.dataPointIDXs){
	    int dataPointClass = (int) nd.dataset[index][labelColIDx];
	    if(nodeClasses.containsKey(dataPointClass)){
		int currentLabelCount = nodeClasses.get(dataPointClass);
		//increment the count if it already exists
		nodeClasses.put(dataPointClass, currentLabelCount+1);
	    } else{
		nodeClasses.put(dataPointClass, 1);//add key for the first time
	    }
	}
	nd.labelsDist = nodeClasses;
    }

    // calculates entropy using formula:
    // E = SUM[P(Xi) * SUM[P(Yj|Xi) * log2(1/P(Yj|Xi))] ]
    // Invariant:Assumes node properties dataPointIDXs,prediction,dataset are set
    public void setNodeEntropy(Node currentNode, int labelColIDx, 
			       int totalRecordCount){
	double dataPointCount = currentNode.dataPointIDXs.size();
	if(dataPointCount  == 0){
	    currentNode.entropy = 0.0; return ;//no data, no error
	}
	Set<Integer> keys = currentNode.labelsDist.keySet();
	double sumOfEntropyOfLabels = 0.0;
	for(int key : keys){
	    int labelCount = currentNode.labelsDist.get(key);
	    double fractionOfNodeRecords = (double) labelCount / dataPointCount;
	    if(labelCount > 0){ //otherwise this label doesn't contribute
		double logFraction = ( Math.log( (1/fractionOfNodeRecords) ) / 
				      Math.log(2) );
		sumOfEntropyOfLabels += (fractionOfNodeRecords * logFraction);
	    }
	}
	//scale the entropy based on no of records at current node
	currentNode.entropy = ( (dataPointCount/totalRecordCount) * 
				sumOfEntropyOfLabels );
    }

    // Prediction = majority label of data points of current node
    // Invariant: Assumes properties dataPointIDXs, dataset, labelsDist are set
    public void setNodePrediction(Node currentNode, int labelColIDx){
	int dataPointCount = currentNode.dataPointIDXs.size();
	if(dataPointCount == 0) return;
	int maxCount = 0; int maxValueKey = -1;// TBD : change default value
	Set<Integer> keys = currentNode.labelsDist.keySet();
	for(int key : keys){
	    if(currentNode.labelsDist.get(key) > maxCount){
		maxCount = currentNode.labelsDist.get(key);
		maxValueKey = key;
	    }
	}
	// set the class label with maximum data points as the node label
	currentNode.prediction = maxValueKey;
    }
    
    // recursively builds sub tree from the root node
    // Invariant: The root node has been added to the queue
    // Invariant: root node's dataset,dataPointIDXs,entropy,depth,prediction are set
    public void buildSubTree(LinkedList<Node> queue, int maxDepth, 
			     int totalRecordCount, int labelColIDx,
			     ArrayList<SplitOption> possibleSplits){
    
	if(queue.isEmpty()) return;// end recursion if no more nodes
	Node head = queue.pop();//otherwise get the head node
	Log.write("Working at tree level : " + head.depth);
	//recurse on other nodes but not this branch 
	//if it is a leaf or maxDepth is reached
	if( (head.depth == maxDepth) || (isNodeALeaf(head, labelColIDx)) ) {
	    Log.write("Leaf node found..");
	    buildSubTree(queue, maxDepth, totalRecordCount, labelColIDx,
			 possibleSplits);
	    return;
	}
	int nodeRecordCount = head.dataPointIDXs.size();
	double splitMinentropy = head.entropy; 
	ArrayList<Node> minentropyNodes = null; int minentropyIndex = -1;
	for(int index = 0; index < possibleSplits.size(); index++){
	    ArrayList<Node> nodes = splitOptionToNodes(possibleSplits.get(index),
						       head, labelColIDx);
	    double splitTotalentropy = getSplitentropy(nodes);
	    if(splitTotalentropy < splitMinentropy){// if equal then ignore
		splitMinentropy = splitTotalentropy;
		minentropyNodes = nodes;
		minentropyIndex = index;
	    }// if condition ends here
	}
	// don't recurse on this branch if it is not improving information gain
	if(splitMinentropy == head.entropy){
	    Log.write("No more entropy improvement...");
	    buildSubTree(queue, maxDepth, totalRecordCount, labelColIDx,
			 possibleSplits);
	    return;
	}
	possibleSplits.remove(minentropyIndex);//disallow the same split in children
	head.subTree = minentropyNodes;
	// recurse on children
	queue.addAll(minentropyNodes);
	buildSubTree(queue, maxDepth, totalRecordCount, labelColIDx,
		     possibleSplits);
    }
    
    // return true if node is leaf; false otherwise
    // node is a leaf if 
    // ............1. All data points have same label or
    // ............2. entropy is zero or
    private boolean isNodeALeaf(Node current, int labelColIDx){
	if(current.entropy == 0.000000) return true;
	if(current.dataPointIDXs.size() == 0) return true;
	double firstItemLabel = 
	    current.dataset[current.dataPointIDXs.get(0)][labelColIDx];
	for(int counter = 1; counter< current.dataPointIDXs.size(); counter++){
	    int index = current.dataPointIDXs.get(counter);
	    if(current.dataset[index][labelColIDx] != firstItemLabel)
		return false;//not a leaf if any label is different 
	}
	return true;//a leaf if code reached this point as all labels are same
    }

    private ArrayList<Node> splitOptionToNodes(SplitOption so, Node parent,
					       int labelColIDx){
	ArrayList<Node> nodes = null;
	if(so.type == 't'){
	    nodes = numericOptToNodes(so, parent);
	} else if(so.type == 'c'){
	    nodes = nominalOptToNodes(so, parent);
	} else {
	    return nodes;//returns null object
	}
	for(Node nd: nodes){
	    setNodeDataPoints(nd, parent);
	    setLabelsDist(nd, labelColIDx);
	    setNodeEntropy(nd, labelColIDx, parent.dataset.length);
	    setNodePrediction(nd, labelColIDx);
	}
	return nodes;
    }
    
    private ArrayList<Node> numericOptToNodes(SplitOption so, Node parent){
	ArrayList<Node> nodes = new ArrayList<Node>();
	nodes.add(new Node());
	nodes.add(new Node());
	//set symbols
	nodes.get(0).symbol = "l";//first node is for less than
	nodes.get(1).symbol = "ge";//second node is for greater than equal to
	// set common properties
	for(Node nd : nodes){
	    nd.attributeIDx = so.attributeIDx;
	    nd.attributeValue = so.value;
	    nd.depth = parent.depth+1;
	    nd.dataset = parent.dataset;
	}
	return nodes;
    }
    
    private ArrayList<Node> nominalOptToNodes(SplitOption so, Node parent){
	ArrayList<Node> nodes = new ArrayList<Node>();
	for(double classVal : so.classes){
	    Node nd = new Node();
	    nd.attributeIDx = so.attributeIDx;
	    nd.attributeValue = classVal;
	    nd.depth = parent.depth+1;
	    nd.dataset = parent.dataset;
	    nd.symbol = "eq";
	    nodes.add(nd);
	}
	return nodes;
    }
    // sets data points applicable to node's criteria
    // Invariant: Assumes dataset, attributeIDx, attributeValue, symbol are set
    private void setNodeDataPoints(Node nd, Node parent){
	ArrayList<Integer> dataPointIDXs = new ArrayList<Integer>();
	//for(int counter = 0; counter < nd.dataset.length; counter++){
	//iterate over parent node's subset only
	for(int counter : parent.dataPointIDXs){
	    double atrValueInDataset = nd.dataset[counter][nd.attributeIDx];
	    if(matchesNodeCriteria(nd, atrValueInDataset))
		dataPointIDXs.add(counter);
	}
	nd.dataPointIDXs = dataPointIDXs;
    }

    private boolean matchesNodeCriteria(Node nd, double atrValueInDataset){
	if(nd.symbol == "le")
	    return (atrValueInDataset <= nd.attributeValue);
	if(nd.symbol == "l")
	    return (atrValueInDataset < nd.attributeValue);
	if(nd.symbol == "g")
	    return (atrValueInDataset > nd.attributeValue);
	if(nd.symbol == "ge")
	    return (atrValueInDataset >= nd.attributeValue);
	if(nd.symbol == "eq")
	    return (atrValueInDataset == nd.attributeValue);
	return false; //default value
    }

    // returns total entropy of a split using the formula below:
    // splitentropy = sum of entropies of nodes
    private double getSplitentropy(ArrayList<Node> splitNodes){
	double splitEntropy = 0.0;
	for(Node nd : splitNodes){
	    splitEntropy += nd.entropy;
	}
	return splitEntropy;
    }

    //-------------------------------------------------------------------------
    // prediction functions

    // given testset, model and normalization arrays,
    // returns the 2D array containing columns : actual value | predicted value
    public double[][] predict(double[][] testset, int labelColIDx, 
			      int[] numericColIDXs, Node modelRootNode,
			      double[] maxValues, double[] minValues){
	double[][] predictions = new double[testset.length][2];
	for(int rowCounter = 0; rowCounter < testset.length; rowCounter++){
	    double[] row = testset[rowCounter];
	    //normalizeRecord(row, numericColIDXs, maxValues, minValues);
	    double prediction = predictLabel(row, modelRootNode);
	    //Log.write("Prediction for row :" + 
	    //	      rowCounter + " = " + prediction);
	    predictions[rowCounter][0] = row[labelColIDx];
	    predictions[rowCounter][1] = prediction;
	}
	return predictions;
    }

    private void normalizeRecord(double[] record, int[] numericColIDXs, 
				 double[] maxValues, double[] minValues){
	DataProcessor processor = new DataProcessor();
	processor.normalizeRecord(record, numericColIDXs, maxValues, minValues);
    }
    
    private double predictLabel(double[] row, Node nd){
	if(nd.subTree == null) return nd.prediction;
	for(Node child : nd.subTree){
	    if(matchesNodeCriteria(child, row[child.attributeIDx]))
		return predictLabel(row, child);//recurse on child
	}
	return -1000.00;// return error value for unseen data
    }

    public double computeMSE(double[][] predictions){
	if(predictions.length == 0) return 0.0;// no data, no error
	double sumOfSquareDiff = 0.0; 
	for(int rowCounter = 0; rowCounter < predictions.length; rowCounter++){
	    double diff = predictions[rowCounter][0] -predictions[rowCounter][1];
	    sumOfSquareDiff += (diff * diff);
	}
	return (sumOfSquareDiff / predictions.length);
    }
    //------------------------------------------------------------------------
    // sub classes

class Node{

    public double[][] dataset; // pointer to whole dataset
    public ArrayList<Integer> dataPointIDXs; // subset of data points at node
    public double entropy;
    public double prediction;
    public HashMap<Integer, Integer> labelsDist; // distribution of class lables
    public int depth;//depth needs to be a node's property
    public int attributeIDx;//-1 for root node
    public double attributeValue;
    public String symbol; //could be one of le(<=), l(<), g(>), ge(>=) or eq(=)
    public ArrayList<Node> subTree;

    public void printNodeAndSubTree(){
	String indentation = "";
	for(int counter =0 ; counter < depth; counter++){
	    indentation = indentation + "|\t";
	}
	String nodeText = "";
	nodeText = "( X[" + attributeIDx ;
	nodeText = nodeText + "]" ;
	nodeText = appendSymbol(symbol, nodeText);
	nodeText = nodeText + attributeValue + " ; entropy = "+ entropy;
	nodeText = nodeText + " ; Count = "+ dataPointIDXs.size();
	if(subTree == null){
	    nodeText = nodeText + " ; Label = " + prediction;
	    nodeText = nodeText + " )";
	    System.out.println(indentation + nodeText);
	} else{
	    nodeText = nodeText + " )";
	    System.out.println(indentation + nodeText);
	    for(Node nd : subTree){
		nd.printNodeAndSubTree();
	    }
	}
    }
    
    private String appendSymbol(String symbol, String text){
	if(symbol == "le") 
	    text = text + " <= ";
	if(symbol == "l")
	    text = text +" < ";
	if(symbol == "g")
	    text = text +" > ";
	if(symbol == "ge")
	    text = text +" >= ";
	if(symbol == "eq")
	    text = text +" = ";
	return text; //don't append if anything else

    }
}

}

