京东6.18大促主会场领京享红包更优惠

 找回密码
 立即注册

QQ登录

只需一步,快速开始

查看: 4849|回复: 0

如何学习SVM(支持向量机)以及改进实现SVM算法程序

[复制链接]

10

主题

0

回帖

10

积分

新手上路

积分
10
发表于 2019-5-8 03:16:31 | 显示全部楼层 |阅读模式 来自 中国
雷锋网 AI 科技评论按,本文为韦易笑在知乎问题如何学习SVM(支持向量机)以及改进实现SVM算法程序下面的回复,雷锋网 AI 科技评论获其授权转载。以下为正文:  n* p+ V- f, C1 N
学习 SVM 的最好方法是实现一个 SVM,可讲理论的很多,讲实现的太少了。
8 e- J9 V5 s; v假设你已经读懂了 SVM 的原理,并了解公式怎么推导出来的,比如到这里:: Q* A1 G8 Q9 j; x8 q

; P) Y5 A  g2 G: K; D( _! ZSVM 的问题就变成:求解一系列满足约束的 alpha 值,使得上面那个函数可以取到最小值。然后记录下这些非零的 alpha 值和对应样本中的 x 值和 y 值,就完成学习了,然后预测的时候用:& \; m: Z8 W) a
/ x* Y6 R. i, q; ^5 C" y; }3 R
上面的公式计算出 f(x) ,如果返回值 > 0 那么是 +1 类别,否则是 -1 类别,先把这一步怎么来的,为什么这么来找篇文章读懂,不然你会做的一头雾水。
* e+ m* e/ v& k那么剩下的 SVM 实现问题就是如何求解这个函数的极值。方法有很多,我们先找个起点,比如 Platt 的 SMO 算法,它后面有伪代码描述怎么快速求解 SVM 的各个系数。
0 v; w0 }; E3 i' D( S) S. j第一步:实现传统的 SMO 算法/ L/ f. Z. D0 M# V
现在大部分的 SVM 开源实现,源头都是 platt 的 smo 算法,读完他的文章和推导,然后照着伪代码写就行了,核心代码没几行:( a0 y' u  q7 ^
procedure takeStep(i1,i2)
: a# ]5 J4 \" M: J* L: }6 _( H0 u if (i1 == i2) return 08 \8 g9 ?% d5 U1 d% s. m
alph1 = Lagrange multiplier for i1
: J5 j6 Z9 _! F8 H& _  |5 p) K  R7 T& S' W0 D: B0 R
y1 = target[i1]/ E7 o" o' m3 f- ]
2 ^6 v2 T$ t. E9 \+ M2 T
E1 = SVM output on point[i1] – y1 (check in error cache)
5 r) k9 q7 i* V( ]; @2 T' i$ D: I& c5 F/ }/ F# h6 Y
s = y1*y2
( M4 T# Z4 }4 [9 O' c+ ]3 A- E- x* Z& Y7 F8 Z8 G; }, {
Compute L, H via equations (13) and (14)
. b3 v7 k# B; B4 ^8 S
; h' e0 G: w8 f6 ~2 {( eif (L == H)
7 ?7 n% P- H" p# {. M" b. @* e7 d) b9 Q0 R& v8 C/ C7 m
return 0
% m  ~1 s! e8 `0 F/ ]& s' s
, X7 Z1 Y) B& C- |9 Tk11 = kernel(point[i1],point[i1])' [3 I. A" y9 r

$ }4 E5 \- l8 s: r3 e2 _& Wk12 = kernel(point[i1],point[i2])! v4 w- e7 e, b* I4 w6 u. W
( Q. m, F2 r, _% V* ?0 ^8 x
k22 = kernel(point[i2],point[i2])
  |* D: h8 Q; F3 K! W
9 W# G6 z* E7 T& I5 Veta = k11+k22-2*k12- e: `! D% D4 ?( r) _& d- }8 c# q0 P9 y

