import java.util.*;
import Jama.*;

public class EntryPoint{

    //required objects
    static DataReader reader = null;

    //Dataset constants
    static final String HOUSING_DATASET_PATH = "/home/rahul/data/src/ml/dataset/hw1/housing/housing_train.txt";
    static final String HOUSING_TESTSET_PATH = "/home/rahul/data/src/ml/dataset/hw1/housing/housing_test.txt";

    //static final String HOUSING_DATASET_PATH = "/home/rahul/data/src/ml/dataset/hw1/housing/housing_test.txt";
    static final int HOUSING_ROW_COUNT = 433;
    static final int HOUSING_TEST_ROW_COUNT = 74;
    static final int HOUSING_COL_COUNT = 14;
    static final String HOUSING_DATA_TYPE = "SSV";
    static final int[] HOUSING_NUMERIC_COL_IDXS = {0,1,2,4,5,6,7,8,9,10,11,12};//started with 0 for convenience
    static final int[] HOUSING_NOMINAL_COL_IDXS = {3};//0 based
    static final int HOUSING_LABEL_COL_IDX = 13;// 0 based

    static final String SPAMBASE_DATASET_PATH = "/home/rahul/data/src/ml/dataset/hw1/spam/spambase.data";
    //static final String HOUSING_DATASET_PATH = "/home/rahul/data/src/ml/dataset/hw1/housing/housing_test.txt";
    static final int SPAMBASE_ROW_COUNT = 4601;
    static final int SPAMBASE_COL_COUNT = 58;
    static final String SPAMBASE_DATA_TYPE = "CSV";
    static final int[] SPAMBASE_NUMERIC_COL_IDXS = {0,1,2,3,4,5,6,7,8,9,10,
            11,12,13,14,15,16,17,18,19,20,
            21,22,23,24,25,26,27,28,29,30,
            31,32,33,34,35,36,37,38,39,40,
            41,42,43,44,45,46,47,48,49,50,
            51,52,53,54,55,56};//started with 0 for convenience
    static final int[] SPAMBASE_NOMINAL_COL_IDXS = {};//0 based
    static final int SPAMBASE_LABEL_COL_IDX = 57;// 0 based

