import java.util.Arrays;

/**
 * Contains methods required to compute data related to ROC curves and other
 * statistics related to the predictor
 */
public class PredictorStatsCalculator {

    private static int THRESHOLD_COUNT = 100;//maximum number of thresholds
    private static int MIN_SUB_SAMPLES_COUNT = 50;
    /*
    ignores thresholds if autoThreshold is set to true
    if allThresholds is true then generates threshold for each prediction value
    ignores subSampleSizes array if autoSizeSubSamples is true
    */
    public static void computeROCStats(double[][] predictions,
                                       double[] thresholds,
                                       int[] subSampleSizes,
                                       boolean autoThreshold,
                                       boolean allThresholds,
                                       boolean autoSizeSubSamples){

        //transform predictions array
        Prediction[] preds = doubleToPredictions(predictions);
        //sort predictions array
        sortPredictions(preds);
        //if autothreshold -- generate threshold array
        if(autoThreshold) thresholds = generateThresholds(preds);
        if(allThresholds) thresholds = generateAllThresholds(preds);
        //if auto subsample size - generate sub sample sizes array
        if(autoSizeSubSamples){
            subSampleSizes = generateSubSampleSizes(preds.length);
        }
        String output = "";//statistics output
        output += "Threshold\t,K\t,TPR\t,FPR\n";
        //compute ROC statistics
        for(int thresholdIDx = 0; thresholdIDx < thresholds.length ;
            thresholdIDx++){
            double threshold = thresholds[thresholdIDx];
            for(int sampleSizeIDx = 0; sampleSizeIDx < subSampleSizes.length;
                sampleSizeIDx++){
                int sampleSize = subSampleSizes[sampleSizeIDx];
                ConfusionMatrix cMat = computeConfusionMatrix(preds,
                        threshold, sampleSize);
                double tpRate = computeTruePositiveRate(cMat);
                double fpRate = computeFalsePositiveRate(cMat);
                output += threshold + ", " + sampleSize + ", " + tpRate + ", " +
                        "" + fpRate + "\n";
            }//sub sample size iteration ends here
        }//main threshold loop ends here
        System.out.println(output);
        System.out.println("");
    }

    // Invariant: each row has actual value at 0 index and predicted value at 1
    private static Prediction[] doubleToPredictions(double[][] predictions){
        Prediction[] output = new Prediction[predictions.length];
        for(int idx = 0; idx < predictions.length; idx++){
            output[idx] = new Prediction();
            output[idx].actualValue = predictions[idx][0];
            output[idx].predictedValue = predictions[idx][1];
        }//main loop ends here
        return output;
    }
    //sorts predictions array in place
    private static void sortPredictions(Prediction[] predictions){
        Arrays.sort(predictions);
    }

    //computes set number of thresholds from given data array
    //Invariant: Predictions array is already sorted
    private static double[] generateThresholds(Prediction[] preds){
        int actualThresholdCount = THRESHOLD_COUNT;
        if(THRESHOLD_COUNT > preds.length) actualThresholdCount = preds.length;
        double[] thresholds = new double[actualThresholdCount];
        int thresholdInterval = preds.length / actualThresholdCount;
        for(int idx = 0; idx < thresholds.length; idx++){
            //set value at specific interval as threshold
            thresholds[idx] = preds[idx * thresholdInterval].predictedValue;
        }//main loop ends here
        return thresholds;
    }

    private static double[] generateAllThresholds(Prediction[] preds){
        double[] thresholds = new double[preds.length];
        for(int idx = 0; idx < thresholds.length; idx++){
            thresholds[idx] = preds[idx].predictedValue;
        }//main loop ends here
        return thresholds;
    }
    private static int[] generateSubSampleSizes(int recordCount){
        int subSampleSize = recordCount / MIN_SUB_SAMPLES_COUNT;
        int[] size = new int[MIN_SUB_SAMPLES_COUNT];
        for(int counter = 0; counter < MIN_SUB_SAMPLES_COUNT; counter++){
            if(subSampleSize < 1) subSampleSize = 1;//minimum size of sample
            if(subSampleSize > recordCount) subSampleSize = recordCount;
            size[counter] = subSampleSize;
            subSampleSize++;
        }//main loop ends here
        return size;
    }

    //Invariant : assumes given sampleSize is less than or equal to the size
    // of predictions array
    // 1 is +ve class i.e. > threshold and 0 is -ve class i.e. < threshold
    private static ConfusionMatrix computeConfusionMatrix
    (Prediction[] predictions, double threshold, int sampleSize){
        ConfusionMatrix output = new ConfusionMatrix();
        for(int idx = 0; idx < sampleSize; idx++){
            double label =
                    (predictions[idx].predictedValue >= threshold) ? 1.0 : 0.0;
            double actual = predictions[idx].actualValue;
            //confusion matrix domain knowledge
            // only one of the below conditions can be true at avoid time
            if(label == 1.0 && actual == 1.0) output.TPCOUNT++;
            if(label == 1.0 && actual == 0.0) output.FPCOUNT++;
            if(label == 0.0 && actual == 1.0) output.FNCOUNT++;
            if(label == 0.0 && actual == 0.0) output.TNCOUNT++;
        }//main loop ends here
        return output;
    }

    private static double computeTruePositiveRate(ConfusionMatrix input){
        double denom = input.TPCOUNT + input.FNCOUNT;
        if(denom == 0) return 0.0;
        return input.TPCOUNT / denom;
    }
    private static double computeFalsePositiveRate(ConfusionMatrix input){
        double denom = input.FPCOUNT + input.TNCOUNT;
        if(denom == 0) return 0.0;
        return input.FPCOUNT / denom;
    }
}



class ConfusionMatrix {
    public int TPCOUNT = 0;
    public int FPCOUNT = 0;
    public int TNCOUNT = 0;
    public int FNCOUNT = 0;
}

class Prediction implements Comparable<Prediction> {
    public double actualValue;
    public double predictedValue;

    @Override
    public int compareTo(Prediction other){
        Double tmp = this.predictedValue;
        return tmp.compareTo(other.predictedValue);
    }

}
