scikit-learn 逻辑回归例子详解

时间:2022-03-01 00:33:49

iris花分类

1 代码文件

http://scikit-learn.org/stable/auto_examples/linear_model/plot_iris_logistic.html#sphx-glr-auto-examples-linear-model-plot-iris-logistic-py

2 读取原始数据源

import numpy as npimport matplotlib.pyplot as plt
from sklearn import linear_model, datasets

# import some data to play with
iris = datasets.load_iris()

得到的iris数据结果如下

print(iris.data)

这就是一个二维数组

[[ 5.1  3.5  1.4  0.2] [ 4.9  3.   1.4  0.2] [ 4.7  3.2  1.3  0.2] [ 4.6  3.1  1.5  0.2] [ 5.   3.6  1.4  0.2] [ 5.4  3.9  1.7  0.4] [ 4.6  3.4  1.4  0.3] [ 5.   3.4  1.5  0.2] [ 4.4  2.9  1.4  0.2] [ 4.9  3.1  1.5  0.1] [ 5.4  3.7  1.5  0.2] [ 4.8  3.4  1.6  0.2] [ 4.8  3.   1.4  0.1] [ 4.3  3.   1.1  0.1] [ 5.8  4.   1.2  0.2] [ 5.7  4.4  1.5  0.4] [ 5.4  3.9  1.3  0.4] [ 5.1  3.5  1.4  0.3] [ 5.7  3.8  1.7  0.3] [ 5.1  3.8  1.5  0.3] [ 5.4  3.4  1.7  0.2] [ 5.1  3.7  1.5  0.4] [ 4.6  3.6  1.   0.2] [ 5.1  3.3  1.7  0.5] [ 4.8  3.4  1.9  0.2] [ 5.   3.   1.6  0.2] [ 5.   3.4  1.6  0.4] [ 5.2  3.5  1.5  0.2] [ 5.2  3.4  1.4  0.2] [ 4.7  3.2  1.6  0.2] [ 4.8  3.1  1.6  0.2] [ 5.4  3.4  1.5  0.4] [ 5.2  4.1  1.5  0.1] [ 5.5  4.2  1.4  0.2] [ 4.9  3.1  1.5  0.1] [ 5.   3.2  1.2  0.2] [ 5.5  3.5  1.3  0.2] [ 4.9  3.1  1.5  0.1] [ 4.4  3.   1.3  0.2] [ 5.1  3.4  1.5  0.2] [ 5.   3.5  1.3  0.3] [ 4.5  2.3  1.3  0.3] [ 4.4  3.2  1.3  0.2] [ 5.   3.5  1.6  0.6] [ 5.1  3.8  1.9  0.4] [ 4.8  3.   1.4  0.3] [ 5.1  3.8  1.6  0.2] [ 4.6  3.2  1.4  0.2] [ 5.3  3.7  1.5  0.2] [ 5.   3.3  1.4  0.2] [ 7.   3.2  4.7  1.4] [ 6.4  3.2  4.5  1.5] [ 6.9  3.1  4.9  1.5] [ 5.5  2.3  4.   1.3] [ 6.5  2.8  4.6  1.5] [ 5.7  2.8  4.5  1.3] [ 6.3  3.3  4.7  1.6] [ 4.9  2.4  3.3  1. ] [ 6.6  2.9  4.6  1.3] [ 5.2  2.7  3.9  1.4] [ 5.   2.   3.5  1. ] [ 5.9  3.   4.2  1.5] [ 6.   2.2  4.   1. ] [ 6.1  2.9  4.7  1.4] [ 5.6  2.9  3.6  1.3] [ 6.7  3.1  4.4  1.4] [ 5.6  3.   4.5  1.5] [ 5.8  2.7  4.1  1. ] [ 6.2  2.2  4.5  1.5] [ 5.6  2.5  3.9  1.1] [ 5.9  3.2  4.8  1.8] [ 6.1  2.8  4.   1.3] [ 6.3  2.5  4.9  1.5] [ 6.1  2.8  4.7  1.2] [ 6.4  2.9  4.3  1.3] [ 6.6  3.   4.4  1.4] [ 6.8  2.8  4.8  1.4] [ 6.7  3.   5.   1.7] [ 6.   2.9  4.5  1.5] [ 5.7  2.6  3.5  1. ] [ 5.5  2.4  3.8  1.1] [ 5.5  2.4  3.7  1. ] [ 5.8  2.7  3.9  1.2] [ 6.   2.7  5.1  1.6] [ 5.4  3.   4.5  1.5] [ 6.   3.4  4.5  1.6] [ 6.7  3.1  4.7  1.5] [ 6.3  2.3  4.4  1.3] [ 5.6  3.   4.1  1.3] [ 5.5  2.5  4.   1.3] [ 5.5  2.6  4.4  1.2] [ 6.1  3.   4.6  1.4] [ 5.8  2.6  4.   1.2] [ 5.   2.3  3.3  1. ] [ 5.6  2.7  4.2  1.3] [ 5.7  3.   4.2  1.2] [ 5.7  2.9  4.2  1.3] [ 6.2  2.9  4.3  1.3] [ 5.1  2.5  3.   1.1] [ 5.7  2.8  4.1  1.3] [ 6.3  3.3  6.   2.5] [ 5.8  2.7  5.1  1.9] [ 7.1  3.   5.9  2.1] [ 6.3  2.9  5.6  1.8] [ 6.5  3.   5.8  2.2] [ 7.6  3.   6.6  2.1] [ 4.9  2.5  4.5  1.7] [ 7.3  2.9  6.3  1.8] [ 6.7  2.5  5.8  1.8] [ 7.2  3.6  6.1  2.5] [ 6.5  3.2  5.1  2. ] [ 6.4  2.7  5.3  1.9] [ 6.8  3.   5.5  2.1] [ 5.7  2.5  5.   2. ] [ 5.8  2.8  5.1  2.4] [ 6.4  3.2  5.3  2.3] [ 6.5  3.   5.5  1.8] [ 7.7  3.8  6.7  2.2] [ 7.7  2.6  6.9  2.3] [ 6.   2.2  5.   1.5] [ 6.9  3.2  5.7  2.3] [ 5.6  2.8  4.9  2. ] [ 7.7  2.8  6.7  2. ] [ 6.3  2.7  4.9  1.8] [ 6.7  3.3  5.7  2.1] [ 7.2  3.2  6.   1.8] [ 6.2  2.8  4.8  1.8] [ 6.1  3.   4.9  1.8] [ 6.4  2.8  5.6  2.1] [ 7.2  3.   5.8  1.6] [ 7.4  2.8  6.1  1.9] [ 7.9  3.8  6.4  2. ] [ 6.4  2.8  5.6  2.2] [ 6.3  2.8  5.1  1.5] [ 6.1  2.6  5.6  1.4] [ 7.7  3.   6.1  2.3] [ 6.3  3.4  5.6  2.4] [ 6.4  3.1  5.5  1.8] [ 6.   3.   4.8  1.8] [ 6.9  3.1  5.4  2.1] [ 6.7  3.1  5.6  2.4] [ 6.9  3.1  5.1  2.3] [ 5.8  2.7  5.1  1.9] [ 6.8  3.2  5.9  2.3] [ 6.7  3.3  5.7  2.5] [ 6.7  3.   5.2  2.3] [ 6.3  2.5  5.   1.9] [ 6.5  3.   5.2  2. ] [ 6.2  3.4  5.4  2.3] [ 5.9  3.   5.1  1.8]]

