京东6.18大促主会场领京享红包更优惠

 找回密码
 立即注册

QQ登录

只需一步,快速开始

查看: 4929|回复: 0

如何学习SVM(支持向量机)以及改进实现SVM算法程序

[复制链接]

10

主题

0

回帖

10

积分

新手上路

积分
10
发表于 2019-5-8 03:16:31 | 显示全部楼层 |阅读模式 来自 中国
雷锋网 AI 科技评论按,本文为韦易笑在知乎问题如何学习SVM(支持向量机)以及改进实现SVM算法程序下面的回复,雷锋网 AI 科技评论获其授权转载。以下为正文:: S( ~" L% r9 W' T% S/ M) q; H$ t! O
学习 SVM 的最好方法是实现一个 SVM,可讲理论的很多,讲实现的太少了。
; t+ R' i8 o2 M/ U9 H假设你已经读懂了 SVM 的原理,并了解公式怎么推导出来的,比如到这里:
! k' Z3 A5 J, T8 n4 ~3 P
* ?4 P1 V+ T3 _. v8 YSVM 的问题就变成:求解一系列满足约束的 alpha 值,使得上面那个函数可以取到最小值。然后记录下这些非零的 alpha 值和对应样本中的 x 值和 y 值,就完成学习了,然后预测的时候用:
7 y6 r- V9 ?# _3 w! r% D; R2 j8 t( D* t
上面的公式计算出 f(x) ,如果返回值 > 0 那么是 +1 类别,否则是 -1 类别,先把这一步怎么来的,为什么这么来找篇文章读懂,不然你会做的一头雾水。% P! O1 |! W0 f$ k
那么剩下的 SVM 实现问题就是如何求解这个函数的极值。方法有很多,我们先找个起点,比如 Platt 的 SMO 算法,它后面有伪代码描述怎么快速求解 SVM 的各个系数。
' `+ d6 T1 t7 N. J/ H; H) _第一步:实现传统的 SMO 算法
/ G, Y1 r5 J( Q+ N6 _4 ~2 s- T& e现在大部分的 SVM 开源实现,源头都是 platt 的 smo 算法,读完他的文章和推导,然后照着伪代码写就行了,核心代码没几行:
, g9 y. B3 l0 Q; [procedure takeStep(i1,i2)
1 b. U5 G6 x3 |: W& o" n if (i1 == i2) return 0/ D- s" D, L# ?+ Y7 c( S
alph1 = Lagrange multiplier for i1
5 {. ^1 G) }$ }) W2 i. ?& X) u& Q! _  [. s) ?+ |8 `
y1 = target[i1]
8 N( p0 m. j/ f( ?8 a1 Y8 R; `& ?& T: k
E1 = SVM output on point[i1] – y1 (check in error cache)
' {6 r3 F" c5 T* W5 y$ \. J, n8 A6 P/ z% ?) N' Y  `  L
s = y1*y2
' Q& `$ @# c9 [8 r+ t' a( c% g, w* w' n$ A2 C
Compute L, H via equations (13) and (14)
, o' Z1 ~  V7 A# K2 Y6 L
, x( v4 t2 f$ ]" g* ?& Rif (L == H)
' V* `. W0 X+ u  ?# v
0 P+ N" R4 Y8 ]8 |) Treturn 0
- G) \  \' N% M% F
% h* {3 t- C4 n, Q% dk11 = kernel(point[i1],point[i1])* W5 ]2 O4 t  E1 C7 Y% o( g0 E
1 B6 H! G( U% Y$ L# G7 a
k12 = kernel(point[i1],point[i2])) O  d8 q5 V" {, r) h

6 O- z3 R4 g/ u, ~. R  J: e3 z# bk22 = kernel(point[i2],point[i2])0 \+ e. ~% l% d2 }% k

- l5 |. @% S$ q$ G. r7 m, seta = k11+k22-2*k12
2 _& p# j* |) y- s# U/ I! u3 S0 V4 X! J
if (eta > 0)
- Q! i# g% e* r7 h. T
2 h3 F) [; m/ s+ l' m) S4 X7 \{( }( i7 ~9 _$ N# L2 `5 t: p
4 M: W$ P6 s/ P1 k$ q& Z3 [
a2 = alph2 + y2*(E1-E2)/eta
3 E) l& ]0 g; A4 y! G/ W9 ?% X% T4 p7 L6 b
if (a2 < L) a2 = L( x) S# s, u; P( m( g$ U! w
9 t. k) Y0 P5 q. y- q/ ]4 `
else if (a2 > H) a2 = H# }$ L, n* W- w! f2 N$ @1 i# L

. o3 W8 Z) l6 x1 a) C9 p' I& _}
" T* ]+ O, l& e2 h9 h
1 @' f* [. P: _% W: y% v2 Helse+ `# t3 \/ |8 f! s# Y

