首先考虑两类情况。想象一下寻找形式为$ \ phi(w ^ \ top x)$的特征,其中$ w \ in \ mathbb {R} ^ d $是一个“weight vector”$ \ phi $是某种非线性。定义良好功能的简单标准是什么?一个想法是该功能在一个类上具有较小的平均值,而在另一类上具有较大的平均值。假设$ \ phi $为非负数,则建议最大化比率\ [
w ^ * = \ arg \ max_w \ frac {\ mathbb {E} [\ phi(w ^ \ top x)| y = 1]} {\ mathbb {E} [\ phi(w ^ \ top x)| y = 0]}。
\] 对于$ \ phi(z)= z ^ 2 $的特定选择,这是易处理的,因为它会导致 瑞利商 在两个有条件的第二时刻之间,\ [
w ^ * = \ arg \ max_w \ frac {w ^ \ top \ mathbb {E} [x x ^ \ top | y = 1] w} {w ^ \ top \ mathbb {E} [x x ^ \ top | y = 0] w},
\] 可以通过广义特征值分解来解决。广义特征值问题已经在机器学习和其他领域进行了广泛的研究,并且上述想法与许多其他建议(例如, 费希尔LDA ),但它是不同的,而且在经验上更有效。我将引导您参考该论文进行更深入的讨论,但是我会提到在论文被接受之后,有人指出了与 CSP ,这是时间序列分析中的一种技术(参见 传道书1:4-11 )。
此过程产生的功能通过了气味测试。例如,从上的原始像素表示开始 mnist , the 权重向量s can be visualized as images; the first 权重向量 for discriminating 3 vs. 2 looks like
看起来像笔触,参见图1D 兰萨托等等
我们在本文中还提出了其他一些意见。首先是,如果相关的广义特征值较大,则瑞利商的多个孤立极小值将很有用,即可以从瑞利商中提取多个特征。第二个是,对于适度的$ k $,我们可以独立提取每个类对的特征,并使用所有得到的特征获得良好的结果。第三是所得方向具有附加结构,该附加结构不能被平方非线性完全捕获,这会激发(单变量)基函数展开。第四,一旦原始表示增加了其他功能,就可以重复该过程,有时还会产生其他改进。最后,我们可以将其与随机特征图组合起来,以近似RKHS中的相应操作,有时还会产生其他改进。在本文中,我们还提出了一个抛弃式的评论,即在映射简化样式的分布式框架中轻松完成计算类条件第二矩矩阵的操作,但这实际上是我们朝这个方向进行探索的主要动机,但事实并非如此。不太适合本文的论述,因此我们不再强调它。
将上述想法与Nikos的多类别预处理梯度学习结合起来, 以前的帖子 ,导致出现以下Matlab脚本,该脚本在(排列不变)mnist上获得91个测试错误。注意:您需要下载 mnist _all.mat 从Sam Roweis的网站运行。
function calgevsquared more off; clear all; close all; start=tic; load('mnist_all.mat'); xxt=[train0; train1; train2; train3; train4; train5; ... train6; train7; train8; train9]; xxs=[test0; test1; test2; test3; test4; test5; test6; test7; test8; test9]; kt=single(xxt)/255; ks=single(xxs)/255; st=[size(train0,1); size(train1,1); size(train2,1); size(train3,1); ... size(train4,1); size(train5,1); size(train6,1); size(train7,1); ... size(train8,1); size(train9,1)]; ss=[size(test0,1); size(test1,1); size(test2,1); size(test3,1); ... size(test4,1); size(test5,1); size(test6,1); size(test7,1); ... size(test8,1); size(test9,1)]; paren = @(x, varargin) x(varargin{:}); yt=zeros(60000,10); ys=zeros(10000,10); I10=eye(10); lst=1; for i=1:10; yt(lst:lst+st(i)-1,:)=repmat(I10(i,:),st(i),1); lst=lst+st(i); end lst=1; for i=1:10; ys(lst:lst+ss(i)-1,:)=repmat(I10(i,:),ss(i),1); lst=lst+ss(i); end clear i st ss lst clear xxt xxs clear train0 train1 train2 train3 train4 train5 train6 train7 train8 train9 clear test0 test1 test2 test3 test4 test5 test6 test7 test8 test9 [n,k]=size(yt); [m,d]=size(ks); gamma=0.1; top=20; for i=1:k ind=find(yt(:,i)==1); kind=kt(ind,:); ni=length(ind); covs(:,:,i)=double(kind'*kind)/ni; clear ind kind; end filters=zeros(d,top*k*(k-1),'single'); last=0; threshold=0; for j=1:k covj=squeeze(covs(:,:,j)); l=chol(covj+gamma*eye(d))'; for i=1:k 如果 j~=i covi=squeeze(covs(:,:,i)); C=l\covi/l'; CS=0.5*(C+C'); [v,L]=eigs(CS,top); V=l'\v; take=find(diag(L)>=threshold); batch=length(take); fprintf('%u,%u,%u ', i, j, batch); filters(:,last+1:last+batch)=V(:,take); last=last+batch; end end fprintf('\n'); end clear covi covj covs C CS V v L % NB: augmenting kt/ks 与 .^2 terms is very slow 和 doesn't help filters=filters(:,1:last); ft=kt*filters; clear kt; kt=[ones(n,1,'single') sqrt(1+max(ft,0))-1 sqrt(1+max(-ft,0))-1]; clear ft; fs=ks*filters; clear ks filters; ks=[ones(m,1,'single') sqrt(1+max(fs,0))-1 sqrt(1+max(-fs,0))-1]; clear fs; [n,k]=size(yt); [m,d]=size(ks); for i=1:k ind=find(yt(:,i)==1); kind=kt(ind,:); ni=length(ind); covs(:,:,i)=double(kind'*kind)/ni; clear ind kind; end filters=zeros(d,top*k*(k-1),'single'); last=0; threshold=7.5; for j=1:k covj=squeeze(covs(:,:,j)); l=chol(covj+gamma*eye(d))'; for i=1:k 如果 j~=i covi=squeeze(covs(:,:,i)); C=l\covi/l'; CS=0.5*(C+C'); [v,L]=eigs(CS,top); V=l'\v; take=find(diag(L)>=threshold); batch=length(take); fprintf('%u,%u,%u ', i, j, batch); filters(:,last+1:last+batch)=V(:,take); last=last+batch; end end fprintf('\n'); end fprintf('gamma=%g,top=%u,threshold=%g\n',gamma,top,threshold); fprintf('last=%u filtered=%u\n', last, size(filters,2) - last); clear covi covj covs C CS V v L filters=filters(:,1:last); ft=kt*filters; clear kt; kt=[sqrt(1+max(ft,0))-1 sqrt(1+max(-ft,0))-1]; clear ft; fs=ks*filters; clear ks filters; ks=[sqrt(1+max(fs,0))-1 sqrt(1+max(-fs,0))-1]; clear fs; trainx=[ones(n,1,'single') kt kt.^2]; clear kt; testx=[ones(m,1,'single') ks ks.^2]; clear ks; C=chol(0.5*(trainx'*trainx)+sqrt(n)*eye(size(trainx,2)),'lower'); w=C'\(C\(trainx'*yt)); pt=trainx*w; ps=testx*w; [~,trainy]=max(yt,[],2); [~,testy]=max(ys,[],2); for i=1:5 xn=[pt pt.^2/2 pt.^3/6 pt.^4/24]; xm=[ps ps.^2/2 ps.^3/6 ps.^4/24]; c=chol(xn'*xn+sqrt(n)*eye(size(xn,2)),'lower'); ww=c'\(c\(xn'*yt)); ppt=SimplexProj(xn*ww); pps=SimplexProj(xm*ww); w=C'\(C\(trainx'*(yt-ppt))); pt=ppt+trainx*w; ps=pps+testx*w; [~,yhatt]=max(pt,[],2); [~,yhats]=max(ps,[],2); errort=sum(yhatt~=trainy)/n; errors=sum(yhats~=testy)/m; fprintf('%u,%g,%g\n',i,errort,errors) end fprintf('%4s\t', 'pred'); for true=1:k fprintf('%5u', true-1); end fprintf('%5s\n%4s\n', '!=', 'true'); for true=1:k fprintf('%4u\t', true-1); trueidx=find(testy==true); for predicted=1:k predidx=find(yhats(trueidx)==predicted); fprintf('%5u', sum(predidx>0)); end predidx=find(yhats(trueidx)~=true); fprintf('%5u\n', sum(predidx>0)); end toc(start) end % http://arxiv.org/pdf/1309.1541v1.pdf function X = SimplexProj(Y) [N,D] = size(Y); X = sort(Y,2,'descend'); Xtmp = bsxfun(@times,cumsum(X,2)-1,(1./(1:D))); X = max(bsxfun(@minus,Y,Xtmp(sub2ind([N,D],(1:N)',sum(X>Xtmp,2)))),0); end当我在台式机上运行它时
>> calgevsquared 2,1,20 3,1,20 4,1,20 5,1,20 6,1,20 7,1,20 8,1,20 9,1,20 10,1,20 1,2,20 3,2,20 4,2,20 5,2,20 6,2,20 7,2,20 8,2,20 9,2,20 10,2,20 1,3,20 2,3,20 4,3,20 5,3,20 6,3,20 7,3,20 8,3,20 9,3,20 10,3,20 1,4,20 2,4,20 3,4,20 5,4,20 6,4,20 7,4,20 8,4,20 9,4,20 10,4,20 1,5,20 2,5,20 3,5,20 4,5,20 6,5,20 7,5,20 8,5,20 9,5,20 10,5,20 1,6,20 2,6,20 3,6,20 4,6,20 5,6,20 7,6,20 8,6,20 9,6,20 10,6,20 1,7,20 2,7,20 3,7,20 4,7,20 5,7,20 6,7,20 8,7,20 9,7,20 10,7,20 1,8,20 2,8,20 3,8,20 4,8,20 5,8,20 6,8,20 7,8,20 9,8,20 10,8,20 1,9,20 2,9,20 3,9,20 4,9,20 5,9,20 6,9,20 7,9,20 8,9,20 10,9,20 1,10,20 2,10,20 3,10,20 4,10,20 5,10,20 6,10,20 7,10,20 8,10,20 9,10,20 2,1,15 3,1,20 4,1,20 5,1,20 6,1,20 7,1,20 8,1,20 9,1,20 10,1,20 1,2,20 3,2,20 4,2,20 5,2,20 6,2,20 7,2,20 8,2,20 9,2,20 10,2,20 1,3,20 2,3,11 4,3,17 5,3,20 6,3,20 7,3,19 8,3,18 9,3,18 10,3,19 1,4,20 2,4,12 3,4,20 5,4,20 6,4,12 7,4,20 8,4,19 9,4,15 10,4,20 1,5,20 2,5,12 3,5,20 4,5,20 6,5,20 7,5,20 8,5,16 9,5,20 10,5,9 1,6,18 2,6,13 3,6,20 4,6,12 5,6,20 7,6,18 8,6,20 9,6,13 10,6,18 1,7,20 2,7,14 3,7,20 4,7,20 5,7,20 6,7,20 8,7,20 9,7,20 10,7,20 1,8,20 2,8,14 3,8,20 4,8,20 5,8,20 6,8,20 7,8,20 9,8,20 10,8,12 1,9,20 2,9,9 3,9,20 4,9,15 5,9,18 6,9,11 7,9,20 8,9,17 10,9,16 1,10,20 2,10,14 3,10,20 4,10,20 5,10,14 6,10,20 7,10,20 8,10,12 9,10,20 gamma=0.1,top=20,threshold=7.5 last=1630 filtered=170 1,0.0035,0.0097 2,0.00263333,0.0096 3,0.00191667,0.0092 4,0.00156667,0.0093 5,0.00141667,0.0091 pred 0 1 2 3 4 5 6 7 8 9 != true 0 977 0 1 0 0 1 0 1 0 0 3 1 0 1129 2 1 0 0 1 1 1 0 6 2 1 1 1020 0 1 0 0 6 3 0 12 3 0 0 1 1004 0 1 0 2 1 1 6 4 0 0 0 0 972 0 4 0 2 4 10 5 1 0 0 5 0 883 2 1 0 0 9 6 4 2 0 0 2 2 947 0 1 0 11 7 0 2 5 0 0 0 0 1018 1 2 10 8 1 0 1 1 1 1 0 1 966 2 8 9 1 1 0 2 5 2 0 4 1 993 16 Elapsed time is 186.147659 seconds.这是一个很好的混淆矩阵,可与(置换不变)mnist上的最新深度学习结果相提并论。在本文中,我们报告的数字稍差一些(96个测试错误),因为对于一篇论文,我们必须通过对训练集进行交叉验证来选择超参数,而不是像博客文章那样选择超参数。
此处所述的技术实际上仅对超薄设计矩阵有用(即,有很多示例,但没有太多特征):如果原始特征维数太大(例如,$>10 ^ 4 $)比单纯使用标准广义特征求解器变得缓慢或不可行,并且还需要其他技巧。此外,如果类数太大而不是解决$ O(k ^ 2)$广义特征值问题,那也是不合理的。我们正在努力解决这些问题,我们也很高兴将此策略扩展到结构化预测。希望我们在接下来的几届会议上能有更多的话要说。