import java.util.*;

public class RegressionTree{

    public Node buildModel(double[][] dataset, int maxDepth, 
			   ArrayList<SplitOption> possibleSplits,
			   int labelColIDx){
	Log.write("Model generation started.");
	Node rootNode = buildRootNode(dataset, labelColIDx);
	Log.write("Root node built....");
	LinkedList<Node> treeQueue = new LinkedList<Node>();
	treeQueue.add(rootNode);//to get the recursion started
	Log.write("Subtree generation started...");
	buildSubTree(treeQueue, maxDepth, dataset.length, labelColIDx,
		     possibleSplits);
	Log.write("Model generation complete.");
	return rootNode;
    }
    
    private Node buildRootNode(double[][] dataset, int labelColIDx){

	Node rootNode = new Node();
	rootNode.dataset = dataset;
	rootNode.dataPointIDXs = new ArrayList<Integer>();
	for(int counter = 0; counter < dataset.length; counter++){
	    rootNode.dataPointIDXs.add(counter);
	}
	rootNode.attributeIDx = -1;
	rootNode.depth = 0;
	setNodePrediction(rootNode, labelColIDx);
	setNodeRMSE(rootNode, labelColIDx);
	return rootNode;
    }
    
    // calculates RMSE using formula:
    // Em = for all t∈χm SUM[(yt − gm)2]/ |χm|; where gm is node prediction
    // Invariant:Assumes node properties dataPointIDXs,prediction,dataset are set
    public void setNodeRMSE(Node currentNode, int labelColIDx){
	int dataPointCount = currentNode.dataPointIDXs.size();
	if(dataPointCount  == 0){
	    currentNode.RMSE = 0.0; return ;//no data, no error
	}
	double sumOfSquareDiff = 0.0;
	for(int index: currentNode.dataPointIDXs){
	    double diff = currentNode.prediction - 
		currentNode.dataset[index][labelColIDx];
	    sumOfSquareDiff += (diff * diff);
	}
	// count > 0 if code reached this point
	//currentNode.RMSE  = sumOfSquareDiff / dataPointCount;
	currentNode.RMSE  = sumOfSquareDiff;
    
    }

    // calculates prediction using formula:
    // Em = for all t∈χm SUM[yt]/ |χm|; where yt is the value of label column
    // Invariant: Assumes node properties dataPointIDXs, dataset are set
    public void setNodePrediction(Node currentNode, int labelColIDx){
	double sum = 0.0;
	int dataPointCount = currentNode.dataPointIDXs.size();
	if(dataPointCount == 0){
	    currentNode.prediction = 0; return; 
	}
	for(int index : currentNode.dataPointIDXs){
	    sum += currentNode.dataset[index][labelColIDx];
	}
	// count > 0 if code reached till this point
	currentNode.prediction =  (sum / dataPointCount);

    }
    
    // recursively builds sub tree from the root node
    // Invariant: The root node has been added to the queue
    // Invariant: root node's dataset,dataPointIDXs,RMSE,depth,prediction are set
    public void buildSubTree(LinkedList<Node> queue, int maxDepth, 
			     int totalRecordCount, int labelColIDx,
			     ArrayList<SplitOption> possibleSplits){
    
	if(queue.isEmpty()) return;// end recursion if no more nodes
	Node minMSENode = null;	double splitMinRMSE = 100000.0; 
	ArrayList<Node> minMSESubTree = null;
	int minMSESOIndex = -1; 
	for(Node head : queue) {
	    //recurse on other nodes but not this branch 
	    //if it is a leaf or maxDepth is reached
	    if( (head.depth == maxDepth) || (isNodeALeaf(head, labelColIDx)) )
		continue;//skip this node if it is already a leaf
	    Log.write("Working at tree level : " + head.depth);
	    int nodeRecordCount = head.dataPointIDXs.size();
	    for(int index = 0; index < possibleSplits.size(); index++){
		ArrayList<Node> nodes = splitOptionToNodes(possibleSplits.get(index),
							   head, labelColIDx);
		if(isUnfit(nodes, head))
			continue;
		double splitTotalRMSE = getSplitRMSE(nodes, labelColIDx);
		//		if( (splitTotalRMSE < splitMinRMSE) && 
		//  (splitTotalRMSE < head.RMSE) ){// if equal then ignore
		if(splitTotalRMSE < splitMinRMSE){
		    splitMinRMSE = splitTotalRMSE;
		    minMSESubTree = nodes;
		    minMSENode = head;
		    minMSESOIndex = index;
		}// if condition ends here
	    }

	}// for loop ends here
	if (minMSENode == null)
	    return ; //no more growth possible
	queue.remove(minMSENode);// remove node from next iteration
	possibleSplits.remove(minMSESOIndex);//disallow the same split in children
	minMSENode.subTree = minMSESubTree;
	// recurse on children
	queue.addAll(minMSESubTree);
	buildSubTree(queue, maxDepth, totalRecordCount, labelColIDx,
		     possibleSplits);
    }

