import java.util.ArrayList;

/**
 * Represents a Neural Network characterized by I/P and O/P Layers.
 */
public class NeuralNetwork {
    //public global variables
    public Node[] ipLayer;
    public Node[] opLayer;

    //private global variables
    private ArrayList<Node[]> hiddenLayers;
    private Instance[] instances;
    private double learnRate;
    private double[] learnRates;
    private int[] convergenceVector = new int[] {0,0,0,0,0,0,0,0};
    //global constants
    private static final double INIT_WEIGHT = 0.0;
    private static final double INIT_LEARN_RATE = 1.2;
    private static final boolean USE_RANDOM_WEIGHTS = true;
    private static final boolean USE_LEARN_VECTOR = false;
    private static final double WEIGHT_MAX = 1.0;
    private static final double WEIGHT_MIN = -1.0;
    private static final double ERROR_CONST = 0.5;
    private static final double ROUND_OFF_CONST = 0.01;
    private static final int MAX_ITERATIONS = 5;
    private static final double ERROR_TARGET = 0.0;

    public NeuralNetwork(int ipLayerSize, int opLayerSize,
                         int[] hiddenLayerSizes,
                         Instance[] dataInstances){
        ipLayer = new Node[ipLayerSize];
        opLayer = new Node[opLayerSize];
        hiddenLayers = new ArrayList<Node[]>();
        instances = dataInstances;
        for(int size : hiddenLayerSizes)
            hiddenLayers.add(new Node[size]);
    }

    public void learnNetwork(){
        initializeNetwork();
        int iterationCount = 1;
        double error = 0.0;
        do{
            Log.write("Iteration: " + iterationCount);
            int instanceCounter = 0;
            //using incremental gradient descent
            for(Instance dataPoint : instances){
                applyNetwork(dataPoint);
                error = computeError();
                //Log.write("Error after this instance: " + error);
                printOutputs();
                convergenceVector[instanceCounter] = getConvergenceCount();
                computeGradients();
                updateWeights();
                instanceCounter++;
                //printLayer(hiddenLayers.get(0));
            }
            iterationCount++;//one iteration is going over all the data points
            updateLearnRate();//update learn rate after each iteration
            printNetwork();
            if(iterationCount == MAX_ITERATIONS) break;
        } while(!checkConvergence());
    }
    private void initializeNetwork(){
        initializeOutputLayer();
        initializeHiddenLayers();
        initializeInputLayer();
        initializeLearnRate();
    }

    private void applyNetwork(Instance dataPoint){
        setIPLayerInputs(dataPoint.ipVector);
        setOPLayerTargets(dataPoint.opVector);
        computeHiddenLayerOutputs();
        computOpLayerOutputs();
    }

    private double computeError(){
        //compute error on all output nodes
        double squaredErrorSum = 0.0;
        for(Node nd : opLayer){
            nd.error = nd.targetOutput - nd.actualOutput;
            squaredErrorSum += (nd.error * nd.error);
        }
        return (ERROR_CONST * squaredErrorSum);
    }

    private void printOutputs(){
        String outputs = "( ";
        for(Node nd: opLayer){
            if(nd.actualOutput > (1.0 - ROUND_OFF_CONST) ){
                outputs += "1.0, ";
                continue;
            }
            if(nd.actualOutput < (0.0 + ROUND_OFF_CONST) ){
                outputs += "0.0, ";
                continue;
            }
            outputs += nd.actualOutput + ", ";
        }
        Log.write("O/P s: "+ outputs);
    }
    private int getConvergenceCount(){
        int convergenceCounter = 0;
        for(Node nd: opLayer){
            if((nd.targetOutput - nd.actualOutput) < ROUND_OFF_CONST)
                convergenceCounter++;
        }
        Log.write("Nodes converged "+ convergenceCounter);
        return convergenceCounter;
    }
    private boolean checkConvergence(){
        for(int count : convergenceVector)
            if(count != opLayer.length) return false;//false even if one
            // instance did not converge
        return true;
    }
    private void computeGradients(){
        computeOpLayerGradients();
        computeHiddenLayersGradients();
    }

