format long
more off
warning ('off', 'Octave:broadcast');

imageDim = 28;
numClasses = 1;
filterDim1 = 9;
numFilters1 = 20;
filterDim2 = 5;
numFilters2 = 10;
poolDim1 = 2;
poolDim2 = 2;

addpath ../commons/matlab/;
craters = loadCraterImages('../dataset/craters/');
fprintf('loaded %d crater images\n', size(craters, 3));
noncraters = loadCraterImages('../dataset/non-craters/');
fprintf('loaded %d non-crater images\n', size(noncraters, 3));
labels = ones(size(craters, 3) + size(noncraters, 3), 1);
labels(1:size(craters, 3), 1) = 2;
randomperm = randperm(size(labels,1));
images = cat(3, craters, noncraters);
images = images(:, :, randomperm);
labels = labels(randomperm, :);
lastTrainingIndex = floor(1 / 1 * size(images, 3));
trainingImages = images(:, :, 1:lastTrainingIndex);
testImages = trainingImages;
trainingLabels = labels(1:lastTrainingIndex, 1);
testLabels = trainingLabels;

if false
	cnnCost = @cnnCostGPU;
else
	cnnCost = @cnnCostCPU;
end

fprintf('training set size: %d\n', size(trainingImages, 3));
fprintf('test set size: %d\n', size(testImages, 3));

%theta = initialize(imageDim,filterDim1,numFilters1,...
%                                 filterDim2,numFilters2,poolDim1,poolDim2,numClasses);
 load('../dataset/craterDetectorCE.mat','opttheta');
 theta = opttheta;


convDim1 = imageDim - filterDim1 + 1;
inputDim1 = convDim1/poolDim1;
inputSize1 = inputDim1 ^ 2 * numFilters1;
convDim2 = inputDim1 - filterDim2 + 1;
inputDim2 = convDim2/poolDim2;
inputSize2 = inputDim2 ^ 2 * numFilters2 * numFilters1;

lambda = 0.001;

numTestImages = size(testImages, 3);
preds = zeros(numTestImages, 1);
stepsize = 500;
for i=1:stepsize:numTestImages
	if numTestImages - i + 1 >= stepsize
		[cost,time,~,predstemp, probstemp]=cnnCost(opttheta,testImages(:, :, i:(i+stepsize-1)),...
								testLabels(i:(i+stepsize-1)),numClasses,...
                                filterDim1,numFilters1, filterDim2, numFilters2,...
                                poolDim1, poolDim2, true, 0, 0);
	else
		[cost,time,~,predstemp, probstemp]=cnnCost(opttheta,testImages(:, :, i:numTestImages),...
								testLabels(i:numTestImages),numClasses,...
                                filterDim1,numFilters1, filterDim2, numFilters2,...
                                poolDim1, poolDim2, true, 0, 0);
	end
	preds(i:(i+size(predstemp, 1)-1), 1) = predstemp(:);
	fprintf('Progress: %2.2f%%\n', i/numTestImages*100);
end

results = zeros(size(preds));
results = (preds .* 10) + testLabels(1:numTestImages);
tp = sum(results == 22)
fp = sum(results == 21)
fn = sum(results == 12)
precision = tp / (tp + fp);
recall = tp / (tp + fn);
f1score = 2 * precision * recall / (precision + recall);

fprintf('precision is %0.2f%%\n', precision * 100);
fprintf('recall is %0.2f%%\n', recall * 100);
fprintf('f1score is %0.2f%%\n', f1score * 100);