3 获取X特征数组

目的是获得第一和二列数据,长度和宽度,使用的是python的二维数组切片语法,得到二维数组存入X中

3.1 代码

X = iris.data[:, :2]  # we only take the first two features.

3.2 说明

  • ,表示分隔符号,左边是对第一维度对切片,右边是对第二维度的切片
  • 每个切片等同于一维的切片,语法参考
    python数组访问
  • 二维数组的第一维度切片就是取第几行数据,这里由于用的是:, 表示取全部所有行
  • 二维数组的第二维度切片就是取第几列数据,这里:2表示取0和1列数据

3.3 结果

这是个二维数组,第一个元素是长度,第二个是宽度

[[ 5.1  3.5] [ 4.9  3. ] [ 4.7  3.2] [ 4.6  3.1] [ 5.   3.6] [ 5.4  3.9] [ 4.6  3.4] [ 5.   3.4] [ 4.4  2.9] [ 4.9  3.1] [ 5.4  3.7] [ 4.8  3.4] [ 4.8  3. ] [ 4.3  3. ] [ 5.8  4. ] [ 5.7  4.4] [ 5.4  3.9] [ 5.1  3.5] [ 5.7  3.8] [ 5.1  3.8] [ 5.4  3.4] [ 5.1  3.7] [ 4.6  3.6] [ 5.1  3.3] [ 4.8  3.4] [ 5.   3. ] [ 5.   3.4] [ 5.2  3.5] [ 5.2  3.4] [ 4.7  3.2] [ 4.8  3.1] [ 5.4  3.4] [ 5.2  4.1] [ 5.5  4.2] [ 4.9  3.1] [ 5.   3.2] [ 5.5  3.5] [ 4.9  3.1] [ 4.4  3. ] [ 5.1  3.4] [ 5.   3.5] [ 4.5  2.3] [ 4.4  3.2] [ 5.   3.5] [ 5.1  3.8] [ 4.8  3. ] [ 5.1  3.8] [ 4.6  3.2] [ 5.3  3.7] [ 5.   3.3] [ 7.   3.2] [ 6.4  3.2] [ 6.9  3.1] [ 5.5  2.3] [ 6.5  2.8] [ 5.7  2.8] [ 6.3  3.3] [ 4.9  2.4] [ 6.6  2.9] [ 5.2  2.7] [ 5.   2. ] [ 5.9  3. ] [ 6.   2.2] [ 6.1  2.9] [ 5.6  2.9] [ 6.7  3.1] [ 5.6  3. ] [ 5.8  2.7] [ 6.2  2.2] [ 5.6  2.5] [ 5.9  3.2] [ 6.1  2.8] [ 6.3  2.5] [ 6.1  2.8] [ 6.4  2.9] [ 6.6  3. ] [ 6.8  2.8] [ 6.7  3. ] [ 6.   2.9] [ 5.7  2.6] [ 5.5  2.4] [ 5.5  2.4] [ 5.8  2.7] [ 6.   2.7] [ 5.4  3. ] [ 6.   3.4] [ 6.7  3.1] [ 6.3  2.3] [ 5.6  3. ] [ 5.5  2.5] [ 5.5  2.6] [ 6.1  3. ] [ 5.8  2.6] [ 5.   2.3] [ 5.6  2.7] [ 5.7  3. ] [ 5.7  2.9] [ 6.2  2.9] [ 5.1  2.5] [ 5.7  2.8] [ 6.3  3.3] [ 5.8  2.7] [ 7.1  3. ] [ 6.3  2.9] [ 6.5  3. ] [ 7.6  3. ] [ 4.9  2.5] [ 7.3  2.9] [ 6.7  2.5] [ 7.2  3.6] [ 6.5  3.2] [ 6.4  2.7] [ 6.8  3. ] [ 5.7  2.5] [ 5.8  2.8] [ 6.4  3.2] [ 6.5  3. ] [ 7.7  3.8] [ 7.7  2.6] [ 6.   2.2] [ 6.9  3.2] [ 5.6  2.8] [ 7.7  2.8] [ 6.3  2.7] [ 6.7  3.3] [ 7.2  3.2] [ 6.2  2.8] [ 6.1  3. ] [ 6.4  2.8] [ 7.2  3. ] [ 7.4  2.8] [ 7.9  3.8] [ 6.4  2.8] [ 6.3  2.8] [ 6.1  2.6] [ 7.7  3. ] [ 6.3  3.4] [ 6.4  3.1] [ 6.   3. ] [ 6.9  3.1] [ 6.7  3.1] [ 6.9  3.1] [ 5.8  2.7] [ 6.8  3.2] [ 6.7  3.3] [ 6.7  3. ] [ 6.3  2.5] [ 6.5  3. ] [ 6.2  3.4] [ 5.9  3. ]]