    private void updateWeights(){
        updateOpLayerWeights();
        updateHiddenLayerWeights();
    }
    private void updateOpLayerWeights(){
        updateLayerWeights(opLayer);
    }
    private void updateHiddenLayerWeights(){
        for(int idx = 0; idx < hiddenLayers.size(); idx++){
            Node[] layer = hiddenLayers.get(idx);
            updateLayerWeights(layer);
        }//hidden Layers iteration ends here
    }

    private void updateLayerWeights(Node[] layer){
        for(Node nd: layer){
            updateInputsWeights(nd);
        }
    }
    private void updateInputsWeights(Node nd){
        for(Edge e : nd.inputs){
            double weightDelta = (learnRate * nd.gradient * e.head.inputValue);
            e.weight += weightDelta;
        }
    }
    //gradient@O/pLayer nodes = error * output * (1 - output)
    private void computeOpLayerGradients(){
        for(Node nd: opLayer){
            nd.gradient =
                    ( nd.error *
                            nd.actualOutput *
                            ( 1 - nd.actualOutput) );
        }
    }

    //start gradient calculation from the top most hidden layer - the one
    // closer to the output and then keep going deeper
    private void computeHiddenLayersGradients(){
        for(int idx = hiddenLayers.size() - 1; idx >= 0; idx--){
            Node[] nodes = hiddenLayers.get(idx);
            computeHiddenLayerGradients(nodes);
        }//hidden layers iteration ends here
    }

    //gradient@HiddenLayer node = output * (1 - output) *
    // (weightedSumOfGradients of all downstream nodes)
    private void computeHiddenLayerGradients(Node[] layer){
        for(Node nd : layer){
            nd.gradient =
                    ( nd.actualOutput *
                            ( 1 - nd.actualOutput) *
                            getWeightedDownstreamGradientSum(nd) );
        }//node iteration ends here
    }
    private double getWeightedDownstreamGradientSum(Node nd){
        double weightedSum = 0.0;
        for(Edge e : nd.downStream){
            weightedSum += (e.tail.gradient * e.weight);
        }//downstream edge iteration ends here
        return weightedSum;
    }
    private void setIPLayerInputs(double[] ipVector){
        for(int idx = 0; idx < ipLayer.length; idx++){
            ipLayer[idx].inputValue = ipVector[idx];
        }
    }

    private void setOPLayerTargets(double[] targets){
        for(int idx = 0; idx < opLayer.length; idx++){
            opLayer[idx].targetOutput = targets[idx];
        }
    }
    private void computeHiddenLayerOutputs(){
        for(int idx = 0; idx < hiddenLayers.size(); idx++){
            Node[] nodes = hiddenLayers.get(idx);
            computeLayerOutput(nodes);
        }//hidden layer iteration ends here
    }

    private void computOpLayerOutputs(){
        computeLayerOutput(opLayer);
    }
    private void computeLayerOutput(Node[] nodes){
        for(Node nd: nodes){
            double weightedSum = getWeightedInputsSum(nd);
            nd.inputValue = weightedSum;
            //output is just the logistic transformation applied to input
            nd.actualOutput = applyLogisticTransformation(weightedSum);
        }//node iteration for one layer ends here
    }
    //sum up the dot product of weights and inputs from all input nodes
    private double getWeightedInputsSum(Node nd){
        double weightedSum = 0.0;
        for(Edge input : nd.inputs){
            weightedSum += (input.head.inputValue * input.weight);
        }
        return weightedSum;
    }
    private void initializeInputLayer(){
        Node[] downStreamNodes = opLayer;//default downstream layer
        if(hiddenLayers.size() > 0)
            downStreamNodes = hiddenLayers.get(0);
        for(int idx = 0; idx < ipLayer.length; idx++){
            Node nd = new Node();
            nd.type = 'i';
            nd.level = 0;
            nd.id = idx;
            nd.downStream = getDownStreamEdges(downStreamNodes, nd);
            ipLayer[idx] = nd;
        }
    }