    public static void main(String[] args) throws Exception {
        try {
            int dbSelector = Integer.parseInt(args[0]);
            reader = new DataReader();
            System.out.println("");
            double[][] data = null;
            if(dbSelector == 0 ){
                data = reader.readFile(HOUSING_DATASET_PATH,
                        HOUSING_DATA_TYPE,
                        HOUSING_ROW_COUNT,
                        HOUSING_COL_COUNT);
            } else {
                data = reader.readFile(SPAMBASE_DATASET_PATH,
                        SPAMBASE_DATA_TYPE,
                        SPAMBASE_ROW_COUNT,
                        SPAMBASE_COL_COUNT);

            }
            DataProcessor processor = new DataProcessor();

            // processor.normalizeNumericCols()

            if( dbSelector == 0){
                Matrix X = processor.getFeatureMatrix(data,HOUSING_LABEL_COL_IDX);
                Matrix Y = processor.getLabelVector(data,HOUSING_LABEL_COL_IDX);
                Matrix W = processor.computeWeightVector(X, Y);

                double[][] preds = processor.computePrediction(X, W);
                double MSE = processor.computeMSE(preds,
                        processor.getLabelArray(data,
                                HOUSING_LABEL_COL_IDX));
                Log.write("MSE for training set is: "+ MSE);

                double[][] testdata = reader.readFile(HOUSING_TESTSET_PATH,
                        HOUSING_DATA_TYPE,
                        HOUSING_TEST_ROW_COUNT,
                        HOUSING_COL_COUNT);

                Matrix XTest = processor.getFeatureMatrix(testdata,HOUSING_LABEL_COL_IDX);
                double[][] predsTest = processor.computePrediction(XTest, W);
                double testMSE = processor.computeMSE(predsTest,
                        processor.getLabelArray(testdata,
                                HOUSING_LABEL_COL_IDX));
                Log.write("MSE for test set is: "+ testMSE);

            } else {
                int foldCount = 10;
                int foldRecordCount = data.length / foldCount;
                int allFoldsRecordCount = 0; double testMSESum = 0.0; double trainMSESum = 0.0;
                ArrayList<double[][]> folds = null;

                for(int k = 0; k< foldCount; k++){
                    int testStartIDx = foldRecordCount * k;
                    Log.write("");
                    Log.write("Test set Start Index: "+ testStartIDx);
                    int testEndIDx = testStartIDx;
                    if(k == foldCount - 1){
                        testEndIDx += data.length - allFoldsRecordCount;
                        Log.write("Test set End Index: "+ testEndIDx);
                    } else{
                        testEndIDx += foldRecordCount;
                        Log.write("Test set End Index: "+ (testEndIDx - 1));
                    }

                    allFoldsRecordCount += foldRecordCount;
                    folds = processor.splitTrainTest(data, testStartIDx, testEndIDx);
                    double[][] trainset = folds.get(0);
                    double[][] testset = folds.get(1);

                    Matrix X = processor.getFeatureMatrix(trainset, SPAMBASE_LABEL_COL_IDX);
                    Matrix Y = processor.getLabelVector(trainset, SPAMBASE_LABEL_COL_IDX);
                    Matrix W = processor.computeWeightVector(X, Y);

                    double[][] trainpreds = processor.computePrediction(X, W);

                    double trainMSE = processor.computeMSE(trainpreds,
                            processor.getLabelArray(data,
                                    SPAMBASE_LABEL_COL_IDX));
                    Log.write("MSE for training set is: "+ trainMSE);
                    double[][] trainPredsForROC =
                            transformPredsForROC(trainpreds,
                                    processor.getLabelArray(data,
                                            SPAMBASE_LABEL_COL_IDX));

                    PredictorStatsCalculator.computeROCStats
                            (trainPredsForROC,
                                    null,
                                    new int[]{trainpreds.length},
                                    false,true, false);
                    Matrix XTest = processor.getFeatureMatrix(testset, SPAMBASE_LABEL_COL_IDX);
                    double[][] predsTest = processor.computePrediction(XTest, W);
                    double testMSE = processor.computeMSE(predsTest,
                            processor.getLabelArray(testset,
                                    SPAMBASE_LABEL_COL_IDX));
                    Log.write("MSE for test set is: "+ testMSE);
                    double[][] testPredsForROC =
                            transformPredsForROC(predsTest,
                                    processor.getLabelArray(testset,
                                            SPAMBASE_LABEL_COL_IDX));


                    PredictorStatsCalculator.computeROCStats
                            (testPredsForROC,
                                    null,
                                    new int[]{predsTest.length},
                                    false,true, false);
                    testMSESum += testMSE;
                    trainMSESum += trainMSE;

                }
                Log.write("");
                Log.write("");
                Log.write("Average test MSE of "+ foldCount + " runs is = "+ (testMSESum/ foldCount));
                Log.write("Average train MSE of "+ foldCount + " runs is = "+ (trainMSESum/ foldCount));
            }

            //Double P = processor.computePrediction(W, testDataPoint);

            //-------------------------------------------------------
            //test statements----------------------------------------
            //Log.writeToFile(data, "train.csv", ",");
            System.out.println("");

            //System.out.print(data[432][2]); // dataset data
            System.out.println("");
        } catch(Exception e){
            System.out.println("Runtime Error occurred: ");
            System.out.println(e);
            throw e;
        }
    }

    private static double[][] transformPredsForROC(double[][] preds,
                                                   double[] labelArray){
        double[][] output = new double[preds.length][2];
        for(int rowCounter = 0; rowCounter < preds.length; rowCounter++){
            output[rowCounter][0] = labelArray[rowCounter];
            output[rowCounter][1] = preds[rowCounter][0];
        }
        return output;
    }

}

