一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器

时间:2024-03-31 12:56:03

一.介绍

这是一个利用tensorflow.js库的网页端的网络。
在没有使用TensorFlow.js库之前,想要实现这个功能,需要花很长时间来完成算法编写,包括数据图像的采集、模型的训练、参数的调整,最终结果可能得经过分类模型(如:VGG、ResNet、ShuffleNet等)的卷积层、全连接层,最终以概率的方式呈现,预期效果可以达到,但是花费的时间代价也很大,
利用这个tensorflow.js库可以在网页端完成这些操作,训练包括预测的时间取决于自己使用的计算机。
一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器

二.实现过程

1.数据集

首先这是训练所使用的数据集,其中包含2893张剪刀石头布的图片,所有这些数据都是以白色背景构成的,每幅图像为300×300像素:
http://www.laurencemoroney.com/rock-paper-scissors-dataset/
一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器

2.网页数据集加载

解决浏览器加载图片的限制,使用精灵表单(spritesheet)将一组图像粘合成一个图像,此时,图像中每个像素都变成1像素高清图像,将它们堆叠创建一个保存所有图像的10MB大小的大图像,所有内容都合并为一个图像输入图片:
一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器
经过转换后的采集结果如下,图像收缩为6464大小每个,共有2520个图象,即成像为40962520像素:
一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器

3.选择模型进行训练

接下来可以选择一个搭建一个简单的训练模型还是一个复杂的训练模型,如果选择高级模型,首先,它需要花更长的时间训练样本甚至结果也没有预想的那么好用。此外,如果训练时间过长,高级模型会出现过拟合数据的问题。
一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器
选择一个简单模型:
一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器

4.开始训练

一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器在训练模型时,我们会获得每批次更新的图表,包括512个图像,以及每个时期更新的另一张图表,包括所有的2100个训练图像,一个健康的训练迭代应具有损失减少,准确性提高等特征。精度图中的橙色线表示验证数据的准确度,即用训练模型去预测剩余的420个未训练图像时的准确度。发现代表验证数据的橙色线与训练数据精度几乎重合,这说明建立的模型是可行的。

一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器

5.训练结果

一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器
我们可以看到石头和剪刀的识别精度很高,而布的识别精度只有0.95,进一步挖掘原因,查看其混淆矩阵,发现布被错误的识别成剪刀5次,识别成石头3次,可以选择增加这方面的不同尺度和姿态的样本来加强训练,
一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器

6.模型测试

使用网络摄像头检查自己做出的代表石头剪刀布的手势图像。需要注意的是我们的手势图像应与训练图像类似,没有旋转角度且背景为白色,便于模型进行识别,点击识别按钮可以看到摄像头图像被裁剪为64×64送入网络进行预测,识别为石头,概率为100%。一个简单的机器学习案例:10分钟,训练一个“剪刀石头布”识别器

最后附上链接网站,大家可以去体验:网站