机器学习实战ByMatlab(五)Logistic Regression

时间:2022-11-01 16:42:48

机器学习实战ByMatlab(五)Logistic Regression

http://blog.csdn.net/llp1992/article/details/45114421
什么叫做回归呢?举个例子,我们现在有一些数据点,然后我们打算用一条直线来对这些点进行拟合(该曲线称为最佳拟合曲线),这个拟合过程就被称为回归。

利用Logistic回归进行分类的主要思想是:

根据现有数据对分类边界线建立回归公式,以此进行分类。

这里的”回归“一词源于最佳拟合,表示要找到最佳拟合参数集。训练分类器时的嘴阀就是寻找最佳拟合曲线,使用的是最优化算法。

基于Logistic回归和Sigmoid函数的分类

优点:计算代价不高,易于理解和实现 
缺点:容易欠拟合,分类精度可能不高

使用数据类型:数值型和标称型数据

Sigmoid函数:


机器学习实战ByMatlab(五)Logistic Regression

波形如下:


机器学习实战ByMatlab(五)Logistic Regression

当z为0时,值为0.5,当z增大时,g(z)逼近1,当z减小时,g(z)逼近0

Logistic回归分类器:

对每一个特征都乘以一个回归系数,然后把所有结果都相加,再讲这个总和代入Sigmoid函数中,从而得到一个范围在0-1之间的数值。任何大于0.5的数据被分为1,小于0.5的数据被分为0.因此Logistic回归也被看成是一种概率分布。

分类器的函数形式确定之后,现在的问题就是,如何确定回归系数?

基于最优化方法的最佳回归系数确定

Sigmoid函数的输入记为z,由下面公式得出:


机器学习实战ByMatlab(五)Logistic Regression

如果采用向量的写法,则上述公式可以写成: 


机器学习实战ByMatlab(五)Logistic Regression

其中向量X就是分类器的输入数据,向量W也就是我们要找到的最佳参数,从而使分类器尽可能更加地精确。接下来将介绍几种需找最佳参数的方法。

梯度上升法

梯度上升法的基本思想:

要找到某函数的最大值,最好的方法是沿着该函数的梯度方向寻找

这里提一下梯度下降法,这个我们应该会更加熟悉,因为我们在很多代价函数J的优化的时候经常用到它,其基本思想是:

要找到某函数的最小值,最好的方法是沿着该函数的梯度方向的反方向寻找

函数的梯度表示方法如下:


机器学习实战ByMatlab(五)Logistic Regression

机器学习实战ByMatlab(五)Logistic Regression

移动方向确定了,移动的大小我们称之为步长,用α表示,用向量来表示的话,梯度下降算法的迭代公式如下:


机器学习实战ByMatlab(五)Logistic Regression

该公式已知被迭代执行,直到某个停止条件位置,比如迭代次数达到某个指定值或者算法的误差小到某个允许的误差范围内。

注:梯度下降算法中的迭代公式如下:


机器学习实战ByMatlab(五)Logistic Regression

Matlab 实现

<code class="language-matlab hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">
<span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">function</span> <span class="hljs-title" style="box-sizing: border-box;">weight</span> = <span class="hljs-title" style="box-sizing: border-box;">gradAscent</span></span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%%</span>
clc
close all
clear
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%%</span>