    private void initializeHiddenLayers(){
        for(int hlIndex = hiddenLayers.size() - 1; hlIndex >= 0; hlIndex--){
            Node[] hiddenLayer = hiddenLayers.get(hlIndex);
            Node[] downStreamNodes = opLayer;//default downstream layer
            if(hlIndex != hiddenLayers.size() - 1)
                downStreamNodes = hiddenLayers.get(hlIndex + 1);
            for(int idx = 0; idx < hiddenLayer.length; idx++){
                Node nd = new Node();
                nd.type = 'h';
                nd.level = hlIndex + 1;// +1 because of input layer
                nd.id = idx;
                nd.downStream = getDownStreamEdges(downStreamNodes, nd);
                hiddenLayer[idx] = nd;
            }
        }
    }
    private void initializeOutputLayer(){
        int opLayerLevel = hiddenLayers.size() + 1;//hidden + input layers
        for(int idx = 0; idx < opLayer.length; idx++){
            Node nd = new Node();
            nd.type = 'o';
            nd.id = idx;
            nd.level = opLayerLevel;
            opLayer[idx] = nd;
        }
    }

    private void initializeLearnRate(){
        learnRate = INIT_LEARN_RATE;
    }
    private Edge[] getDownStreamEdges(Node[] nodes, Node head){
        Edge[] edges = new Edge[nodes.length];
        for(int index = 0; index < edges.length; index++){
            Edge ed = new Edge();
            ed.head = head;
            ed.tail = nodes[index];
            ed.weight = getWeight();
            edges[index] = ed;
            ed.tail.inputs.add(ed);
        }
        return edges;
    }
    private double getWeight(){

        if(USE_RANDOM_WEIGHTS){
            //get a random number in between MAX and MIN range
            double output = WEIGHT_MIN + (Math.random() *
                    (WEIGHT_MAX - WEIGHT_MIN));
            return output;
        } else{
            return INIT_WEIGHT;
        }
    }

    private void updateLearnRate(){
        learnRate = INIT_LEARN_RATE;
    }
    //returns 1/(1 + e^(-input))
    private double applyLogisticTransformation(double input){
        double negatedInput = input * (-1);
        double output = 1 / (1 + Math.pow(Math.E, negatedInput) );
        return output;
    }
    public void printNetwork(){
        printLayer(ipLayer);
        for(Node[] hl : hiddenLayers)
            printLayer(hl);
        printLayer(opLayer);
    }

    private void printLayer(Node[] layer){
        String layerText = "";
        for(Node nd : layer){
            layerText += "( " + nd.type + "-" + nd.level + nd.id;
            if(nd.type != 'i'){
                layerText += " ,a:" + nd.actualOutput;
                layerText += " ,g:" + nd.gradient;
                layerText += " ,e:" + nd.error;
            }
            if(nd.type == 'o'){
                layerText += " )\t";
                Log.write(layerText);
                layerText = "";//for now just printing nodes line by line
                continue;
            }
            int idx = 1;
            for(Edge e: nd.downStream){
                layerText += " ,w"+e.head.id+e.tail.id+":" + e.weight;
                idx++;
            }
            layerText += " )\t";
            Log.write(layerText);
            layerText = "";//for now just printing nodes line by line
        }
        Log.write("---------------------------------------------------------");
        Log.write("");

    }

    // Represents a node in Neural Network
    class Node{

        public char type;//could be one of i,h,o for I/P, hidden, O/P
        public int level;//indicates the level of the node in the network
        public int id;
        // starting from 0
        public ArrayList<Edge> inputs = new ArrayList<Edge>();
        public Edge[] downStream;//any nodes which receive input from
        // this node
        public double inputValue;//applicable only for node type i
        public double actualOutput;
        public double targetOutput;
        public double gradient;
        public double error;
    }

    class Edge{
        public Node head;
        public Node tail;
        public double weight;
    }


}