4 获取Y因变量数组

iris.target包含了iris花的所有种类,用0,1和2代表

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]

5 对X特征拟合

设置参数,并拟合数据

logreg = linear_model.LogisticRegression(C=1e5)# we create an instance of Neighbours Classifier and fit the data.logreg.fit(X, Y)

这里C参数的值为  10 5  ,C有什么用呢? 下面有段来自* 的解答

Regularization is applying a penalty to increasing the magnitude of parameter values in order to reduce overfitting. When you train a model such as a logistic regression model, you are choosing parameters that give you the best fit to the data. This means minimizing the error between what the model predicts for your dependent variable given your data compared to what your dependent variable actually is.The problem comes when you have a lot of parameters (a lot of independent variables) but not too much data. In this case, the model will often tailor the parameter values to idiosyncrasies in your data -- which means it fits your data almost perfectly. However because those idiosyncrasies don't appear in future data you see, your model predicts poorly.To solve this, as well as minimizing the error as already discussed, you add to what is minimized and also minimize a function that penalizes large values of the parameters. Most often the function is λΣθj2, which is some constant λ times the sum of the squared parameter values θj2. The larger λ is the less likely it is that the parameters will be increased in magnitude simply to adjust for small perturbations in the data.In your case however, rather than specifying λ, you specify C=1/λ.

