京东6.18大促主会场领京享红包更优惠

 找回密码
 立即注册

QQ登录

只需一步,快速开始

查看: 4806|回复: 0

如何学习SVM(支持向量机)以及改进实现SVM算法程序

[复制链接]

10

主题

0

回帖

10

积分

新手上路

积分
10
发表于 2019-5-8 03:16:31 | 显示全部楼层 |阅读模式 来自 中国
雷锋网 AI 科技评论按,本文为韦易笑在知乎问题如何学习SVM(支持向量机)以及改进实现SVM算法程序下面的回复,雷锋网 AI 科技评论获其授权转载。以下为正文:
: D$ l6 q. Y% ]7 o学习 SVM 的最好方法是实现一个 SVM,可讲理论的很多,讲实现的太少了。/ n4 m$ F  o5 Y* k0 m7 d
假设你已经读懂了 SVM 的原理,并了解公式怎么推导出来的,比如到这里:4 s0 w/ W8 f2 K5 b& [0 \" v& p
2 f  z8 K6 z  }1 v/ }
SVM 的问题就变成:求解一系列满足约束的 alpha 值,使得上面那个函数可以取到最小值。然后记录下这些非零的 alpha 值和对应样本中的 x 值和 y 值,就完成学习了,然后预测的时候用:
. f' i+ H- a8 ]$ y: I. k& a& E& T8 \2 w# F0 n5 v
上面的公式计算出 f(x) ,如果返回值 > 0 那么是 +1 类别,否则是 -1 类别,先把这一步怎么来的,为什么这么来找篇文章读懂,不然你会做的一头雾水。
7 B3 b9 u4 b+ ?: j0 {那么剩下的 SVM 实现问题就是如何求解这个函数的极值。方法有很多,我们先找个起点,比如 Platt 的 SMO 算法,它后面有伪代码描述怎么快速求解 SVM 的各个系数。
5 ?8 j0 V6 _! Q4 x' l, e( h% L& v( N第一步:实现传统的 SMO 算法& ]) x" H# _$ O( W% ]
现在大部分的 SVM 开源实现,源头都是 platt 的 smo 算法,读完他的文章和推导,然后照着伪代码写就行了,核心代码没几行:
1 V3 T0 r2 F! r/ {* D* c1 tprocedure takeStep(i1,i2)' v, ]8 w" G6 V8 P, n
if (i1 == i2) return 03 _6 w* U6 G0 D# e4 c
alph1 = Lagrange multiplier for i1
- ]$ w9 H  c, C6 j- ^& @
6 K- N/ b% @2 L1 ^y1 = target[i1]
1 a! T! P+ t; m0 k# a; D
& P3 B* a) s) a- N3 {: Z( ]: ^& ^E1 = SVM output on point[i1] – y1 (check in error cache)1 y  V# A7 T+ p# ?
4 m' i$ s, b8 D$ i7 Z
s = y1*y2
5 v+ D$ h  |  R# C! ]: A$ @4 ^! }7 Z9 q  S* _" _, j6 E2 T
Compute L, H via equations (13) and (14)
  e9 x5 b' ~+ F* W* _! \1 G  m4 Q  W0 K# g: y
if (L == H)
8 x% m" f1 m# e2 }/ E
" F! L! Q6 I. l- H% m- Dreturn 0' J( s6 v; Q) y* c; O6 a9 `: p
6 ]$ d) B- Y6 J" J4 X* L
k11 = kernel(point[i1],point[i1])* v5 O3 s% \$ w8 j9 N: E
: P% {. H2 m% {8 i8 w0 ]
k12 = kernel(point[i1],point[i2])/ ~1 R7 a0 ]$ b" U

