package lipfd.circleHough;

import lipfd.commons.Util;

import org.opencv.core.Mat;
import org.opencv.imgproc.Imgproc;

public class Retinex {

	public Retinex()
	{
		
	}
	
	public Mat SingleScaleRetinex(Mat mat, int kernelSize)
	{	
		int rows=mat.rows();
		int cols=mat.cols();
		
		if (mat.channels()!=3)
		{
			Imgproc.cvtColor(mat, mat, Imgproc.COLOR_GRAY2RGB);
		}
		
		double[][] kernel = gaussian2DKernel(kernelSize);
		
		int center=kernelSize/2;
		for (int row=0; row<rows; row++)
		{
			for (int col=0; col<cols; col++)
			{
				double[] originalData = mat.get(row, col);
				double[][][] block = new double[kernelSize][kernelSize][3];
				
				//	Four corner cases
				//	Upper left corner
				if (row==0 && col==0)
				{
					double[] data=mat.get(row, col);
					for (int i=0; i<3; i++)
					{
						block[0][0][i]=data[i];
						block[0][1][i]=data[i];
						block[1][0][i]=data[i];
						block[1][1][i]=data[i];
					}
					
					data=mat.get(row, col+1);
					for (int i=0; i<3; i++)
					{
						block[0][2][i]=data[i];
						block[1][2][i]=data[i];
					}
						
					data=mat.get(row+1, col);
					for (int i=0; i<3; i++)
					{
						block[2][0][i]=data[i];
						block[2][1][i]=data[i];
					}
					
					data=mat.get(row+1, col+1);
					for (int i=0; i<3; i++)
					{
						block[2][2][i]=data[i];
					}
				}
				//	Upper right corner
				if (row==0 && col==cols)
				{
					double[] data=mat.get(row, col);
					for (int i=0; i<3; i++)
					{
						block[0][1][i]=data[i];
						block[0][2][i]=data[i];
						block[1][2][i]=data[i];
						block[1][1][i]=data[i];
					}
					
					data=mat.get(row, col-1);
					for (int i=0; i<3; i++)
					{
						block[0][0][i]=data[i];
						block[1][0][i]=data[i];
					}
						
					data=mat.get(row+1, col);
					for (int i=0; i<3; i++)
					{
						block[2][1][i]=data[i];
						block[2][2][i]=data[i];
					}
					
					data=mat.get(row-1, col-1);
					for (int i=0; i<3; i++)
					{
						block[2][0][i]=data[i];
					}
				}
				//	Lower left corner
				if (row==rows && col==0)
				{
					double[] data=mat.get(row, col);
					for (int i=0; i<3; i++)
					{
						block[1][0][i]=data[i];
						block[2][0][i]=data[i];
						block[2][1][i]=data[i];
						block[1][1][i]=data[i];
					}
					
					data=mat.get(row-1, col);
					for (int i=0; i<3; i++)
					{
						block[0][0][i]=data[i];
						block[0][1][i]=data[i];
					}
						
					data=mat.get(row, col+1);
					for (int i=0; i<3; i++)
					{
						block[1][2][i]=data[i];
						block[2][2][i]=data[i];
					}
					
					data=mat.get(row-1, col+1);
					for (int i=0; i<3; i++)
					{
						block[0][2][i]=data[i];
					}
				}
				//	Lower right corner
				if (row==rows && col==cols)
				{
					double[] data=mat.get(row, col);
					for (int i=0; i<3; i++)
					{
						block[1][2][i]=data[i];
						block[2][2][i]=data[i];
						block[2][1][i]=data[i];
						block[1][1][i]=data[i];
					}
					
					data=mat.get(row-1, col);
					for (int i=0; i<3; i++)
					{
						block[0][1][i]=data[i];
						block[0][2][i]=data[i];
					}
						
					data=mat.get(row, col-1);
					for (int i=0; i<3; i++)
					{
						block[1][0][i]=data[i];
						block[2][0][i]=data[i];
					}
					
					data=mat.get(row-1, col-1);
					for (int i=0; i<3; i++)
					{
						block[0][0][i]=data[i];
					}
				}
				//	Top edge
				if (row==0 && (col>0 && col<cols-1))
				{
					double[] data=mat.get(row, col-1);
					for (int i=0; i<3; i++)
					{
						block[0][0][i]=data[i];
						block[1][0][i]=data[i];
					}
					
					data=mat.get(row, col);
					for (int i=0; i<3; i++)
					{
						block[0][1][i]=data[i];
						block[1][1][i]=data[i];
					}
					
					data=mat.get(row, col+1);
					for (int i=0; i<3; i++)
					{
						block[0][2][i]=data[i];
						block[1][2][i]=data[i];
					}
					
					data=mat.get(row+1, col-1);
					for (int i=0; i<3; i++)
					{
						block[2][0][i]=data[i];
					}
					
					data=mat.get(row+1, col);
					for (int i=0; i<3; i++)
					{
						block[2][1][i]=data[i];
					}
					
					data=mat.get(row+1, col+1);
					for (int i=0; i<3; i++)
					{
						block[2][2][i]=data[i];
					}
				}
				//	Right edge
				if ((row>0 && row<rows-1) && col==cols-1)
				{
					double[] data=mat.get(row-1, col);
					for (int i=0; i<3; i++)
					{
						block[0][1][i]=data[i];
						block[0][2][i]=data[i];
					}
					
					data=mat.get(row, col);
					for (int i=0; i<3; i++)
					{
						block[1][1][i]=data[i];
						block[1][2][i]=data[i];
					}
					
					data=mat.get(row+1, col);
					for (int i=0; i<3; i++)
					{
						block[2][1][i]=data[i];
						block[2][2][i]=data[i];
					}
					
					data=mat.get(row-1, col-1);
					for (int i=0; i<3; i++)
					{
						block[0][0][i]=data[i];
					}
					
					data=mat.get(row, col-1);
					for (int i=0; i<3; i++)
					{
						block[1][0][i]=data[i];
					}
					
					data=mat.get(row+1, col-1);
					for (int i=0; i<3; i++)
					{
						block[2][0][i]=data[i];
					}
				}
				//	Bottom edge
				if (row==rows-1 && (col>0 && col<cols-1))
				{
					double[] data=mat.get(row, col-1);
					for (int i=0; i<3; i++)
					{
						block[1][0][i]=data[i];
						block[2][0][i]=data[i];
					}
					
					data=mat.get(row, col);
					for (int i=0; i<3; i++)
					{
						block[1][1][i]=data[i];
						block[2][1][i]=data[i];
					}
					
					data=mat.get(row, col+1);
					for (int i=0; i<3; i++)
					{
						block[1][2][i]=data[i];
						block[2][2][i]=data[i];
					}
					
					data=mat.get(row-1, col-1);
					for (int i=0; i<3; i++)
					{
						block[0][0][i]=data[i];
					}
					
					data=mat.get(row-1, col);
					for (int i=0; i<3; i++)
					{
						block[0][1][i]=data[i];
					}
					
					data=mat.get(row-1, col+1);
					for (int i=0; i<3; i++)
					{
						block[0][2][i]=data[i];
					}
				}
				//	Left edge
				if ((row>0 && row<rows-1) && col==0)
				{
					double[] data=mat.get(row-1, col);
					for (int i=0; i<3; i++)
					{
						block[0][0][i]=data[i];
						block[0][1][i]=data[i];
					}
					
					data=mat.get(row, col);
					for (int i=0; i<3; i++)
					{
						block[1][0][i]=data[i];
						block[1][1][i]=data[i];
					}
					
					data=mat.get(row+1, col);
					for (int i=0; i<3; i++)
					{
						block[2][0][i]=data[i];
						block[2][1][i]=data[i];
					}
					
					data=mat.get(row-1, col+1);
					for (int i=0; i<3; i++)
					{
						block[0][2][i]=data[i];
					}
					
					data=mat.get(row, col+1);
					for (int i=0; i<3; i++)
					{
						block[1][2][i]=data[i];
					}
					
					data=mat.get(row+1, col+1);
					for (int i=0; i<3; i++)
					{
						block[2][2][i]=data[i];
					}
				}
				// Everything else
				if ((row>0 && row<rows-1) && (col>0 && col<cols-1))
				{
					for (int i=0; i<3; i++)
					{
						for (int j=0; j<3; j++)
						{	
							double[] data=mat.get(row-center+i, col-center+j);
							block[i][j][0]=data[0];
							block[i][j][1]=data[1];
							block[i][j][2]=data[2];
						}
					}
				}
				
				double[] convolution = convolve(block, kernel, kernelSize);
				
				double[] result = new double[3];
				double gain = 50f;
				double offset = 127.5;
				for (int i=0; i<3; i++)
				{
					result[i] = gain * (Math.log(originalData[i])-Math.log(convolution[i])) + offset;
					if (result[i]>255)
					{
						result[i]=255;
					}
					else if (result[i]<0)
					{
						result[i]=0;
					}
				}
				
				mat.put(row, col, result);
			}
		}
		
		return mat;
	}
	
