Python教程网
--左手Python,右手AI!

深度学习入门笔记5——基于 tensorflow 的手写数字的识别(基础)

基于 tensorflow 的手写数字的识别(简单版本)

本系列将分为 8 篇 。本次为第 5 篇 ,结合上一篇的应用实例 ,将前边学到一些基础知识用到手写数字的识别分类上 。

1.关于 MNIST 数据集

首先 ,我们得了解 MNIST 数据集 。这是一个手写数字数据集 ,在深度学习入门学习中极具代表性 。可以从官网下载该数据集 ,但事实上 TensorFlow 中提供了一个类来处理 MNIST 数据 ,这个类会自动下载并转化格式 ,将数据从原始的数据包中解析成训练和测试神经网络时使用的格式 ,具体相关函数在接下来代码中介绍 。

MNIST 数据集被分为训练数据集(60000张手写数字图片)和测试数据集(10000张手写数字图片)。

每一张图片包含 28*28 个像素 ,图片里的某个像素的强度值介于0-1之间。例如 ,数字 1 对应一个 28*28 像素图片 ,其像素强度如下 :

我们把这一个数组展开成一个向量 ,长度是 28*28=784 。因此在MNIST训练数据集中 mnist.train.images 是一个形状为 [60000, 784] 的张量,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。

2.one-hot 向量和 Softmax 函数 

MNIST 数据集标签为 0-9 十个数字 ,我们用 one-hot 向量来表示 。以MNIST 数据集为例 ,one-hot 向量指的是除了某一位数字为 1 ,其他维度都为 0 ,比如数字 1 对应 [0,1,0,0,0,0,0,0,0,0] 。

那么我们就可以得到数据集中对应的标签(labels)是若干个 one-hot 向量组成的矩阵 。以训练集为例 ,是一个 [60000,10] 的数字矩阵 。

另一个重要的知识就是 Softmax 函数 。如果是二分类问题 ,我们可以考虑用 sigmoid 或 tanh 等进行分类 ,即分为是或否 。这里是多分类问题 ,softmax 就很合适了 。这里小詹不知道怎么描述容易让大家理解 ,借鉴一个博客链接给出一段较为生动的描述 。

我们知道 max ,假如说我有两个数 ,a 和 b ,并且 a > b ,如果取 max ,那么就直接取 a ,没有第二种可能 。但有的时候我希望分值大的那一项(a) 经常取到 ,分值小的那一项 (b) 也偶尔可以取到 ,那么我用 softmax 就可以了 。(尊重原创 ,附上这段话链接:https://blog.csdn.net/supercally/article/details/54234115)

3.MNIST 数据集识别实战

以上已经对基本的知识进行了介绍 ,这里进行实战讲解 。我们首先要设计一个网络结构 ,然后根据第四讲中的 “三步走” 步骤进行实现 。这里简单版本先设计一个简单到不能更简单的网络实现手写数字的识别分类 。

训练过程 ,每一张图片输入的可以看成一个长度为 784 的向量 ,输出为 0-9 中的一个 ,即有 10 种可能 ,或者说这就是一个 10 分类问题 。所以我们采取输入层 784 个神经元 ,全连接到输出层 10 个神经元 。( 哪个帅哥写的字 ?这么丑 !哈哈)

首先 ,需要读取 MNIST 数据集 ,利用 TF 框架自带类进行下载读取 。

接下来就是根据之前的 “三步走” 进行实践 。实现上述的最简单的网络结构 ,并依旧选择二次代价函数和梯度下降法 。

再在会话 Session 中执行 。代码如下 :

这里小詹讲一下下面这两行代码如何求出了 accuracy 。

因为这里无论是数据集中的 labels ,还是预测值 prediction 都是以 one-hot 向量形式存在 。tf.argmax 返回一维张量中最大值所在位置 ,若某一张图片数据的 label 和对该图片的预测 最大值在同一个位置(例如数字 3 ,预测结果和 label 对应的 one-hot 向量都为[0,0,0,1,0,0,0,0,0,0]),此时 tf.equal 则返回值为 1 ,反之为 0 。即预测正确为 1 ,错误为 0 。

之后利用 tf.reduce_mean() 函数将所有的correct_prediction 求平均值 ,比如测试 10 张图片 ,上述有 9 张正确(9个1),1 张错误(1个0)。则平均值为 0.9,就是预测精度了 。

那么利用以上网络和代码得到的结果是怎样的呢 ?下面给出结果 。

往期推荐:

1. 深度学习入门笔记系列 ( 四 )

2. 深度学习入门笔记系列 ( 三 )

以上代码获取方式 :公号后台回复关键词【MNIST】即可获取 。

赞(9) 打赏
未经允许不得转载:Python教程网 » 深度学习入门笔记5——基于 tensorflow 的手写数字的识别(基础)
分享到: 更多 (0)

评论 抢沙发

一块钱也是爱,支持最重要~

支付宝扫一扫打赏

微信扫一扫打赏