+ _- K: @  [( H: g8 e. n! Aif (eta > 0)1 [% H) ?, `$ ?$ k9 y) L5 a

' U% M( M" v; |1 m{) b* V. x, _% g9 d8 G" T

' q& E, \7 L, n1 oa2 = alph2 + y2*(E1-E2)/eta- t+ G3 X( o9 N! l0 t

+ V. R% L% O; x( G; T& U& h5 y6 Fif (a2 < L) a2 = L4 W* T0 J/ O1 P1 i

3 K7 G7 N8 K: @else if (a2 > H) a2 = H
# k0 s8 k# _! a; f1 c  O# v
, p1 e! G  e1 U: z+ D}
8 X/ Z, G' m" j1 U6 N: T1 D" r
1 R; n% a/ g" w: a7 B! aelse8 _/ L/ S. l0 d' g. f0 t8 F
0 m4 n$ n# b8 R1 O  S
{
  j, u3 U# ~/ s
; `6 a7 Y$ b# [% Y# m4 o$ `; `Lobj = objective function at a2=L1 }/ a2 s& P& C( H/ J3 U# ~0 ~# X  o  b

' |5 M- V6 r! ~4 d" e' mHobj = objective function at a2=H
' S! t5 w5 n5 q! p0 J' I9 s( P5 r& @" l4 F" C2 w6 d4 i' N
if (Lobj < Hobj-eps)+ i/ {! l! h9 B0 X* r4 D& J
/ u8 O& ^. W" i! I1 ]$ i; m, z3 C3 |1 R
a2 = L( A4 w, z5 v& k8 j- C# Q: k

# [7 v" I8 _" E5 B3 a6 ~3 lelse if (Lobj > Hobj+eps)
! B% X9 K' D8 i+ \/ U# d, u# Q$ T0 g8 f- ~3 {5 G6 C
a2 = H
" |/ j/ C: O2 l2 R: p  `' p, T( P; c/ }* L- V9 \( M3 n1 S# b
else
% \5 r3 S# X" _' G" l: Z: D7 H4 q2 V# M* [
a2 = alph2
  a% q  P; `& B0 @9 b
, t1 A( J0 z) }2 O* l% m" g! h}
  E1 G" j) u* C# Q$ h1 ?% A
/ v0 i$ B: W5 Z$ z1 g1 \; Z6 S. z( qif (|a2-alph2| < eps*(a2+alph2+eps))7 W& Q6 g# m1 O) C) d

6 d  H& L! }% b) B: E3 `  Lreturn 0
7 E2 h9 `8 Q! A4 c* j) N$ h
& s4 |; b+ s$ @$ ~6 a# ]a1 = alph1+s*(alph2-a2)& ~5 ~7 L! B) o

7 l0 [6 Q3 ]8 V( J. r0 S& DUpdate threshold to reflect change in Lagrange multipliers
% Y% ]  D0 g" @  c4 V% k3 ~( _4 S  |0 @, }' l/ l8 H
Update weight vector to reflect change in a1 & a2, if SVM is linear
! \6 I; {4 ~6 S
' w7 k% Y% X. F' B: ?Update error cache using new Lagrange multipliers
5 X5 g$ s7 j) `! ?! G6 Z
$ \! \( W# c7 e  b9 }. NStore a1 in the alpha array0 ~! X, Z6 U: c2 e; e2 B
/ F3 ]6 z" V/ T. s6 r% u7 d2 B
Store a2 in the alpha array
9 W- U$ k1 ]3 [5 P( Z" R4 }9 K* i9 c
return 1
4 R% i$ t" e' ^$ Q3 Qendprocedure& ?- D( m0 R1 E; `7 ]$ T
核心代码很紧凑,就是给定两个 ai, aj,然后迭代出新的 ai, aj 出来,还有一层循环会不停的选择最需要被优化的系数 ai, aj,然后调用这个函数。如何更新权重和 b 变量(threshold)文章里面都有说,再多调试一下,可以用 python 先调试,再换成 C/C++,保证得到一个正确可用的 SVM 程序,这是后面的基础。" j/ x( o1 K; v: R8 {" q
第二步:实现核函数缓存9 }- X3 L" k+ v1 i+ ]( n$ u: V
观察下上面的伪代码,开销最大的就是计算核函数 K(xi, xj),有些计算又反复用到,一个 100 个样本的数据集求解,假设总共要调用核函数 20 万次,但是 xi, xj 的组和只有 100x100=1 万种,有缓存的话你的效率可以提升 20 倍。
# x7 m7 a8 o% C0 @& Q) b样本太大时,如果你想存储所有核函数的组和,需要 N*N * sizeof(double) 的空间,如果训练集有 10 万个样本,那么需要 76 GB 的内存,显然是不可能实现的,所以核函数缓存是一个有限空间的 LRU 缓存,SVM 的 SMO 求解过程中其实会反复用到特定的几个有限的核函数求解,所以命中率不用担心。+ G: l# _  q* S
有了这个核函数缓存,你的 SVM 求解程序能瞬间快几十倍。! F, ?. B8 K/ c
第三步:优化误差值求解
0 z# V9 q1 m6 d- g& F注意看上面的伪代码,里面需要计算一个估计值和真实值的误差 Ei 和 Ej,他们的求解方法是:
  p9 X, P3 o+ S& y/ g+ Q/ J' {E(i) = f(xi) - yi
1 P! t. ]# i1 }- e: h) W4 v这就是目前为止 SMO 这段为代码里代价最高的函数,因为回顾下上面的公式,计算一遍 f(x) 需要 for 循环做乘法加法。
! B1 m8 G! c3 C) Wplatt 的文章建议是做一个 E 函数的缓存,方便后面选择 i, j 时比较,我看到很多入门版本 SVM 实现都是这么做。其实这是有问题的,后面我们会说到。最好的方式是定义一个 g(x) 令其等于:$ U$ U% T" y! B, b6 I3 k+ G. ]
9 ]  h  V5 |* M1 t& |2 @
也就是 f(x) 公式除了 b 以外前面的最费时的计算,那么我们随时可以计算误差:! Y; C5 X0 s# V- X; J
E(j) = g(xj) + b - yj$ |2 q  X5 e' I$ s9 y
所以最好的办法是对 g(x) 进行缓存,platt 的方法里因为所有 alpha 值初始化成了 0,所以 g(x) 一开始就可以全部设置成 0,稍微观察一下 g(x) 的公式,你就会发现,因为去掉了 b 的干扰,而每次 SMO 迭代更新 ai, aj 参数时,这两个值都是线性变化的,所以我们可以给 g(x) 求关于 a 的偏导,假设 ai,aj 变化了步长 delta,那么所有样本对应的 g(x) 加上 delta 乘以针对 ai, aj 的偏导数就行了,具体代码类似:0 c$ u4 w: L9 U7 J
double Kik = kernel(i, k);& I# N% x7 T0 b! X$ p' N# O. e
double Kjk = kernel(j, k);$ _' e3 Q/ u' E! \$ T
G[k] += delta_alpha_i * Kik * y + delta_alpha_j * Kjk * y[j];
; ?$ V1 {+ r/ _2 W: T; W4 _把这段代码放在 takeStep 后面,每次成功更新一对 ai, aj 以后,更新所有样本对应的 g(x) 缓存,这样通过每次迭代更新 g(x) 避免了大量的重复计算。' O7 D# ^" k; O6 f
这其实是很直白的一种优化方式,我查了一下,有人专门发论文就讲了个类似的方法。1 q' A; K) T. U, p
第四步:实现冷热数据分离
; |8 P. C, ]9 ]Platt 的文章里也证明过一旦某个 alpha 出于边界(0 或者 C)的时候,就很不容易变动,而且伪代码也是优先在工作集里寻找 > 0 and < C 的 alpha 值进行优化,找不到了,再对工作集整体的 alpha 值进行迭代。4 J. N* B9 m2 S; ?( t
那么我们势必就可以把工作集分成两个部分,热数据在前(大于 0 小于 C 的 alpha 值),冷数据在后(小于等于 0 或者大于等于 C 的 alpha)。
% s" q, R- @: D1 a  p+ }( }随着迭代加深,会发现大部分时候只需要在热数据里求解,并且热数据的大小会逐步不停的收缩,所以区分了冷热以后 SVM 大部分都在针对有限的热数据迭代,偶尔不行了,再全部迭代一次,然后又回到冷热迭代,性能又能提高不少。
7 y( o' t+ c( J/ k7 ?/ U& e. F第五步:支持 Ensemble) Y- U7 g6 y* \5 [, p3 ^: D- M4 v! w. m
大家都知道,通过 Ensemble 可以让多个不同的弱模型组和成一个强模型,而传统 SVM 实现并不能适应一些类似 AdaBoost 的集成方法,所以我们需要做一些改动。可以让外面针对某一个分类传入一个“权重”过来,修正 SVM 的识别结果。
% e. |3 Y& b/ H最传统的修改方式就是将不等式约束 C 分为 Cp 和 Cn 两个针对 +1 分类的 C 及针对 -1 分类的 C。修改方式是直接用原始的 C 乘以各自分类的权重,得到 Cp 和 Cn,然后迭代时,不同的样本根据它的 y 值符号,用不同的 C 值带入计算。
8 f& S  ~* o: u这样 SVM 就能用各种集成方法同其他模型一起组成更为强大精准的模型了。3 @4 l# e5 F; N* k  E0 m
实现到这一步你就得到了功能上和性能上同 libsvm 类似的东西,接下来我们继续优化。: \5 {1 g* O4 ^* K
第六步:继续优化核函数计算
, C8 Q+ s* `2 h, K; N7 P! }- n9 b0 G核函数缓存非常消耗内存,libsvm 数学上已经没得挑了,但是工程方面还有很大改进余地,比如它的核缓存实现。
+ A* O: C) X1 }' ~6 c由于标准 SVM 核函数用的是两个高维矢量的内积,根据内积的几个条件,SVM 的核函数又是一个正定核,即 K(xi, xj) = K(xj, xi),那么我们同样的内存还能再多存一倍的核函数,性能又能有所提升。# c: c) \! Y3 \4 b$ ^' [# h
针对核函数的计算和存储有很多优化方式,比如有人对 NxN 的核函数矩阵进行采样,只计算有限的几个核函数,然后通过插值的方式求解出中间的值。还有人用 float 存储核函数值,又降低了一倍空间需求。
$ d- ?3 q9 u, O% d* N0 f- M第七步:支持稀疏向量和非稀疏向量# c$ q" g& x4 t: A8 M5 W
对于高维样本,比如文字这些,可能有上千维,每个样本的非零特征可能就那么几个,所以稀疏向量会比较高效,libsvm 也是用的稀疏向量。
# U. v5 Y, q( Z& U( J! G但是还有很多时候样本是密集向量,比如一共 200 个特征,大部分样本都有 100个以上的非零特征,用稀疏向量存储的话就非常低效了,openCV 的 SVM 实现就是非稀疏向量。" F0 f& R+ I3 h6 ]. d
非稀疏向量直接是用数组保存样本每个特征的值,在工程方面就有很多优化方式了,比如用的最多的求核函数的时候,直接上 SIMD 指令或者 CUDA,就能获得更好的计算性能。) J- z  U+ @9 }* _# \/ ^
所以最好的方式是同时支持稀疏和非稀疏,兼顾时间和空间效率,对不同的数据选择最适合的方式。
0 m" c; R! F( ~8 Y第八步:针对线性核进行优化+ {, w5 _' z; Y# P
传统的 SMO 方法,是 SVM 的通用求解方法,然而针对线性核,就是:  `3 e9 s6 Y$ X' ?9 a( T
K(xi, xj) = xi . xj
- y5 a" _! I9 U还有很多更高效的求解思路,比如 Pegasos 算法就用了一种类似随机梯度下降的方法,快速求 SVM 的解权重 w,如果你的样本适合线性核,使用一些针对性的非 SMO 算法可以极大的优化 SVM 求解,并且能处理更加庞大的数据集,LIBLINEAR 就是做这件事情的。: b. ^/ o3 y$ F: `/ K
同时这类算法也适合 online 训练和并行训练,可以逐步更新增量训练新的样本,还可以用到多核和分布式计算来训练模型,这是 SMO 算法做不到的地方。7 |- W5 J" s4 k) ~$ L- J
但是如果碰到非线性核,权重 w 处于高维核空间里(有可能无限维),你没法梯度下降迭代 w,并且 pegasos 的 pdf 里面也没有提到如何用到非线性核上,LIBLINEAR 也没有办法处理非线性核。: J- H2 h0 p, f7 @
或许哪天出个数学家又找到一种更好的方法,可以用类似 pegasos 的方式求解非线性核,那么 SVM 就能有比较大的进展了。( _7 B. A; x) M  `" `: l9 I# f) N
后话
  g; c1 Q" s+ a1 H# Z上面八条,你如果实现前三条,基本就能深入理解 SVM 的原理了,如果实现一大半,就可以得到一个类似 libsvm 的东西,全部实现,你就能得到一个比 libsvm 更好用的 SVM 库了。
6 A+ o/ y% C5 T& w* G8 c0 j上面就是如何实现一个相对成熟的 SVM 模型的思路,以及配套优化方法,再往后还有兴趣,可以接着实现支持向量回归,也是一个很有用的东西。
, y6 M8 [; X0 W1 d9 i3 `
+ a  ^& ]! n7 n( c- i来源:http://www.yidianzixun.com/article/0Lv0UIiC) m. {7 n2 i9 O2 i+ W# I
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×

帖子地址: 

梦想之都-俊月星空 优酷自频道欢迎您 http://i.youku.com/zhaojun917
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|手机版|小黑屋|梦想之都-俊月星空 ( 粤ICP备18056059号 )|网站地图

GMT+8, 2025-10-29 19:34 , Processed in 0.068729 second(s), 23 queries .

Powered by Mxzdjyxk! X3.5

© 2001-2025 Discuz! Team.

快速回复 返回顶部 返回列表