6 预测花的种类

6.1 用numpy对数据做整理

在直角坐标系中,x表示花萼长度,y表示花萼宽度。每个点的坐标就是(x,y)。 先取X二维数组的第一列(长度)的最小值,最大值和步长h生成数组, 再取X二维数组的第二列(宽度)的最小值,最大值和步长h生成数组, 然后用meshgrid函数生成两个网格矩阵xx和yy
meshgrid概念请参考meshgrid

h = .02  # step size in the mesh# Plot the decision boundary. For that, we will assign a color to each# point in the mesh [x_min, x_max]x[y_min, y_max].x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

xx和yy的值如下:

[[ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ] [ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ] [ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ] ...,  [ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ] [ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ] [ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ]][[ 1.5   1.5   1.5  ...,  1.5   1.5   1.5 ] [ 1.52  1.52  1.52 ...,  1.52  1.52  1.52] [ 1.54  1.54  1.54 ...,  1.54  1.54  1.54] ...,  [ 4.86  4.86  4.86 ...,  4.86  4.86  4.86] [ 4.88  4.88  4.88 ...,  4.88  4.88  4.88] [ 4.9   4.9   4.9  ...,  4.9   4.9   4.9 ]]

下面代码用到了ravel()函数,这里提前作个说明:
xx.ravel() 和 yy.ravel() 是将两个矩阵(二维数组)都变成一维数组的意思(其实是视图,并没有复制数据),由于两个矩阵大小相等,因此两个一维数组大小也相等。

[ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ][ 1.5  1.5  1.5 ...,  4.9  4.9  4.9]

np.c_[xx.ravel(), yy.ravel()]取出来的矩阵变成了:

[[ 3.8   1.5 ] [ 3.82  1.5 ] [ 3.84  1.5 ] ...,  [ 8.36  4.9 ] [ 8.38  4.9 ] [ 8.4   4.9 ]]
  • 现在总结下前面做的事情
    1. 把第一列花萼长度数据按h取等分,作为行,然后复制多行,得到xx网格矩阵
    2. 把第二列花萼宽度数据按h取等分,作为列,然后复制多列,得到yy网格矩阵
    3. xx和yy矩阵都变成两个一维数组,然后到np.c_[] 函数组合成一个二维数组

6.2 用逻辑回归预测

predict函数的语法为:

Predict class labels for samples in X.Parameters:X : {array-like, sparse matrix}, shape = [n_samples, n_features]Samples.Returns:C : array, shape = [n_samples]Predicted class label per sample.

可以看到这里的参数X的shape刚好是(39501, 2),也就是39501个样本(行数),2个特征(两列:长度和宽度)
返回的是一个C风格的一维数组Z

Z = logreg.predict(np.c_[xx.ravel(), yy.ravel()])# Put the result into a color plotZ = Z.reshape(xx.shape)

下面是具体的值:

Z: [1 1 1 ..., 2 2 2]size: 39501shape: (171, 231)Z:[[1 1 1 ..., 2 2 2] [1 1 1 ..., 2 2 2] [0 1 1 ..., 2 2 2] ...,  [0 0 0 ..., 2 2 2] [0 0 0 ..., 2 2 2] [0 0 0 ..., 2 2 2]]

7 将预测结果显示

pcolormesh函数将xx,yy两个网格矩阵和对应的预测结果Z绘制在图片上,我们可以看到的是三个颜色区块表示了分类的区域。
这个函数很不错,可以方便的绘制三个区域代表三类点集。

plt.figure(1, figsize=(4, 3))plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired)# Plot also the training pointsplt.xlabel('Sepal length')plt.ylabel('Sepal width')plt.xlim(xx.min(), xx.max())plt.ylim(yy.min(), yy.max())plt.xticks(())plt.yticks(())plt.show()

scikit-learn 逻辑回归例子详解

8 检验结果

现在将训练的点也绘制在图片上,可视化的看一下预测的分类方法运用到训练数据上的准确性,只要加一行代码:

plt.scatter(X[:, 0], X[:, 1], c=Y, edgecolors='k', cmap=plt.cm.Paired)

结果显示为:

scikit-learn 逻辑回归例子详解