Demo侠可能是我等小白进阶的必经之路了,如今在AI领域,我也是个研究Demo的小白。用了两三天装好环境,跑通Demo,自学Python语法,进而研究这个Demo。当然过程中查了很多资料,充分发挥了小白的主观能动性,总算有一些收获需要总结下。
不多说,算法在代码中,一切也都在代码中。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '' #获得数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) import tensorflow as tf #输入图像数据占位符
x = tf.placeholder(tf.float32, [None, 784]) #权值和偏差
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10])) #使用softmax模型
y = tf.nn.softmax(tf.matmul(x, W) + b) #代价函数占位符
y_ = tf.placeholder(tf.float32, [None, 10]) #交叉熵评估代价
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) #使用梯度下降算法优化:学习速率为0.5
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) #Session(交互方式)
sess = tf.InteractiveSession() #初始化变量
tf.global_variables_initializer().run() #训练模型,训练1000次
for _ in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) #计算正确率
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
看完这个Demo,顿时感觉Python真是一门好语言,Tensorflow是一个好框架,就跟之前掌握Matlab以后,用Matlab做仿真的感觉一样。
为什么看这几行代码看了两三天,因为看懂很容易,但了解代码背后的意义更重要,如果把一个Demo看透了,那么后边举一反三就会很容易了,我向来就是这样学习的,本小白当年也是个学霸?!
来一起看下这里边有什么玄机和坑吧,记录一下,人老了记性不好(^-^)。
看到1,2行代码,不要懵,这个作用是设置日志级别,os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 warning 和 Error,等于1是显示所有信息。不加这两行会有个提示(Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2,具体可以看这里)
第5行是一个引用声明,从tensorflow.examples.tutorials.mnist 引用一个名为 input_data 的函数,可以看一下input_data是什么样子的:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function import gzip
import os
import tempfile import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
原来input_data里边也是引用声明,真正想用到的实际是tensorflow.contrib.learn.python.learn.datasets.mnist里的read_data_sets,看一下代码:
def read_data_sets(train_dir,
fake_data=False,
one_hot=False,
dtype=dtypes.float32,
reshape=True,
validation_size=5000,
seed=None,
source_url=DEFAULT_SOURCE_URL):
if fake_data:
... if not source_url: # empty string check
... local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
source_url + TRAIN_IMAGES)
with gfile.Open(local_file, 'rb') as f:
train_images = extract_images(f) ... if not 0 <= validation_size <= len(train_images):
raise ValueError('Validation size should be between 0 and {}. Received: {}.'
.format(len(train_images), validation_size)) validation_images = train_images[:validation_size]
validation_labels = train_labels[:validation_size]
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:] options = dict(dtype=dtype, reshape=reshape, seed=seed) train = DataSet(train_images, train_labels, **options)
validation = DataSet(validation_images, validation_labels, **options)
test = DataSet(test_images, test_labels, **options) return base.Datasets(train=train, validation=validation, test=test)
mnist最终得到的是base.Datasets,完成了数据读取。这里边的细节还需要完了再仔细研究下。
顺便记录下自编的函数的定义方法:
def Mycollect(My , thing): try:
count = My[thing]
except KeyError:
count = 0 return count from TestFunction import Mycollect
My = {'a':10, 'b':15, 'c':5}
thing = 'a'
print(Mycollect(My , thing));
第11行的placeholder,需要注意下,是用了占位符,也就是先安排位置,而不先提供具体数据,也就是说都是模型(管道)的构建过程(这里用管道来类比,我觉得比较恰当)。注意下placeholder的语法就可以,指定了type和shape,这里的None表示有多少幅图片是未知的,也就是说样本数是未知的。这里的坑在于,如果我们用print看的话会发现,构建的是张量(Tensor)而不是矩阵,这里对熟悉matlab的同学来说可能是个坑。可以注意下张量的定义方式。
第14和15行是定义了变量,如果只看tf.zeros([10])的话也是个张量的,只是外边又加了变量的声明。所以后边可以直接乘的,这个也不难理解了。
第18行的matmul是张量相乘,然后使用了softmax模型,目的是把结果进行概率化。巧妙,只想说这两个字,这个就是进行归一化,搞算法这个是比较常用的,学校时候这个词很火,我们最终想得到的是一个指定的数组,所以用这个模型来匹配我的规则。
21行是什么,看完就知道是实际的输出,然后在24行做交叉熵。终于又碰到熵这个老朋友了。交叉熵简单理解为概率分布的距离,在这里作为一个loss_function。第27行使用了梯度下降来优化这个loss_function,最终是想找到最优时候的一个模型,这里的最优指的是通过这个模型,得到的结果和实际值最接近。
第30行,创建一个session。
第33行,初始化变量。
第37行,可以去看下next_batch的源码,作用是选取100个样本来训练。
第41行,注意equal函数的作用,第43行来做类型转换,然后取平均值。(代码很巧妙,很优雅,很爽)
最终第44行输出模型的准确率。
好了,这大概就是我的一点点总结了,算是入了个门,接下来我会更多的举一反三,深入掌握其精髓,我会努力走得更远。
作为一个小白,我要继续努力向大牛学习,吃饭去咯,下周再战。