function [hme] = hmeFit(hme,x,y,iter,tol,split,vis,prefix) % [hme] = hmeFit(hme,x,y,iter,tol,split,vis) % % Fit a k-class HME model. % % INPUTS % hme HME model with params initialized % x dxn matrix of samples % y kxn matrix of class assignments % [iter] Number of EM iterations. If a vector, then perform % at least min(iter) and at most max(iter) iterations. % Default value is [20 50]. A single number is valid. % [tol] Stop EM when log likelihood increases by a fraction % less than tol, i.e. (old-new)/old < tol. % Used when a range of iterations is given. % Default value is 1e-4. % [split] Value in (0,1) giving fraction of data to use as % the training set. The remaining data is used as % the test set. Default value is 0.5. % [vis] Visualization level: % 0 = none [default] % 1 = plot log likelihood vs. iteration % 2 = image model details % [prefix] Prefix for visualization output files. % % OUTPUTS % hme HME model with params fitted % % Calls logistK, logistK_eval. % % David Martin % Charless Fowlkes % May 7, 2002 % Copyright (C) 2002 David R. Martin % Copyright (C) 2002 Charless C. Fowlkes % % This program is free software; you can redistribute it and/or % modify it under the terms of the GNU General Public License as % published by the Free Software Foundation; either version 2 of the % License, or (at your option) any later version. % % This program is distributed in the hope that it will be useful, but % WITHOUT ANY WARRANTY; without even the implied warranty of % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU % General Public License for more details. % % You should have received a copy of the GNU General Public License % along with this program; if not, write to the Free Software % Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA % 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. error(nargchk(3,8,nargin)); if nargin < 4, iter = [10 50]; end if nargin < 5, tol = 1e-4; end if nargin < 6, split = 0.5; end if nargin < 7, vis = 0; end if nargin < 8, prefix = ''; end [d,n] = size(x); [k,n] = size(y); % split the samples into training and test sets n1 = round(n*split); n2 = n - n1; if n1 <= 0, error('there is no training data!'); end perm = randperm(n); ind1 = perm(1:n1); ind2 = perm(n1+1:n); x1 = x(:,ind1); x2 = x(:,ind2); y1 = y(:,ind1); y2 = y(:,ind2); L1 = zeros(1,max(iter)); L2 = zeros(1,max(iter)); if vis > 0, h = figure; set(h,'DoubleBuffer','on'); end lli = -inf; for i = 1:max(iter), % EM using the training set hme = Estep(hme,x1,y1); hme = Mstep(hme,x1,y1,ones(1,n1)); if vis > 1, % visualize the current state of the model hmeVis(hme,k,[0 1],[0 1],x1,prefix); end if n2 > 0 & (i > min(iter) | vis > 0), % evaluate the model using the test set [post2,lik2,lli2] = hmeEval(hme,k,x2,y2); L2(i) = lli2/n2; lli_prev = lli; lli = lli2/n2; disp(sprintf('hme iter=%d lli=%g',i,lli)); else disp(sprintf('hme iter=%d',i)); end if vis > 0, % evaluate the model using the training set [post1,lik1,lli1] = hmeEval(hme,k,x1,y1); L1(i) = lli1/n1; % plot log likelihood for training,test sets over time figure(h); hold off; plot(L1(1:i),'b-o'); if n2 > 0, hold on; plot(L2(1:i),'r-o'); end if n2 > 0, axis([1 max(iter) 1.01*min([L1(1:i) L2(1:i)]) 0.99*max([L1(1:i) L2(1:i)])]); else axis([1 max(iter) 1.01*min(L1(1:i)) 0.99*max(L1(1:i))]); end xlabel('iteration'); ylabel('mean log likelihood'); if n2 > 0, legend('training set','test set',4); end print(h,'-depsc',[prefix 'lli']); end if n2 > 0 & i > min(iter), % stop iterating if the log-likelihood on the test set % decreases by a fraction less than tol if (lli_prev-lli)/lli < tol, break, end end end if vis == 1, % visualize the current state of the model hmeVis(hme,k,[0 1],[0 1],x1,prefix); end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % requires hme.param in all nodes % creates hme.clik in internal nodes % creates hme.lik in all nodes function [hme] = Estep(hme,x,y) if hme.leaf, % compute likelihood of expert [post,lik] = logistK_eval(hme.param,x,y); hme.lik = lik; else % compute likelihood of children for i = 1:length(hme.children), hme.children{i} = Estep(hme.children{i},x,y); end % evaluate gating function gate = logistK_eval(hme.param,x); % bxn matrix % compute node likelihood and normalized gated child % likelihoods hme.clik = zeros(size(gate)); for i = 1:length(hme.children), hme.clik(i,:) = gate(i,:) .* hme.children{i}.lik; end hme.lik = sum(hme.clik,1); for i = 1:length(hme.children), hme.clik(i,:) = hme.clik(i,:) ./ (hme.lik+eps); end end % end Estep %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % requires hme.{lik,clik} from Estep % sets hme.param in all nodes % destroys hme.{lik,clik} in all nodes function [hme] = Mstep(hme,x,y,w) if hme.leaf, % train expert hme.param = logistK(x,y,w,hme.param); else % train children for i = 1:length(hme.children), wi = w .* hme.clik(i,:); hme.children{i} = Mstep(hme.children{i},x,y,wi); end % train gating function hme.param = logistK(x,hme.clik,w,hme.param); end % blow away stuff from Estep if hme.leaf, hme = rmfield(hme,{'lik'}); else hme = rmfield(hme,{'lik','clik'}); end % end Mstep %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%