package lipfd.pca;

import lipfd.commons.*;

import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Scalar;
import org.opencv.highgui.*;
import org.opencv.imgproc.Imgproc;
import org.opencv.core.Rect;
import org.opencv.core.Point;
import org.opencv.core.Size;
import org.opencv.core.TermCriteria;
import org.opencv.ml.CvKNearest;

import javax.swing.JFileChooser;
import javax.swing.filechooser.FileNameExtensionFilter;
import javax.swing.JFrame;

import java.awt.event.WindowEvent;

import java.util.*;
import java.text.SimpleDateFormat;
import java.text.DateFormat;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;
import java.lang.Math;
import java.io.File;

public class PCA {
	public static void run(List<Crater> cratersConf, Image inputImage,
		String cratersFolder, String nonCratersFolder, Integer m, Integer mPrime,
		Boolean useConvolution, Integer numClusters,
		Integer numNeighbors, Boolean useNearestMean){

		System.loadLibrary(Core.NATIVE_LIBRARY_NAME);

		//
		// Load template images
		//
		List<File> craterFiles = new ArrayList<File>();
		List<File> noncraterFiles = new ArrayList<File>();
		Collections.addAll(craterFiles,(new File(cratersFolder).listFiles()));
		Collections.addAll(noncraterFiles,(new File(nonCratersFolder).listFiles()));

		Integer numCraters = craterFiles.size();
		Integer numNonCraters = noncraterFiles.size();

		Mat craters = new Mat();
		for(int i = 0; i < m; i++){
			Mat mat = Highgui.imread(craterFiles.get(i).getPath(), Highgui.CV_LOAD_IMAGE_GRAYSCALE);
			Core.subtract(mat, Core.mean(mat), mat);
			if(!useConvolution)
				craters.push_back(mat.reshape(1, 1));
			else craters.push_back(convolveWithFilters(mat));
		}

		Integer numFeatures = craters.cols();

		Mat eig = new Mat();
		// eig.convertTo(eig, CvType.CV_32F);
		Mat mean = new Mat();
		// mean.convertTo(mean, CvType.CV_32F);
		Core.PCACompute(craters, mean, eig, mPrime);


		System.out.println("Running the clustering Algorithm");

		Mat data = new Mat();
		Mat labels = new Mat();
		System.out.print("Loading crater images: ");
		for(int i = 0; i < numCraters; i++){
			Mat mat = Highgui.imread(craterFiles.get(i).getPath(), Highgui.CV_LOAD_IMAGE_GRAYSCALE);
			Core.subtract(mat, Core.mean(mat), mat);
			if(mat.width() == 28 && mat.height() == 28){
				if(!useConvolution)
					data.push_back(mat.reshape(1, 1));
				else data.push_back(convolveWithFilters(mat));
				labels.put(i, 0, 1);
			}
			System.out.print(String.format("\rLoading crater images: %d/%d", i+1, craterFiles.size()));
		}
		System.out.print("\nLoading noncrater images: ");
		for(int i = 0; i < numNonCraters; i++){
			Mat mat = Highgui.imread(noncraterFiles.get(i).getPath(), Highgui.CV_LOAD_IMAGE_GRAYSCALE);
			Core.subtract(mat, Core.mean(mat), mat);
			if(mat.width() == 28 && mat.height() == 28){
				if(!useConvolution)
					data.push_back(mat.reshape(1, 1));
				else data.push_back(convolveWithFilters(mat));
				labels.put(numCraters + i, 0, 0);
			}
			System.out.print(String.format("\rLoading noncrater images: %d/%d", i+1, noncraterFiles.size()));
		}
		System.out.println("");

		Mat floatingPointData = new Mat();
		data.convertTo(floatingPointData, CvType.CV_32F);
		Mat projectedData = new Mat();
		Core.PCAProject(floatingPointData, mean, eig, projectedData);

		Mat centroids = new Mat();
		Core.kmeans(projectedData, numClusters, labels, new TermCriteria(TermCriteria.COUNT + TermCriteria.EPS,
			1000, 0.000001), 1, Core.KMEANS_PP_CENTERS, centroids);

		int[] dominantClass = new int[numClusters];
		for(int i = 0; i < numClusters; i++){
			int foundCraters = 0;
			int foundNonCraters = 0;
			for(int j = 0; j < labels.rows(); j++){
				if(labels.get(j, 0)[0] == i){
					if(j < numCraters)
						foundCraters++;
					else foundNonCraters++;
				}
			}
			if((double)foundCraters > (double)foundNonCraters)
				dominantClass[i] = 1;
			else dominantClass[i] = 0;
		}

		for(int k = 0; k < numClusters; k++){
			for(int i = 0; i < labels.rows(); i++){
				if((int)labels.get(i, 0)[0] == k){
					int classs = 0;
					classs = calcClass(i, numCraters);
					if(dominantClass[k] != classs){
						double minDistance = Double.MAX_VALUE;
						int minIndex = 0;
						for(int c = 0; c < numClusters; c++){
							if(dominantClass[c] == classs){
								double distance = calcDistance(projectedData, centroids, i, c);
								if(distance < minDistance){
									minDistance = distance;
									minIndex = c;
								}
							}
						}
						labels.put(i, 0, minIndex);
					}
				}
			}
			calcCentroids(projectedData, labels, numClusters, centroids);
		}
		CvKNearest knn = new CvKNearest(projectedData, labels, new Mat(), false, 1001);

		Mat newTestData = new Mat();
		double ratio = 2;
		System.out.print("Loading crater candidates:");
		for(int i = 0; i < cratersConf.size(); i++){
			Crater c = cratersConf.get(i);
			int[] enclosingRect = new int[4];
			int proposedradius = (int)((double) Math.max(c.enclosingRect[2] - c.enclosingRect[0],
				c.enclosingRect[3] - c.enclosingRect[1])/2 * ratio);
			enclosingRect[0] = ((int)c.centerX) - proposedradius;
			enclosingRect[1] = ((int)c.centerY) - proposedradius;
			enclosingRect[2] = ((int)c.centerX) + proposedradius;
			enclosingRect[3] = ((int)c.centerY) + proposedradius;
			Mat mat = inputImage.crop(enclosingRect[0], enclosingRect[1], enclosingRect[2], enclosingRect[3]).resize(28, 28).getMat();
			Core.subtract(mat, Core.mean(mat), mat);
			if(mat.width() == 28 && mat.height() == 28){
				if(!useConvolution)
					newTestData.push_back(mat.reshape(1, 1));
				else newTestData.push_back(convolveWithFilters(mat));
			}
			System.out.print(String.format("\rLoading crater candidates: %d/%d", i+1, cratersConf.size()));
		}
		System.out.println("");

		Mat floatingPointNewData = new Mat();
		newTestData.convertTo(floatingPointNewData, CvType.CV_32F);
		Mat projectedNewData = new Mat();
		Core.PCAProject(floatingPointNewData, mean, eig, projectedNewData);

		System.out.print("recognizing craters:");
		for(int i = 0; i < projectedNewData.rows(); i++){
			int result = -1;
			if(useNearestMean)
				result = calcNearestMean(projectedNewData.row(i), centroids, labels, dominantClass, numClusters);
			else result = calcNearestNeighbor(projectedNewData.row(i), projectedData, labels, dominantClass, numClusters, knn, numNeighbors);
			if(result == 1){cratersConf.get(i).conf=1;}
			else {cratersConf.get(i).conf=0;}
			System.out.print(String.format("\rrecognizing craters: %d/%d", i+1, cratersConf.size()));
		}
		System.out.println("");
	}
	private static int calcNearestMean(Mat projectedCandidate, Mat centroids, Mat labels, int[] dominantClass, int numClusters){
		double minDistance = Double.MAX_VALUE;
		int minIndex = 0;
		for(int k = 0; k < numClusters; k++){
			double distance = Core.norm(
				centroids.row(k), projectedCandidate, Core.NORM_L2);
			if(distance < minDistance){
				minDistance = distance;
				minIndex = k;
			}
		}
		return dominantClass[minIndex];
	}
	private static int calcNearestNeighbor(Mat projectedCandidate, Mat projectedData, Mat labels, int[] dominantClass, int numClusters, CvKNearest knn, int numNeighbors){
		Mat results = new Mat();
		Mat responses = new Mat();
		Mat dists = new Mat();
		knn.find_nearest(projectedCandidate, numNeighbors, results, responses, dists);
		int cluster = (int)results.get(0, 0)[0];
		return dominantClass[cluster];
	}
	private static double calcDistance(Mat m1, Mat m2, int i1, int i2){
		return Core.norm(m1.row(i1), m2.row(i2), Core.NORM_L2);
	}
	private static void calcCentroids(Mat data, Mat labels, int numClusters, Mat centroids){
		for(int k = 0; k < numClusters; k++){
			Mat centroid = Mat.zeros(data.row(1).size(), CvType.CV_32F);
			double num = 0;
			for(int i = 0; i < data.rows(); i++){
				if(labels.get(i, 0)[0] == k){
					Core.add(data.row(i), centroid, centroid);
					num++;
				}
			}
			Core.multiply(centroid, Mat.ones(centroid.size(), CvType.CV_32F),
				centroids.row(k), 1/num, CvType.CV_32F);
		}
	}
	private static int calcClass(int i, int numCraters){
		return i<numCraters?1:0;
	}
	private static Mat convolveWithFilters(Mat image){
		int filterDim = 8;
		image.convertTo(image, CvType.CV_32F);
		int imageDim = image.cols();
		Mat dest = new Mat();
		List<Mat> filters = new ArrayList<Mat>();
		Point anchor = new Point(filterDim/2, filterDim/2);
		Mat output = new Mat();


		Mat tempfilter = new Mat(filterDim,filterDim, CvType.CV_32F);
		for(int y = 0; y < filterDim/2; y++){
			for(int x = 0; x < filterDim; x++){
				tempfilter.put(y, x, 1);
			}
		}
		for(int y = filterDim/2; y < filterDim; y++){
			for(int x = 0; x < filterDim; x++){
				tempfilter.put(y, x, -1);
			}
		}
		filters.add(tempfilter.clone());

		for(int y = 0; y < filterDim/2; y++){
			for(int x = 0; x < filterDim; x++){
				tempfilter.put(y, x, -1);
			}
		}
		for(int y = filterDim/2; y < filterDim; y++){
			for(int x = 0; x < filterDim; x++){
				tempfilter.put(y, x, 1);
			}
		}
		filters.add(tempfilter.clone());

		for(int y = 0; y < filterDim/2; y++){
			for(int x = 0; x < filterDim; x++){
				tempfilter.put(y, x, -1);
			}
		}
		for(int y = filterDim/2; y < filterDim; y++){
			for(int x = 0; x < filterDim; x++){
				if(x >= filterDim/2)
					tempfilter.put(y, x, 1);
				else tempfilter.put(y, x, -1);
			}
		}
		filters.add(tempfilter.clone());

		for(int y = 0; y < filterDim/2; y++){
			for(int x = 0; x < filterDim; x++){
				if(x >= filterDim/2)
					tempfilter.put(y, x, 1);
				else tempfilter.put(y, x, -1);
			}
		}
		for(int y = filterDim/2; y < filterDim; y++){
			for(int x = 0; x < filterDim; x++){
				tempfilter.put(y, x, -1);
			}
		}
		filters.add(tempfilter.clone());

		for(int y = 0; y < filterDim/2; y++){
			for(int x = 0; x < filterDim; x++){
				if(x < filterDim/2)
					tempfilter.put(y, x, 1);
				else tempfilter.put(y, x, -1);
			}
		}
		for(int y = filterDim/2; y < filterDim; y++){
			for(int x = 0; x < filterDim; x++){
				tempfilter.put(y, x, -1);
			}
		}
		filters.add(tempfilter.clone());

		for(int y = 0; y < filterDim/2; y++){
			for(int x = 0; x < filterDim; x++){
				tempfilter.put(y, x, -1);
			}
		}
		for(int y = filterDim/2; y < filterDim; y++){
			for(int x = 0; x < filterDim; x++){
				if(x < filterDim/2)
					tempfilter.put(y, x, 1);
				else tempfilter.put(y, x, -1);
			}
		}
		filters.add(tempfilter.clone());

		for(int y = 0; y < filterDim/2; y++){
			for(int x = 0; x < filterDim; x++){
				if(x < filterDim/2)
					tempfilter.put(y, x, 1);
				else tempfilter.put(y, x, -1);
			}
		}
		for(int y = filterDim/2; y < filterDim; y++){
			for(int x = 0; x < filterDim; x++){
				if(x >= filterDim/2)
					tempfilter.put(y, x, 1);
				else tempfilter.put(y, x, -1);
			}
		}
		filters.add(tempfilter.clone());

		for(int y = 0; y < filterDim; y++){
			for(int x = 0; x < filterDim; x++){
				tempfilter.put(y, x, 1.0);
			}
		}
		for(int y = filterDim/3; y < 2*filterDim/3; y++){
			for(int x = 0; x < filterDim; x++){
				tempfilter.put(y, x, -1.0);
			}
		}
		filters.add(tempfilter.clone());

		for(int y = 0; y < filterDim; y++){
			for(int x = 0; x < filterDim; x++){
				tempfilter.put(y, x, 1.0);
			}
		}
		for(int x =0; x < filterDim; x++){
			for(int y = filterDim/3; y < 2*filterDim/3; y++){
				tempfilter.put(x, y, -1.0);
			}
		}
		filters.add(tempfilter.clone());

		List<Mat> convResults = new ArrayList<Mat>();
		for(int i = 0; i < filters.size(); i++){
			dest = new Mat();
			Mat filter = filters.get(i);
			Imgproc.filter2D(image, dest, 1, filter, anchor, 0, Imgproc.BORDER_CONSTANT);
			dest = dest.colRange((filterDim)/2, imageDim - filterDim/2)
               .rowRange((filterDim)/2, imageDim - filterDim/2);
            // System.out.println(dest.dump());
			convResults.add(dest.clone().reshape(1,1));
		}

		Mat resized = new Mat(new Size(imageDim - filterDim, imageDim - filterDim), 1);
		Imgproc.resize(image, resized, new Size(imageDim - filterDim, imageDim - filterDim));
		resized.convertTo(resized, 1);
		convResults.add(resized.reshape(1, 1));

		for(int i = 0; i < convResults.size(); i++){
			Core.normalize(convResults.get(i), convResults.get(i), 0, 1, Core.NORM_MINMAX, CvType.CV_32F);
		}
		Mat conv_norm = new Mat();
		Core.hconcat(convResults, output);
		output.convertTo(conv_norm, CvType.CV_32F);
		//Core.normalize(output, conv_norm, 0, 255, Core.NORM_MINMAX, CvType.CV_64F);
		return conv_norm;
	}
}