2 L. U4 ]0 |  T3 Bk22 = kernel(point[i2],point[i2])
8 f( ^- C% Z; p& Y2 h5 C! V. }; v/ O6 Z
eta = k11+k22-2*k12
8 E! t. D& B/ x( Z( b1 M4 f, G( n4 N. @- Y2 j# o$ z; w
if (eta > 0)
* l7 x6 y' @& F+ r
* ?4 i4 U3 {/ A& c$ Y$ d{
; E8 B/ f# }- d4 k" M. K) o1 ?# x7 N! c$ t
a2 = alph2 + y2*(E1-E2)/eta" {+ [& b6 [' B4 m& p
" N& i1 F$ U, e# G( @# s8 Z  i3 e
if (a2 < L) a2 = L+ V! [$ s+ f1 a: J# u5 U7 h
- ~3 h% t/ X* J
else if (a2 > H) a2 = H
4 x1 S3 C2 |* B3 ~, y4 A. j
# _0 G. r' q9 l, L}
. D, G1 r# o5 Q* x
' V% R: r$ A2 ielse, u) |) Y8 ]) z# B/ p. F$ H; o

- |$ @# P; z/ y{; G! f& ^: v" @
  k6 R% S0 o- v4 H, Z5 _) q4 h. l
Lobj = objective function at a2=L
1 [  ~$ n8 `" n6 }: Q% j7 A
. Y) L! w8 T0 U3 m: `  `+ ~Hobj = objective function at a2=H9 S4 ^  k  A5 r9 G
9 _: a0 K9 c% s/ h; Z! c
if (Lobj < Hobj-eps)4 t# [0 z/ ~7 S5 W, d- Q7 p7 Y

, P+ A3 V' M) M/ J( _9 m# g" I4 Ea2 = L
- o7 U; C+ Q# j6 c9 `
9 z* o0 V1 {# a$ z7 Z- V0 Zelse if (Lobj > Hobj+eps)
/ K# S1 P# d; J& |* v" l5 o" O
; A' `( B6 B" d8 @a2 = H
2 x# C0 F1 `6 @" |5 V/ X/ q( e8 R% J" }( m9 s
else
, U7 m* Z4 z) [" G: D* L+ o$ G! K8 r
a2 = alph22 a4 b# H- O* t3 o4 _$ z; |5 S
0 v3 M9 S: j% f+ I0 F% \# D
}' B* M$ V. ]! @, T$ |5 p* Y

0 R% c3 X1 }3 r6 Lif (|a2-alph2| < eps*(a2+alph2+eps))
2 z3 U' v3 W7 y3 N
) N5 z" V2 H5 s- A2 f' A  h! zreturn 0$ ^  v) V' q! B* d( _0 ?4 a

" x; J! f" ]* V2 B: r: D+ I3 ua1 = alph1+s*(alph2-a2)
9 e% @1 u) s7 m1 ~1 l  y% j5 H0 j! ]0 w! C8 y2 J* t
Update threshold to reflect change in Lagrange multipliers
, S6 t2 f) b% @5 H
" E6 d: u% }! }, |$ MUpdate weight vector to reflect change in a1 & a2, if SVM is linear
! d/ D5 o3 c$ W9 }3 s
0 L' F6 p# K1 G" u6 p$ oUpdate error cache using new Lagrange multipliers2 J! D2 h$ G3 L: h1 K4 a2 O, i