	public double gaussian(double x, double mu, double sigma)
	{
		return Math.exp( -(Math.pow(x-mu, 2.0)/(2*Math.pow(sigma, 2.0))) );
	}
	
	public double[][] gaussian2DKernel(int kernelSize)
	{
		int center=kernelSize/2;
		double[][] kernel = new double[kernelSize][kernelSize];
		double sum=0.0;
		double sigma = 10.0;
		
		for (int row=0; row<kernelSize; row++)
		{
			for (int col=0; col<kernelSize; col++)
			{
				kernel[row][col] = gaussian(row, center, sigma) * gaussian(col, center, sigma);
				sum+=kernel[row][col];
			}
		}
		
		for (int row=0; row<kernelSize; row++)
		{
			for (int col=0; col<kernelSize; col++)
			{
				kernel[row][col] /= sum;
			}
		}
		
		return kernel;
	}
	
	public double[] convolve(double[][][] block, double[][] kernel, int kernelSize)
	{
		double[] result={0.0,0.0,0.0};
		
		for (int i=0; i<kernelSize; i++)
		{
			for (int j=0; j<kernelSize; j++)
			{
				result[0]+=block[i][j][0]*kernel[i][j];
				result[1]+=block[i][j][1]*kernel[i][j];
				result[2]+=block[i][j][2]*kernel[i][j];
			}
		}
		
		return result;
	}

}