data = load(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'testSet.txt'</span>);
<span class="hljs-matrix" style="box-sizing: border-box;">[row , col]</span> = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">size</span>(data);
dataMat = data(:,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>:col-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>);
dataMat = <span class="hljs-matrix" style="box-sizing: border-box;">[ones(row,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) dataMat]</span> ;
labelMat = data(:,col);
alpha = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.001</span>;
maxCycle = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">500</span>;
weight = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">ones</span>(col,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>);
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span> = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>:maxCycle
h = sigmoid((dataMat * weight)<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">');
error = (labelMat - h'</span>);
weight = weight + alpha * <span class="hljs-transposed_variable" style="box-sizing: border-box;">dataMat'</span> * error;
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">end</span>

figure
scatter(dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">find</span>(labelMat(:) == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>),dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">find</span>(labelMat(:) == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>);
hold on
scatter(dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">find</span>(labelMat(:) == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>),dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">find</span>(labelMat(:) == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">5</span>);
hold on
x = -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>:<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.1</span>:<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>;
y = (-weight(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)-weight(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>)*x)/weight(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>);
plot(x,y)
hold off

<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">end</span>

<span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">function</span> <span class="hljs-title" style="box-sizing: border-box;">returnVals</span> = <span class="hljs-title" style="box-sizing: border-box;">sigmoid</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(inX)</span></span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% 注意这里的sigmoid函数要用点除</span>
returnVals = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.0</span>./(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.0</span>+<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">exp</span>(-inX));
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">end</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li></ul>

效图如下:


机器学习实战ByMatlab(五)Logistic Regression

由上图可以看到,回归效果还是挺不错的,只有2-4个点分类错误。

其实这是的梯度上升算法是批量梯度上升算法,每一次更新参数的时候都要讲所有的数据集都代入训练,效果并不好,下面我们将介绍改进版本:随机梯度上升算法

随机梯度上升

梯度上升算法在每次更新回归系数时都要遍历整个数据集,该方法在处理100个左右的数据集时尚可,但如果有数十亿样本和成千上万的特征,那么该方法的复杂度就太高了。一种改进方法是一次仅用一个样本点来更新回归系数,该方法就称为随机梯度上升法。由于可以在新样本到来之前对分类器进行增量式更新,因此随机梯度算法是一个在线学习算法。与”在线学习“相对应,一次处理所有数据被称作是”批处理“

随机梯度上升算法可以写成如下的伪代码:

<code class="hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">
所有回归系数初始化为1
对数据集中的每个样本
计算该样本的梯度
使用alpha x gradient 更新回归系数值
返回回归系数值
</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul>

Matlab 代码实现

<code class="language-matlab hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">
<span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">function</span> <span class="hljs-title" style="box-sizing: border-box;">stocGradAscent</span></span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%%</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% Description : LogisticRegression using stocGradAsscent</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% Author : Liulongpo</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% Time:2015-4-18 10:57:25</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%%</span>
clc
clear
close all
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%%</span>
data = load(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'testSet.txt'</span>);
<span class="hljs-matrix" style="box-sizing: border-box;">[row , col]</span> = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">size</span>(data);
dataMat = <span class="hljs-matrix" style="box-sizing: border-box;">[ones(row,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) data(:,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>:col-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)]</span>;
alpha = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.01</span>;
labelMat = data(:,col);
weight = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">ones</span>(col,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>);
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span> = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>:row
h = sigmoid(dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span>,:)*weight);
error = labelMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span>) - h;
dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span>,:)
weight
weight = weight + alpha * error * dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span>,:)<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'
end

figure
scatter(dataMat(find(labelMat(:)==0),2),dataMat(find(labelMat(:)==0),3),5);
hold on
scatter(dataMat(find(labelMat(:) == 1),2),dataMat(find(labelMat(:) == 1),3),5);
hold on
x = -3:0.1:3;
y = -(weight(1)+weight(2)*x)/weight(3);
plot(x,y)
hold off


end

function returnVals = sigmoid(inX)
% 注意这里的sigmoid函数要用点除
returnVals = 1.0./(1.0+exp(-inX));
end
</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li><li style="box-sizing: border-box; padding: 0px 5px;">45</li></ul>

效果如下:


机器学习实战ByMatlab(五)Logistic Regression

由上图可以看出,随机梯度上升算法分类效果并没有上面的的梯度上升算法分类效果好。

但是直接比较梯度上升算法和随机梯度上升算法是不公平的,前者是在整个数据集上迭代500次得到的结果,后者只是迭代了100次。一个判断算法优劣的可靠方法是看它是否收敛,也就是说求解的参数是否达到了稳定值,是否还会不断变化。

我们让随机梯度上升算法在整个数据集上运行200次,迭代过程中3个参数的变化如下图:


机器学习实战ByMatlab(五)Logistic Regression

由上图可以看到,weight1 最先达到稳定,而weight0和weight2则还需要更多的迭代次数来达到稳定。

此时的分类器跟之前的梯度上升算法的分类效果差不多,如下:


机器学习实战ByMatlab(五)Logistic Regression

但同时我们也可以看到,三个参数都有不同程度的波动。产生这种现象的原因是存在一些不能被正确分类的样本点(数据集并非线性可分),在每次迭代的时候都会引起参数的剧烈变化。我们期望算法能避免来回波动,从而收敛到某个值。另外,算法收敛速度也要加快。

改进的随机梯度上升算法

改进的随机梯度上升算法的主要两个改进点如下:

1,每一步调整alpha的值,也就是alpha的值是不严格下降的 
2.随机采取样本来更新回归参数

matlab代码如下:

<code class="language-matlab hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">
<span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">function</span> <span class="hljs-title" style="box-sizing: border-box;">ImproveStocGradAscent</span></span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%%</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% Description : LogisticRegression using stocGradAsscent</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% Author : Liulongpo</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% Time:2015-4-18 10:57:25</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%</span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%%</span>
clc
clear
close all
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%%</span>
data = load(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'testSet.txt'</span>);
<span class="hljs-matrix" style="box-sizing: border-box;">[row , col]</span> = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">size</span>(data);
dataMat = <span class="hljs-matrix" style="box-sizing: border-box;">[ones(row,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) data(:,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>:col-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)]</span>;
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%alpha = 0.01;</span>
numIter = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">20</span>;
labelMat = data(:,col);
weightVal = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">zeros</span>(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>,numIter*row);
weight = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">ones</span>(col,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>);
<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">j</span> = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>;

<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> k = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>:numIter
randIndex = randperm(row);
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span> = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>:row
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% 改进点 1</span>
alpha = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>/(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.0</span>+<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span>+k)+<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.01</span>;
<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">j</span> = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">j</span>+<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>;
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% 改进点 2 </span>
h = sigmoid(dataMat(randIndex(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span>),:)*weight);
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% 改进点 2</span>
error = labelMat(randIndex(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span>)) - h;
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% 改进点 2</span>
weight = weight + alpha * error * dataMat(randIndex(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span>),:)<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">';
weightVal(1,j) = weight(1);
weightVal(2,j) = weight(2);
weightVal(3,j) = weight(3);
end
end

figure
i = 1:numIter*row;
subplot(3,1,1)
plot(i,weightVal(1,:)),title('</span><span class="hljs-transposed_variable" style="box-sizing: border-box;">weight0'</span>)<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%,axis([0 numIter*row 0.8 7])</span>
<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">j</span> = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>:numIter*row;
subplot(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>)
plot(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">j</span>,weightVal(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>,:)),title(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'weight1'</span>)<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%,axis([0 numIter*row 0.3 1.2])</span>
k = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>:numIter*row;
subplot(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>)
plot(k,weightVal(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>,:)),title(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'weight2'</span>)<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">%,axis([0 numIter*row -1.2 -0.1])</span>

figure
scatter(dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">find</span>(labelMat(:)==<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>),dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">find</span>(labelMat(:)==<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">5</span>);
hold on
scatter(dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">find</span>(labelMat(:) == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>),dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">find</span>(labelMat(:) == <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>),<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">5</span>);
hold on
x = -<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>:<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.1</span>:<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>;
y = -(weight(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)+weight(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>)*x)/weight(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>);
plot(x,y,<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'r'</span>)
hold off


<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">end</span>

<span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">function</span> <span class="hljs-title" style="box-sizing: border-box;">returnVals</span> = <span class="hljs-title" style="box-sizing: border-box;">sigmoid</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(inX)</span></span>
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">% 注意这里的sigmoid函数要用点除</span>
returnVals = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.0</span>./(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1.0</span>+<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">exp</span>(-inX));
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">end</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li><li style="box-sizing: border-box; padding: 0px 5px;">45</li><li style="box-sizing: border-box; padding: 0px 5px;">46</li><li style="box-sizing: border-box; padding: 0px 5px;">47</li><li style="box-sizing: border-box; padding: 0px 5px;">48</li><li style="box-sizing: border-box; padding: 0px 5px;">49</li><li style="box-sizing: border-box; padding: 0px 5px;">50</li><li style="box-sizing: border-box; padding: 0px 5px;">51</li><li style="box-sizing: border-box; padding: 0px 5px;">52</li><li style="box-sizing: border-box; padding: 0px 5px;">53</li><li style="box-sizing: border-box; padding: 0px 5px;">54</li><li style="box-sizing: border-box; padding: 0px 5px;">55</li><li style="box-sizing: border-box; padding: 0px 5px;">56</li><li style="box-sizing: border-box; padding: 0px 5px;">57</li><li style="box-sizing: border-box; padding: 0px 5px;">58</li><li style="box-sizing: border-box; padding: 0px 5px;">59</li><li style="box-sizing: border-box; padding: 0px 5px;">60</li><li style="box-sizing: border-box; padding: 0px 5px;">61</li><li style="box-sizing: border-box; padding: 0px 5px;">62</li><li style="box-sizing: border-box; padding: 0px 5px;">63</li><li style="box-sizing: border-box; padding: 0px 5px;">64</li><li style="box-sizing: border-box; padding: 0px 5px;">65</li><li style="box-sizing: border-box; padding: 0px 5px;">66</li><li style="box-sizing: border-box; padding: 0px 5px;">67</li><li style="box-sizing: border-box; padding: 0px 5px;">68</li><li style="box-sizing: border-box; padding: 0px 5px;">69</li></ul>

改进点 1 中的alpha会随着迭代次数的增加不断减小,但由于代码中常数0.01的存在,alpha不会减少到0。这样做是为了保证在多次迭代之后新数据对于参数的更新还有一定的影响。

另一点值得注意的就是,alpha每次减少 1/(k+i) ,k 是迭代次数,i是样本的下标。所以 alpha 不是严格下降的。避免参数的严格下降也常见于模拟退火算法等其他优化算法中。

第二个改进的地方如代码注释中标记的,这里通过随机采取样本来更新回归参数,这样能够减少参数的周期性的波动。

由于alpha的动态变化,我们可以在开始的时候设置比较大的值,代码中设置0.01,alpha也就是每一次迭代的步长,步长越大,越能够加快参数的收敛速度。然后ahpha会不严格下降,这样就避免了过拟合现象的发生。至于什么是过拟合已经alpha的选取问题将在下面描述。

迭代20次后效果如下:


机器学习实战ByMatlab(五)Logistic Regression


机器学习实战ByMatlab(五)Logistic Regression

由上图可知,步长alpha动态变化之后,参数的收敛速度加快了很多,这里只是对所有样本数据集迭代20次,weight0 和 weight2很早就收敛。证明了该算法的优异性。

学习率alpha的选取

首先我们看一下梯度上升算法的核心代码,如下:

<code class="language-matlab hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">h = sigmoid(dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span>,:)  *  weight);
error = labelMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span>) - h;
weight = weight + alpha * error * dataMat(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">i</span>,:)<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">';</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li></ul>

第一行做的就是估计分类,第二行计算当前估计与正确分类之间的差error,第三行根据这个error来更新参数weight。

我们在迭代的时候,要做的目标就是最小化 error ,我们令 J 代表 error,令向量 θ 代表weight,则很显然,J是θ的函数。这里盗用Standfor 机器学习教程的图,如下:


机器学习实战ByMatlab(五)Logistic Regression

上图中的每个箭头就是每一次迭代的更新步长,第一幅图我们看到,在最小化 J(θ) 的时候迭代了很多次,这说明什么?说明我们要走很多步才能到达全局最优点,原因就是我们每一步走的距离太短,走得太慢,也就是我们的alpha设置得太小。但是当我们处于最优点附近的时候,这样有利我们向最优点靠近。

下图中的每个箭头也代表走一步,我们可以看到,迭代的时候,每一步都没有到达最优点,而是在最优点的附近波动。为什么呢?因为步长太大了嘛,明明就在眼前了,半步或者四分之三步就走到了,你却只能一跨而过,重新再来。但是学习率大的话,在刚开始迭代的时候有利于我们参数的快速收敛,也有利于我们避开局部最小值。

综合以上两种情况,我们就应该在开始的时候选取较大的学习率,然后不断不严格减小学习率,这样才是最优的选择。

那么,我们开始的学习率应该怎么选取?Andrew Ng 在课程中建议先试试0.01,太大就0.003,太小就0.03….