    // recursively builds sub tree from the root node
    // Invariant: The root node has been added to the queue
    // Invariant: root node's dataset,dataPointIDXs,RMSE,depth,prediction are set
    public void buildSubTreeBFS(LinkedList<Node> queue, int maxDepth, 
			     int totalRecordCount, int labelColIDx,
			     ArrayList<SplitOption> possibleSplits){
    
	if(queue.isEmpty()) return;// end recursion if no more nodes
	Node head = queue.pop();//otherwise get the head node
	Log.write("Working at tree level : " + head.depth);
	//recurse on other nodes but not this branch 
	//if it is a leaf or maxDepth is reached
	if( (head.depth == maxDepth) || (isNodeALeaf(head, labelColIDx)) ) {
	    Log.write("Leaf node found..");
	    buildSubTreeBFS(queue, maxDepth, totalRecordCount, labelColIDx,
			 possibleSplits);
	    return;
	}
	int nodeRecordCount = head.dataPointIDXs.size();
	double splitMinRMSE = head.RMSE; 
	ArrayList<Node> minRMSENodes = null; int minRMSEIndex = -1;
	for(int index = 0; index < possibleSplits.size(); index++){
	    ArrayList<Node> nodes = splitOptionToNodes(possibleSplits.get(index),
						       head, labelColIDx);
	    //if(isUnfit(nodes, head))
	    //	continue;
	    double splitTotalRMSE = getSplitRMSE(nodes, labelColIDx);
	    if(splitTotalRMSE < splitMinRMSE){// if equal then ignore
		splitMinRMSE = splitTotalRMSE;
		minRMSENodes = nodes;
		minRMSEIndex = index;
	    }// if condition ends here
	}
	// don't recurse on this branch if it is not improving the RMSE
	if(splitMinRMSE >= head.RMSE){
	    Log.write("No more RMSE improvement...");
	    buildSubTreeBFS(queue, maxDepth, totalRecordCount, labelColIDx,
			 possibleSplits);
	    return;
	}
	possibleSplits.remove(minRMSEIndex);//disallow the same split in children
	head.subTree = minRMSENodes;
	// recurse on children
	queue.addAll(minRMSENodes);
	buildSubTreeBFS(queue, maxDepth, totalRecordCount, labelColIDx,
		     possibleSplits);
    }

    private boolean isUnfit(ArrayList<Node> nodes, Node head){
	boolean bigRMSENode = false;
	boolean smallSubset = false;
	for(Node nd: nodes){
	    if(nd.RMSE >= head.RMSE)
	    	bigRMSENode = true;
	    if(nd.dataPointIDXs.size() <= 5) 
		smallSubset = true;
	}
	if(bigRMSENode || smallSubset)
	    return true;
	return false;//default value
    }

    // return true if node is leaf; false otherwise
    // node is a leaf if 
    // ............1. All data points have same label or
    // ............2. RMSE is zero or
    private boolean isNodeALeaf(Node current, int labelColIDx){
	if(current.RMSE == 0.0) return true;
	if(current.dataPointIDXs.size() == 0) return true;
	double firstItemLabel = 
	    current.dataset[current.dataPointIDXs.get(0)][labelColIDx];
	for(int counter = 1; counter< current.dataPointIDXs.size(); counter++){
	    int index = current.dataPointIDXs.get(counter);
	    if(current.dataset[index][labelColIDx] != firstItemLabel)
		return false;//not a leaf if any label is different 
	}
	return true;//a leaf if code reached this point as all labels are same
    }

