function [opttheta] = sgd(funObj,theta,data,labels,...
                        options)
%% Setup
assert(all(isfield(options,{'epochs','alpha','minibatch'})),...
        'Some options not defined');
if ~isfield(options,'momentum')
    options.momentum = 0.95;
end;
epochs = options.epochs;
alpha = options.alpha;
minibatch = options.minibatch;
wc1max = options.wc1max;
wc2max = options.wc2max;
wd1max = options.wd1max;
wd2max = options.wd2max;
wd3max = options.wd3max;



m = length(labels); % training set size
% Setup for momentum
mom = 0.5;
momIncrease = 1;
velocity = zeros(size(theta));
gradient_ac = zeros(size(theta));
update_ac = zeros(size(theta));
epsilon = 1e-6;

%%======================================================================
%% SGD loop
it = 0;
for e = 1:epochs
    
    % randomly permute indices of data for quick minibatch sampling
    rp = randperm(m);
    
    for s=1:minibatch:(m-minibatch+1)
        it = it + 1;

        % increase momentum after momIncrease iterations
        if it == momIncrease
            mom = options.momentum;
        end;

        % get next randomly selected minibatch
        mb_data = data(:,:,rp(s:s+minibatch-1));
        mb_labels = labels(rp(s:s+minibatch-1));

        % evaluate the objective function on the next minibatch
        [cost,time,grad,~,probs] = funObj(theta,mb_data,mb_labels);

        
        %%% MOMENTUM METHOD %%%
        % velocity = mom * velocity + alpha * grad;
        % theta = theta - velocity;

        [Wc1, Wc2, Wd1, Wd2, Wd3, bc1, bc2, bd1, bd2, bd3] = cnnParamsToStack(theta,options.imageDim,options.filterDim1,...
                                 options.numFilters1,options.filterDim2,options.numFilters2,options.poolDim1,...
                                 options.poolDim2,options.numClasses);


        wc1sum = sqrt(sum(Wc1(:) .^ 2) + sum(bc1(:) .^ 2));
        wc2sum = sqrt(sum(Wc2(:) .^ 2) + sum(bc2(:) .^ 2));
        wd1sum = sqrt(sum(Wd1(:) .^ 2) + sum(bd1(:) .^ 2));
        wd2sum = sqrt(sum(Wd2(:) .^ 2) + sum(bd2(:) .^ 2));
        wd3sum = sqrt(sum(Wd3(:) .^ 2) + sum(bd3(:) .^ 2));

        %%% ADADELTA METHOD %%%
        gradient_ac = mom * gradient_ac + (1 - mom) * (grad .^ 2);
        velocity = sqrt(update_ac + epsilon) ./ sqrt(gradient_ac + epsilon) .* grad;
        update_ac = mom * update_ac + (1 - mom) * (velocity .^ 2);
        theta = theta - velocity;

        %%% Max-norm Regularization %%%
        [nWc1, nWc2, nWd1, nWd2, nWd3, nbc1, nbc2, nbd1, nbd2, nbd3] = cnnParamsToStack(theta,options.imageDim,options.filterDim1,...
                                 options.numFilters1,options.filterDim2,options.numFilters2,options.poolDim1,...
                                 options.poolDim2,options.numClasses);

        if wc1sum > sqrt(wc1max)
            nWc1 = nWc1 ./ wc1sum * sqrt(wc1max);
            nbc1 = nbc1 ./ wc1sum * sqrt(wc1max);
        end
        if wc2sum > sqrt(wc2max)
            nWc2 = nWc2 ./ wc2sum * sqrt(wc2max);
            nbc2 = nbc2 ./ wc2sum * sqrt(wc2max);
        end
        if wd1sum > sqrt(wd1max)
            nWd1 = nWd1 ./ wd1sum * sqrt(wd1max);
            nbd1 = nbd1 ./ wd1sum * sqrt(wd1max);
        end
        if wd2sum > sqrt(wd2max)
            nWd2 = nWd2 ./ wd2sum * sqrt(wd2max);
            nbd2 = nbd2 ./ wd2sum * sqrt(wd2max);
        end
        if wd3sum > sqrt(wd3max)
            nWd3 = nWd3 ./ wd3sum * sqrt(wd3max);
            nbd3 = nbd3 ./ wd3sum * sqrt(wd3max);
        end

        theta = [nWc1(:); nWc2(:); nWd1(:); nWd2(:); nWd3(:); nbc1(:); nbc2(:); nbd1(:); nbd2(:); nbd3(:)];

        

        fprintf('Epoch %d: Cost on iteration %d is %f\n',e,it,cost);        
    end;

    % aneal learning rate by factor of two after each epoch
    alpha = alpha/2;

end;

opttheta = theta;

end
