%% This file use matlab cell mode. To run the cells, use Ctrl+Enter. 


%% cross validation, by divide the genes into 10 groups and average out...
load gene_data;
K = 10;  
x = time_point;
y = gene_data;
Ngene = size(y, 2);
random_cluster00 = ceil(rand(1, Ngene)*K);
sigma_t = 1; 
Nstarts = 20;

% estimate the parameters
for hstp = 1:K
    random_cluster = (random_cluster00 == hstp) + 1;
    y_training = y(:, find(random_cluster == 1)); 
    y_test = y(:, find(random_cluster == 2));
 
    para_xVV( :, hstp)  =   max_tau_gene(x, y_training,  sigma_t, Nstarts);
    theta_xVV(:, hstp)  = max_theta_only(x, y_training,  sigma_t, Nstarts);
    disp({'Step:', hstp});
end

% calculate the mse
for rstp = 1:100
    numrep =  [2     2     3     3     2     3     3     3     3];
    %ind_knock = stp;
    ind_knock = 1 + [2,4, 7, 10, 12, 15,18,21,24]' - rand_integer(numrep, 1, 9)';%[2,4, 7, 10, 12, 15,18,21,24]';
    luo = [1:length(time_point)]' ~= ind_knock(1);
    for i = 2:length(ind_knock)
       luo  = luo.*([1:length(time_point)]' ~= ind_knock(i));
    end
    indk = find(luo);
    
    for hstp = 1:10
        random_cluster = (random_cluster00 == hstp) + 1;
        y_test = y(:, find(random_cluster == 2));
        para_xv   = para_xVV(:, hstp);
        theta_xv  = theta_xVV(:, hstp);
        err_ww(:, hstp) = loocv_mse(x, y_test, para_xv, ind_knock)';
        err2_ww(:, hstp) = loocv_mse(x, y_test, [zeros(24,1); theta_xv], ind_knock)';
    end
    err_ww_m(:, rstp) = mean(err_ww, 2);
    err2_ww_m(:, rstp) = mean(err2_ww, 2);
    ind_knockVec(:, rstp) = ind_knock;
end
% average the mse over each replicate
rati = (err2_ww_m - err_ww_m)./err2_ww_m;
for i = 1:length(x)
    [a, b]=find(ind_knockVec == i);
    err_rep(i) = mean(mean(err_ww_m(a,b)));
    err2_rep(i) = mean(mean(err2_ww_m(a,b)));
    rati_rep(i) = mean(mean(rati(a,b)));
    rati_std(i) = std(rati(a(1),b));
end

% plot the figure of improved ratio and its error bar. 
numrep =  [2     2     3     3     2     3     3     3     3]';
barbase = 2*[1:length(time_p2)];
wth = 1;
wth2 = 1.8;
barx = [0; wth];
for i = 2:length(time_p2)
    if numrep(i) == 2
        barx = [barx; barx(end) + wth2 + [0; wth]];
    end
    if numrep(i) == 3
        barx = [barx; barx(end) + wth2 + [0; wth; 2*wth]];
    end
end

figure
subplot(10,1,1:5);
hold on;
bar(barx, rati_rep, 1.2*wth);
errorbar(barx, rati_rep, rati_std,'.', 'LineWidth', 1.5, 'Color', 'red');
colormap summer
v = axis;
ylabel('Percentage Reduction of MSE');
set(gca, 'YTickLabel', {'-20%', '0', '20%', '40%', '60%', '80%'});
axis([barx(1)- wth/2, barx(end)+ wth/2, v(3:4)]);
set(gca, 'XTick', barx);
set(gca,'XTickLabel', {'23-1', '23-2', '25-1', '25-2', '27-1', '27-2',...
                      '27-3', '29-1', '29-2', '29-3', '34-1', '34-2',...
                      '37-1', '37-2', '37-3', '39-1', '39-2', '39-3',...
                      '41-1', '41-2', '41-3', '44-1', '44-2', '44-3'});
                  
rotateticklabel(gca, 60) 