    private ArrayList<Node> splitOptionToNodes(SplitOption so, Node parent,
					       int labelColIDx){
	ArrayList<Node> nodes = null;
	if(so.type == 't'){
	    nodes = numericOptToNodes(so, parent);
	} else if(so.type == 'c'){
	    nodes = nominalOptToNodes(so, parent);
	} else {
	    return nodes;//returns null object
	}
	for(Node nd: nodes){
	    setNodeDataPoints(nd, parent);
	    setNodePrediction(nd, labelColIDx);
	    setNodeRMSE(nd, labelColIDx);
	}
	return nodes;
    }
    
    private ArrayList<Node> numericOptToNodes(SplitOption so, Node parent){
	ArrayList<Node> nodes = new ArrayList<Node>();
	nodes.add(new Node());
	nodes.add(new Node());
	//set symbols
	nodes.get(0).symbol = "l";//first node is for less than
	nodes.get(1).symbol = "ge";//second node is for greater than equal to
	// set common properties
	for(Node nd : nodes){
	    nd.attributeIDx = so.attributeIDx;
	    nd.attributeValue = so.value;
	    nd.depth = parent.depth+1;
	    nd.dataset = parent.dataset;
	}
	return nodes;
    }
    
    private ArrayList<Node> nominalOptToNodes(SplitOption so, Node parent){
	ArrayList<Node> nodes = new ArrayList<Node>();
	for(double classVal : so.classes){
	    Node nd = new Node();
	    nd.attributeIDx = so.attributeIDx;
	    nd.attributeValue = classVal;
	    nd.depth = parent.depth+1;
	    nd.dataset = parent.dataset;
	    nd.symbol = "eq";
	    nodes.add(nd);
	}
	return nodes;
    }
    // sets data points applicable to node's criteria
    // Invariant: Assumes dataset, attributeIDx, attributeValue, symbol are set
    private void setNodeDataPoints(Node nd, Node parent){
	ArrayList<Integer> dataPointIDXs = new ArrayList<Integer>();
	//for(int counter = 0; counter < nd.dataset.length; counter++){
	//iterate over parent node's subset only
	for(int counter : parent.dataPointIDXs){
	    double atrValueInDataset = nd.dataset[counter][nd.attributeIDx];
	    if(matchesNodeCriteria(nd, atrValueInDataset))
		dataPointIDXs.add(counter);
	}
	nd.dataPointIDXs = dataPointIDXs;
    }

    private boolean matchesNodeCriteria(Node nd, double atrValueInDataset){
	if(nd.symbol == "le")
	    return (atrValueInDataset <= nd.attributeValue);
	if(nd.symbol == "l")
	    return (atrValueInDataset < nd.attributeValue);
	if(nd.symbol == "g")
	    return (atrValueInDataset > nd.attributeValue);
	if(nd.symbol == "ge")
	    return (atrValueInDataset >= nd.attributeValue);
	if(nd.symbol == "eq")
	    return (atrValueInDataset == nd.attributeValue);
	return false; //default value
    }

    // returns total RMSE of a split using the formula below:
    // splitRMSE = SUM of squared differences of error of all nodes in split/
    // ........... total number of records in all the nodes
    private double getSplitRMSE(ArrayList<Node> splitNodes, int labelColIDx){
	double sumOfSquareDiff = 0.0;
	int totalRecordCount = 0;
	for(Node nd : splitNodes){
	    totalRecordCount += nd.dataPointIDXs.size();
	    for(int index: nd.dataPointIDXs){
		double diff = nd.prediction - nd.dataset[index][labelColIDx];
		sumOfSquareDiff += (diff * diff);
	    }// node data point iteration ends here
	}// all nodes iteration ends here
	if(totalRecordCount == 0) return 0; // no data, no error
	/*double totalRMSE = 0.0;
	double totalPoints = 0.0;
	for(Node nd: splitNodes){
	    totalRMSE += nd.RMSE;
	    totalPoints += nd.dataPointIDXs.size();
	}
	return totalRMSE / totalPoints;*/
	//return (sumOfSquareDiff/ totalRecordCount);
	return sumOfSquareDiff;

    }


    //-------------------------------------------------------------------------
    // prediction functions

