format long
more off
warning ('off', 'Octave:broadcast');

imageDim = 28;
numClasses = 1;
filterDim1 = 9;
numFilters1 = 20;
filterDim2 = 5;
numFilters2 = 10;
poolDim1 = 2;
poolDim2 = 2;

addpath ../commons/matlab/;
craters = loadCraterImages('../templateMatching/result/dataset/craters/');
fprintf('loaded %d crater images\n', size(craters, 3));
noncraters = loadCraterImages('../templateMatching/result/dataset/non-craters/');
fprintf('loaded %d non-crater images\n', size(noncraters, 3));
labels = ones(size(craters, 3) + size(noncraters, 3), 1);
labels(1:size(craters, 3), 1) = 2;
randomperm = randperm(size(labels,1));
images = cat(3, craters, noncraters);
images = images(:, :, randomperm);
labels = labels(randomperm, :);
lastTrainingIndex = floor(1 / 1 * size(images, 3));
trainingImages = images(:, :, 1:lastTrainingIndex);
testImages = trainingImages;
trainingLabels = labels(1:lastTrainingIndex, 1);
testLabels = trainingLabels

if false %~exist('images', 'var') || ~exist('labels', 'var') || ~exist('testImages', 'var') || ~exist('rocks', 'var')
	clear
	imageDim = 28;
	numClasses = 1;
	filterDim1 = 9;
	numFilters1 = 20;
	filterDim2 = 5;
	numFilters2 = 10;
	poolDim1 = 2;
	poolDim2 = 2;
	addpath ../commons/matlab/;
	load('../template-matching/result/dataset/craterData.mat');
	load('../template-matching/result/dataset/rockImagesRefined.mat');
	load('../template-matching/result/dataset/oldCraterImages.mat');


	% Make images have zero mean
	%images = bsxfun(@minus, images, sum(sum(images))/imageDim/imageDim);
	%rocks = bsxfun(@minus, rocks, sum(sum(rocks))/imageDim/imageDim);
	%oldCraters = bsxfun(@minus, oldCraters, sum(sum(oldCraters))/imageDim/imageDim);

	images = cat(3, images, rocks, oldCraters);
	labels = cat(1, labels, ones(size(rocks, 3), 1));
	labels = cat(1, labels, 2 .* ones(size(oldCraters, 3), 1));
	randomperm = randperm(size(labels,1));
	images = images(:, :, randomperm);
	labels = labels(randomperm, :);

	fprintf('loaded %d images with labels\n', size(images, 3));

	lastTrainingIndex = floor(1 / 1 * size(images, 3));
	trainingImages = images(:, :, 1:lastTrainingIndex);
	testImages = images(:, :, 1:lastTrainingIndex);
	trainingLabels = labels(1:lastTrainingIndex, 1);
	testLabels = labels(1:lastTrainingIndex, 1);
end

if false
	cnnCost = @cnnCostGPU;
else
	cnnCost = @cnnCostCPU;
end

fprintf('training set size: %d\n', size(trainingImages, 3));
fprintf('test set size: %d\n', size(testImages, 3));

costWeights = ones(size(numClasses, 1));
% for k=1:numClasses
%     costWeights(k,1) = size(labels, 1)/sum(labels == k, 1);
% end
% costWeights = costWeights / sum(costWeights) * numClasses;


theta = initialize(imageDim,filterDim1,numFilters1,...
                                 filterDim2,numFilters2,poolDim1,poolDim2,numClasses);
 %load('../templateMatching/result/dataset/craterDetectorCESaman.mat','opttheta');
 %theta = opttheta;


convDim1 = imageDim - filterDim1 + 1;
inputDim1 = convDim1/poolDim1;
inputSize1 = inputDim1 ^ 2 * numFilters1;
convDim2 = inputDim1 - filterDim2 + 1;
inputDim2 = convDim2/poolDim2;
inputSize2 = inputDim2 ^ 2 * numFilters2 * numFilters1;

lambda = 0.001;
options.epochs = 50;
options.minibatch = 256;
options.alpha = 1e-1;
options.momentum = 0.90;
options.wc1max = 16;
options.wc2max = 4;
options.wd1max = 9900;
options.wd2max = 9900;
options.wd3max = 24;

options.imageDim = imageDim;
options.filterDim1 = filterDim1;
options.filterDim2 = filterDim2;
options.numFilters1 = numFilters1;
options.numFilters2 = numFilters2;
options.poolDim1 = poolDim1;
options.poolDim2 = poolDim2;
options.numClasses = numClasses;



opttheta = sgd(@(x,y,z) cnnCost(x,y,z,numClasses,...
                                filterDim1, numFilters1, filterDim2, numFilters2,...
                                poolDim1, poolDim2, false, lambda, costWeights),...
					theta,trainingImages,trainingLabels,options);

m=input('Do you want to save theta, Y/N [Y]:','s');
if m~='N' && m~='n'
save('./data-files/craterDetectorCE.mat', 'opttheta');
fprintf('saved the optimal weights\n');
end


numTestImages = size(testImages, 3);
preds = zeros(numTestImages, 1);
stepsize = 500;
for i=1:stepsize:numTestImages
	if numTestImages - i + 1 >= stepsize
		[cost,time,~,predstemp, probstemp]=cnnCost(opttheta,testImages(:, :, i:(i+stepsize-1)),...
								testLabels(i:(i+stepsize-1)),numClasses,...
                                filterDim1,numFilters1, filterDim2, numFilters2,...
                                poolDim1, poolDim2, true);
	else
		[cost,time,~,predstemp, probstemp]=cnnCost(opttheta,testImages(:, :, i:numTestImages),...
								testLabels(i:numTestImages),numClasses,...
                                filterDim1,numFilters1, filterDim2, numFilters2,...
                                poolDim1, poolDim2, true);
	end
	preds(i:(i+size(predstemp, 1)-1), 1) = predstemp(:);
	fprintf('Progress: %2.2f%%\n', i/numTestImages*100);
end

results = zeros(size(preds));
results = (preds .* 10) + testLabels(1:numTestImages);
tp = sum(results == 22)
fp = sum(results == 21)
fn = sum(results == 12)
precision = tp / (tp + fp);
recall = tp / (tp + fn);
f1score = 2 * precision * recall / (precision + recall);

fprintf('precision is %0.2f%%\n', precision * 100);
fprintf('recall is %0.2f%%\n', recall * 100);
fprintf('f1score is %0.2f%%\n', f1score * 100);