& i' H. V, j; d5 r9 N% @. U9 I7 O{5 R; w* F' ?8 E5 R

" C3 s- n7 G4 p& V% HLobj = objective function at a2=L) Q$ F2 U( P0 m5 o' Q
! I; I/ U. J1 v  K
Hobj = objective function at a2=H
: p0 g$ k8 M% L7 H
2 T- K# Q2 M) N  P" d( H' t2 wif (Lobj < Hobj-eps), ]# \& s8 z5 t& F  W" L, H
! V0 S; w+ a/ U9 T
a2 = L
9 @# F+ b0 b) I) k7 q: k; t8 Y& I5 t  B' g
else if (Lobj > Hobj+eps)
2 P$ v, ?$ l0 t% p1 f& g  x- k  \0 s$ W
a2 = H/ d, K& E4 b: A
. B( V" g3 ~) o+ J0 g& Z! F+ `/ h; d
else- C% X1 E; I0 Q+ _
# E' k- U1 h+ I0 s
a2 = alph2
& Y( c/ ~' p! q% r. |
: g3 [  ~' V2 |  j) M# A* _! G}& y$ K+ J5 z/ o! f: N

5 ^8 W8 p8 {2 |0 D3 H% z, ~. Kif (|a2-alph2| < eps*(a2+alph2+eps))
9 T8 Y. P* w8 e! p
" R+ h. K2 X5 ~return 0& `* u/ q1 s  c' b- E

4 k/ Y- u; a) `% e0 Y; Sa1 = alph1+s*(alph2-a2)
6 u. q# l; s- p3 \. `) L8 Z% d! n7 [- |3 C7 b  D3 o; L
Update threshold to reflect change in Lagrange multipliers
( y1 g. D; N; ^/ p
8 F- v( V3 v5 n9 g. S0 sUpdate weight vector to reflect change in a1 & a2, if SVM is linear
; n7 X% `" z8 f( j( ]) ]- q( I# U' Q4 [; @$ i) |7 K# @8 C% p4 i
Update error cache using new Lagrange multipliers) U, v% f2 I$ h" E& C% Y  i4 }
& X8 T) S! A: o( D
Store a1 in the alpha array2 ~& V- R. l3 g4 D8 L* }3 n
3 h; t9 O4 i$ @* v. ]1 c# I
Store a2 in the alpha array
$ ]4 e& v  j# a- m
; W2 Y/ a% ]0 |1 y! Sreturn 1
, o, B9 p9 R4 a1 L/ a& I0 N, Y1 U6 Pendprocedure
6 V6 V- ^: F  Z核心代码很紧凑,就是给定两个 ai, aj,然后迭代出新的 ai, aj 出来,还有一层循环会不停的选择最需要被优化的系数 ai, aj,然后调用这个函数。如何更新权重和 b 变量(threshold)文章里面都有说,再多调试一下,可以用 python 先调试,再换成 C/C++,保证得到一个正确可用的 SVM 程序,这是后面的基础。
1 l) h% D" }- s% ?, u4 O第二步:实现核函数缓存
/ c4 ?. V3 R, n* A$ k观察下上面的伪代码,开销最大的就是计算核函数 K(xi, xj),有些计算又反复用到,一个 100 个样本的数据集求解,假设总共要调用核函数 20 万次,但是 xi, xj 的组和只有 100x100=1 万种,有缓存的话你的效率可以提升 20 倍。
; ^8 O. U4 Q2 U样本太大时,如果你想存储所有核函数的组和,需要 N*N * sizeof(double) 的空间,如果训练集有 10 万个样本,那么需要 76 GB 的内存,显然是不可能实现的,所以核函数缓存是一个有限空间的 LRU 缓存,SVM 的 SMO 求解过程中其实会反复用到特定的几个有限的核函数求解,所以命中率不用担心。
' [5 A! N6 R" J  d( T, h有了这个核函数缓存,你的 SVM 求解程序能瞬间快几十倍。
: y! d- i, z" O2 ?) u第三步:优化误差值求解# A" k( s, p( i: c- C" w, b8 r
注意看上面的伪代码,里面需要计算一个估计值和真实值的误差 Ei 和 Ej,他们的求解方法是:
( D+ K( @0 w* v9 mE(i) = f(xi) - yi) `+ `3 l" g. t
这就是目前为止 SMO 这段为代码里代价最高的函数,因为回顾下上面的公式,计算一遍 f(x) 需要 for 循环做乘法加法。
; w+ b8 g9 h$ D+ @platt 的文章建议是做一个 E 函数的缓存,方便后面选择 i, j 时比较,我看到很多入门版本 SVM 实现都是这么做。其实这是有问题的,后面我们会说到。最好的方式是定义一个 g(x) 令其等于:
6 K* b+ t! Q) X  a* P1 c0 R$ F
' i+ ]( I" J9 p# r, L$ }也就是 f(x) 公式除了 b 以外前面的最费时的计算,那么我们随时可以计算误差:
3 ~" N: ^! q6 _- B5 d7 s% OE(j) = g(xj) + b - yj- H2 K5 o7 I1 m$ O
所以最好的办法是对 g(x) 进行缓存,platt 的方法里因为所有 alpha 值初始化成了 0,所以 g(x) 一开始就可以全部设置成 0,稍微观察一下 g(x) 的公式,你就会发现,因为去掉了 b 的干扰,而每次 SMO 迭代更新 ai, aj 参数时,这两个值都是线性变化的,所以我们可以给 g(x) 求关于 a 的偏导,假设 ai,aj 变化了步长 delta,那么所有样本对应的 g(x) 加上 delta 乘以针对 ai, aj 的偏导数就行了,具体代码类似:% q$ w* z; N- N( ]  X- X
double Kik = kernel(i, k);* P- U8 l% E) U* @$ S0 t
double Kjk = kernel(j, k);
: \/ S# P2 g# ^8 d; O$ V  ~' r. dG[k] += delta_alpha_i * Kik * y + delta_alpha_j * Kjk * y[j];
( Y9 q  d' a( ^9 M8 [0 n9 D把这段代码放在 takeStep 后面,每次成功更新一对 ai, aj 以后,更新所有样本对应的 g(x) 缓存,这样通过每次迭代更新 g(x) 避免了大量的重复计算。
6 {1 w) Z+ _- f这其实是很直白的一种优化方式,我查了一下,有人专门发论文就讲了个类似的方法。
/ c3 G& p# g6 b( y3 `第四步:实现冷热数据分离2 E' x2 b- j; U: O1 R
Platt 的文章里也证明过一旦某个 alpha 出于边界(0 或者 C)的时候,就很不容易变动,而且伪代码也是优先在工作集里寻找 > 0 and < C 的 alpha 值进行优化,找不到了,再对工作集整体的 alpha 值进行迭代。) o; X; v% ]! w5 y0 _2 F& s
那么我们势必就可以把工作集分成两个部分,热数据在前(大于 0 小于 C 的 alpha 值),冷数据在后(小于等于 0 或者大于等于 C 的 alpha)。6 y- D( L% k6 C5 @( |( r2 p' c
随着迭代加深,会发现大部分时候只需要在热数据里求解,并且热数据的大小会逐步不停的收缩,所以区分了冷热以后 SVM 大部分都在针对有限的热数据迭代,偶尔不行了,再全部迭代一次,然后又回到冷热迭代,性能又能提高不少。3 x: h( g) L: z8 u5 {
第五步:支持 Ensemble3 N0 s( {. a8 ^
大家都知道,通过 Ensemble 可以让多个不同的弱模型组和成一个强模型,而传统 SVM 实现并不能适应一些类似 AdaBoost 的集成方法,所以我们需要做一些改动。可以让外面针对某一个分类传入一个“权重”过来,修正 SVM 的识别结果。, G; e# @1 ?' Y
最传统的修改方式就是将不等式约束 C 分为 Cp 和 Cn 两个针对 +1 分类的 C 及针对 -1 分类的 C。修改方式是直接用原始的 C 乘以各自分类的权重,得到 Cp 和 Cn,然后迭代时,不同的样本根据它的 y 值符号,用不同的 C 值带入计算。
& I/ M7 V' `. B" `/ t" }这样 SVM 就能用各种集成方法同其他模型一起组成更为强大精准的模型了。0 x* l# ?+ U* s2 K- \' z
实现到这一步你就得到了功能上和性能上同 libsvm 类似的东西,接下来我们继续优化。( e, f7 }. Y( M2 A5 ?+ @1 ?
第六步:继续优化核函数计算
' ?3 x  ~/ t5 n. v核函数缓存非常消耗内存,libsvm 数学上已经没得挑了,但是工程方面还有很大改进余地,比如它的核缓存实现。
1 H+ z! E0 z5 R+ [由于标准 SVM 核函数用的是两个高维矢量的内积,根据内积的几个条件,SVM 的核函数又是一个正定核,即 K(xi, xj) = K(xj, xi),那么我们同样的内存还能再多存一倍的核函数,性能又能有所提升。
7 O5 Z% @+ T0 |  Z针对核函数的计算和存储有很多优化方式,比如有人对 NxN 的核函数矩阵进行采样,只计算有限的几个核函数,然后通过插值的方式求解出中间的值。还有人用 float 存储核函数值,又降低了一倍空间需求。
8 E' Y8 [+ \7 ^) _. R第七步:支持稀疏向量和非稀疏向量
. m$ _5 [1 J6 s. ]. D$ g对于高维样本,比如文字这些,可能有上千维,每个样本的非零特征可能就那么几个,所以稀疏向量会比较高效,libsvm 也是用的稀疏向量。2 M. E) l3 o( X' }
但是还有很多时候样本是密集向量,比如一共 200 个特征,大部分样本都有 100个以上的非零特征,用稀疏向量存储的话就非常低效了,openCV 的 SVM 实现就是非稀疏向量。
6 u" M/ m& x+ p) W5 w非稀疏向量直接是用数组保存样本每个特征的值,在工程方面就有很多优化方式了,比如用的最多的求核函数的时候,直接上 SIMD 指令或者 CUDA,就能获得更好的计算性能。+ ^; [! m; u' w+ e
所以最好的方式是同时支持稀疏和非稀疏,兼顾时间和空间效率,对不同的数据选择最适合的方式。
: a( U7 \4 B/ c. f第八步:针对线性核进行优化
2 S. \- z, Z& w: a# p. \* I传统的 SMO 方法,是 SVM 的通用求解方法,然而针对线性核,就是:
2 t. t! h- d/ M2 e: B1 BK(xi, xj) = xi . xj6 ^4 @' B, b+ h8 o4 O% w
还有很多更高效的求解思路,比如 Pegasos 算法就用了一种类似随机梯度下降的方法,快速求 SVM 的解权重 w,如果你的样本适合线性核,使用一些针对性的非 SMO 算法可以极大的优化 SVM 求解,并且能处理更加庞大的数据集,LIBLINEAR 就是做这件事情的。4 @( H# K* z) k9 p' ]) `' l0 M
同时这类算法也适合 online 训练和并行训练,可以逐步更新增量训练新的样本,还可以用到多核和分布式计算来训练模型,这是 SMO 算法做不到的地方。1 M) D1 G/ k9 {( ?
但是如果碰到非线性核,权重 w 处于高维核空间里(有可能无限维),你没法梯度下降迭代 w,并且 pegasos 的 pdf 里面也没有提到如何用到非线性核上,LIBLINEAR 也没有办法处理非线性核。7 W, c4 t0 c/ \5 I2 J4 ^
或许哪天出个数学家又找到一种更好的方法,可以用类似 pegasos 的方式求解非线性核,那么 SVM 就能有比较大的进展了。
( {4 O% H# k$ I; k后话
% v( S$ ?0 S8 s) V, H上面八条,你如果实现前三条,基本就能深入理解 SVM 的原理了,如果实现一大半,就可以得到一个类似 libsvm 的东西,全部实现,你就能得到一个比 libsvm 更好用的 SVM 库了。
) u  V7 S2 G& i4 }% A1 m, ~上面就是如何实现一个相对成熟的 SVM 模型的思路,以及配套优化方法,再往后还有兴趣,可以接着实现支持向量回归,也是一个很有用的东西。
( d% v: n1 t0 @$ c) P: y/ v+ [; d$ ]: X7 q' `+ v5 o
来源:http://www.yidianzixun.com/article/0Lv0UIiC" \  r4 H6 P0 `; |
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×

帖子地址: 

梦想之都-俊月星空 优酷自频道欢迎您 http://i.youku.com/zhaojun917
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|手机版|小黑屋|梦想之都-俊月星空 ( 粤ICP备18056059号 )|网站地图

GMT+8, 2026-4-20 11:15 , Processed in 0.040072 second(s), 24 queries .

Powered by Mxzdjyxk! X3.5

© 2001-2026 Discuz! Team.

快速回复 返回顶部 返回列表