    // given testset, model and normalization arrays,
    // returns the 2D array containing columns : actual value | predicted value
    public double[][] predict(double[][] testset, int labelColIDx, 
			      int[] numericColIDXs, Node modelRootNode,
			      double[] maxValues, double[] minValues){
	double[][] predictions = new double[testset.length][2];
	for(int rowCounter = 0; rowCounter < testset.length; rowCounter++){
	    double[] row = testset[rowCounter];
	    normalizeRecord(row, numericColIDXs, maxValues, minValues);
	    double prediction = predictLabel(row, modelRootNode);
	    //double prediction = predictMinLabel(row, modelRootNode, null);
	    if(prediction == -1000.00)
		Log.write("Noise/ unseen values. Row no.: "+ rowCounter +
			  "Default Prediction :"
			  + prediction);
	    //Log.write("Prediction for row :" + 
	    //  rowCounter + " = " + prediction);
	    predictions[rowCounter][0] = row[labelColIDx];
	    predictions[rowCounter][1] = prediction;
	}
	return predictions;
    }

    private void normalizeRecord(double[] record, int[] numericColIDXs, 
				 double[] maxValues, double[] minValues){
	DataProcessor processor = new DataProcessor();
	processor.normalizeRecord(record, numericColIDXs, maxValues, minValues);
    }
    
    private double predictLabel(double[] row, Node nd){
	if(nd.subTree == null) return nd.prediction;
	for(Node child : nd.subTree){
	    if(matchesNodeCriteria(child, row[child.attributeIDx]))
		return predictLabel(row, child);//recurse on child
	}
	return -1000.00;// return error value for unseen data
    }

    private double predictMinLabel(double[] row, Node nd, ArrayList<Double> preds){
	if(nd.subTree == null) return nd.prediction;
	if(preds == null)
	    preds = new ArrayList<Double>();
	for(Node child : nd.subTree){
	    double pred = 1000.0;//default value
	    if(matchesNodeCriteria(child, row[child.attributeIDx])){
		  pred = predictMinLabel(row, child, preds);//recurse on child
		  preds.add(pred);
	    }
	}
	if((preds == null) || (preds.size() == 0))
	    return 1000.00;// return error value for unseen data
	Collections.sort(preds);
	//Log.write("Predictions for current point: "+ preds.size());
	Log.writeList(preds);
	return preds.get(0); // return the minimum prediction
    }

    public double computeMSE(double[][] predictions){
	if(predictions.length == 0) return 0.0;// no data, no error
	double sumOfSquareDiff = 0.0; 
	for(int rowCounter = 0; rowCounter < predictions.length; rowCounter++){
	    double diff = predictions[rowCounter][0] - predictions[rowCounter][1];
	    //Log.write("Drifference for row: "+ rowCounter +" = " + diff);
	    sumOfSquareDiff += (diff * diff);
	}
	return (sumOfSquareDiff / predictions.length);
    }
    //------------------------------------------------------------------------
    // sub classes

class Node{

    public double[][] dataset; // pointer to whole dataset
    public ArrayList<Integer> dataPointIDXs; // subset of data points at node
    public double RMSE;
    public double prediction;
    public int depth;//depth needs to be a node's property
    public int attributeIDx;//-1 for root node
    public double attributeValue;
    public String symbol; //could be one of le(<=), l(<), g(>), ge(>=) or eq(=)
    public ArrayList<Node> subTree;

    public void printNodeAndSubTree(){
	String indentation = "";
	for(int counter =0 ; counter < depth; counter++){
	    indentation = indentation + "|\t";
	}
	String nodeText = "";
	nodeText = "( X[" + attributeIDx ;
	nodeText = nodeText + "]" ;
	nodeText = appendSymbol(symbol, nodeText);
	nodeText = nodeText + attributeValue + " ; MSE = "+ RMSE;
	nodeText = nodeText + " ; Count = "+ dataPointIDXs.size();
	if(subTree == null){
	    nodeText = nodeText + " ; Label = " + prediction;
	    nodeText = nodeText + " )";
	    System.out.println(indentation + nodeText);
	} else{
	    nodeText = nodeText + " )";
	    System.out.println(indentation + nodeText);
	    for(Node nd : subTree){
		nd.printNodeAndSubTree();
	    }
	}
    }
    
    private String appendSymbol(String symbol, String text){
	if(symbol == "le") 
	    text = text + " <= ";
	if(symbol == "l")
	    text = text +" < ";
	if(symbol == "g")
	    text = text +" > ";
	if(symbol == "ge")
	    text = text +" >= ";
	if(symbol == "eq")
	    text = text +" = ";
	return text; //don't append if anything else

    }
}

}