& Z; S0 b. M3 x3 ]1 T! e2 k' ~+ PStore a1 in the alpha array
' M0 F$ s! n! M( f$ t
, F$ ]8 X& z! q  E% X! m% IStore a2 in the alpha array, }  U; ^* z" ~; G( A8 T
4 r# Q$ s% \2 a
return 1: V$ u6 w9 w4 Z, ^9 P+ H0 P
endprocedure8 T, r  g3 g& O2 Q: G* J* H+ P
核心代码很紧凑,就是给定两个 ai, aj,然后迭代出新的 ai, aj 出来,还有一层循环会不停的选择最需要被优化的系数 ai, aj,然后调用这个函数。如何更新权重和 b 变量(threshold)文章里面都有说,再多调试一下,可以用 python 先调试,再换成 C/C++,保证得到一个正确可用的 SVM 程序,这是后面的基础。
5 |- L4 P! i; Z  L第二步:实现核函数缓存
9 O" g2 z# U, X观察下上面的伪代码,开销最大的就是计算核函数 K(xi, xj),有些计算又反复用到,一个 100 个样本的数据集求解,假设总共要调用核函数 20 万次,但是 xi, xj 的组和只有 100x100=1 万种,有缓存的话你的效率可以提升 20 倍。
% K: e7 \# W1 D6 b% k样本太大时,如果你想存储所有核函数的组和,需要 N*N * sizeof(double) 的空间,如果训练集有 10 万个样本,那么需要 76 GB 的内存,显然是不可能实现的,所以核函数缓存是一个有限空间的 LRU 缓存,SVM 的 SMO 求解过程中其实会反复用到特定的几个有限的核函数求解,所以命中率不用担心。' N1 K, U8 ?! N
有了这个核函数缓存,你的 SVM 求解程序能瞬间快几十倍。. S8 f6 y4 V5 d5 R1 i; P/ q2 X; Q
第三步:优化误差值求解
8 k" w+ }: w8 @4 l* J- N. f注意看上面的伪代码,里面需要计算一个估计值和真实值的误差 Ei 和 Ej,他们的求解方法是:# P$ r; o0 f1 Y/ @' h5 e
E(i) = f(xi) - yi
9 L0 b5 j( W% `# m9 Y( ?8 i这就是目前为止 SMO 这段为代码里代价最高的函数,因为回顾下上面的公式,计算一遍 f(x) 需要 for 循环做乘法加法。
5 t+ o7 j7 O2 y$ Q! Qplatt 的文章建议是做一个 E 函数的缓存,方便后面选择 i, j 时比较,我看到很多入门版本 SVM 实现都是这么做。其实这是有问题的,后面我们会说到。最好的方式是定义一个 g(x) 令其等于:
+ Q# \6 p7 T8 g2 D. j* ]5 z% v0 x7 ]$ F
也就是 f(x) 公式除了 b 以外前面的最费时的计算,那么我们随时可以计算误差:
% o7 ~  `3 _( E! ~E(j) = g(xj) + b - yj
2 m% w, ?8 K4 n) m所以最好的办法是对 g(x) 进行缓存,platt 的方法里因为所有 alpha 值初始化成了 0,所以 g(x) 一开始就可以全部设置成 0,稍微观察一下 g(x) 的公式,你就会发现,因为去掉了 b 的干扰,而每次 SMO 迭代更新 ai, aj 参数时,这两个值都是线性变化的,所以我们可以给 g(x) 求关于 a 的偏导,假设 ai,aj 变化了步长 delta,那么所有样本对应的 g(x) 加上 delta 乘以针对 ai, aj 的偏导数就行了,具体代码类似:
7 G1 H; ~3 M3 A3 m, k- Rdouble Kik = kernel(i, k);
" ~6 M. z, `8 L( K$ M( mdouble Kjk = kernel(j, k);' h( j+ V9 Z4 p) a
G[k] += delta_alpha_i * Kik * y + delta_alpha_j * Kjk * y[j];" B, w. h( E& Y0 Z4 G
把这段代码放在 takeStep 后面,每次成功更新一对 ai, aj 以后,更新所有样本对应的 g(x) 缓存,这样通过每次迭代更新 g(x) 避免了大量的重复计算。
4 M  J& }% A7 U8 y( v8 d这其实是很直白的一种优化方式,我查了一下,有人专门发论文就讲了个类似的方法。
0 i& e( N/ |7 z第四步:实现冷热数据分离0 ^% J* D$ u! q, h/ y& R% |
Platt 的文章里也证明过一旦某个 alpha 出于边界(0 或者 C)的时候,就很不容易变动,而且伪代码也是优先在工作集里寻找 > 0 and < C 的 alpha 值进行优化,找不到了,再对工作集整体的 alpha 值进行迭代。
4 B( x6 o0 W: ~( u3 [" C8 I那么我们势必就可以把工作集分成两个部分,热数据在前(大于 0 小于 C 的 alpha 值),冷数据在后(小于等于 0 或者大于等于 C 的 alpha)。* {2 R3 X9 e9 ~8 z
随着迭代加深,会发现大部分时候只需要在热数据里求解,并且热数据的大小会逐步不停的收缩,所以区分了冷热以后 SVM 大部分都在针对有限的热数据迭代,偶尔不行了,再全部迭代一次,然后又回到冷热迭代,性能又能提高不少。2 _% g8 R, I, |8 f( C" O
第五步:支持 Ensemble
" N, a! _, M: m. N0 f" t6 Q: R大家都知道,通过 Ensemble 可以让多个不同的弱模型组和成一个强模型,而传统 SVM 实现并不能适应一些类似 AdaBoost 的集成方法,所以我们需要做一些改动。可以让外面针对某一个分类传入一个“权重”过来,修正 SVM 的识别结果。
/ k1 P  g2 D! k8 l最传统的修改方式就是将不等式约束 C 分为 Cp 和 Cn 两个针对 +1 分类的 C 及针对 -1 分类的 C。修改方式是直接用原始的 C 乘以各自分类的权重,得到 Cp 和 Cn,然后迭代时,不同的样本根据它的 y 值符号,用不同的 C 值带入计算。
& v7 B1 h8 {6 S' O这样 SVM 就能用各种集成方法同其他模型一起组成更为强大精准的模型了。# G& l2 o: o  r6 u
实现到这一步你就得到了功能上和性能上同 libsvm 类似的东西,接下来我们继续优化。5 F- Q7 k0 \4 O3 J/ u
第六步:继续优化核函数计算
0 b1 i$ ]) S! T1 z; Q" o. ]+ m& }核函数缓存非常消耗内存,libsvm 数学上已经没得挑了,但是工程方面还有很大改进余地,比如它的核缓存实现。
$ q* A+ f7 Q5 }& M, \( @4 C由于标准 SVM 核函数用的是两个高维矢量的内积,根据内积的几个条件,SVM 的核函数又是一个正定核,即 K(xi, xj) = K(xj, xi),那么我们同样的内存还能再多存一倍的核函数,性能又能有所提升。
1 j9 p) f" Y. K5 h) a% \1 N针对核函数的计算和存储有很多优化方式,比如有人对 NxN 的核函数矩阵进行采样,只计算有限的几个核函数,然后通过插值的方式求解出中间的值。还有人用 float 存储核函数值,又降低了一倍空间需求。: l  q6 Z$ j7 M8 }! e
第七步:支持稀疏向量和非稀疏向量
0 Y# Q% w2 ?% ]对于高维样本,比如文字这些,可能有上千维,每个样本的非零特征可能就那么几个,所以稀疏向量会比较高效,libsvm 也是用的稀疏向量。
/ ~4 o1 v. q2 ~" k8 ?0 ]但是还有很多时候样本是密集向量,比如一共 200 个特征,大部分样本都有 100个以上的非零特征,用稀疏向量存储的话就非常低效了,openCV 的 SVM 实现就是非稀疏向量。
! w7 [6 ~* Q, i" b非稀疏向量直接是用数组保存样本每个特征的值,在工程方面就有很多优化方式了,比如用的最多的求核函数的时候,直接上 SIMD 指令或者 CUDA,就能获得更好的计算性能。
+ _4 Z: W' _9 n$ _所以最好的方式是同时支持稀疏和非稀疏,兼顾时间和空间效率,对不同的数据选择最适合的方式。: R' `$ {/ q1 w2 ]
第八步:针对线性核进行优化6 U( A' E2 A, M7 R: U+ r
传统的 SMO 方法,是 SVM 的通用求解方法,然而针对线性核,就是:
! t9 x& M9 Z; r$ a4 \K(xi, xj) = xi . xj! U. v. O5 T( h0 J$ u
还有很多更高效的求解思路,比如 Pegasos 算法就用了一种类似随机梯度下降的方法,快速求 SVM 的解权重 w,如果你的样本适合线性核,使用一些针对性的非 SMO 算法可以极大的优化 SVM 求解,并且能处理更加庞大的数据集,LIBLINEAR 就是做这件事情的。5 r4 f+ \- |& R
同时这类算法也适合 online 训练和并行训练,可以逐步更新增量训练新的样本,还可以用到多核和分布式计算来训练模型,这是 SMO 算法做不到的地方。6 l5 V  J; K" ~8 z
但是如果碰到非线性核,权重 w 处于高维核空间里(有可能无限维),你没法梯度下降迭代 w,并且 pegasos 的 pdf 里面也没有提到如何用到非线性核上,LIBLINEAR 也没有办法处理非线性核。
5 U' g/ H( s1 X或许哪天出个数学家又找到一种更好的方法,可以用类似 pegasos 的方式求解非线性核,那么 SVM 就能有比较大的进展了。
, w6 M$ P) _8 \9 i! l后话+ h+ L# d. }) W$ S: V* ^; r
上面八条,你如果实现前三条,基本就能深入理解 SVM 的原理了,如果实现一大半,就可以得到一个类似 libsvm 的东西,全部实现,你就能得到一个比 libsvm 更好用的 SVM 库了。
  z' R! o- L! ^  @/ @上面就是如何实现一个相对成熟的 SVM 模型的思路,以及配套优化方法,再往后还有兴趣,可以接着实现支持向量回归,也是一个很有用的东西。
% A; {6 E: R" d. P3 ?# a. j0 J1 d5 X. J' P1 C8 G( W$ S. h) `
来源:http://www.yidianzixun.com/article/0Lv0UIiC
/ c" c; @" i6 y, D; v% N. r+ {4 I免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×

帖子地址: 

梦想之都-俊月星空 优酷自频道欢迎您 http://i.youku.com/zhaojun917
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

站长推荐上一条 /6 下一条

QQ|手机版|小黑屋|梦想之都-俊月星空 ( 粤ICP备18056059号 )|网站地图

GMT+8, 2025-7-16 03:21 , Processed in 0.040232 second(s), 23 queries .

Powered by Mxzdjyxk! X3.5

© 2001-2025 Discuz! Team.

快速回复 返回顶部 返回列表