深度学习 -- 基于 keras 的手写数字识别示例
MNIST 是 keras 中一个入门级的计算机视觉数据集,这个数据集包含各种各样的手写数字图片,其中包含 60000 张训练图像和 10000 张测试图像。
是解决手写数字分类问题的经典数据集,可以看做是深度学习的 ‘Hello World’。
下载数据
1 | from keras.datasets import mnist |
这两行代码需要先运行,下载数据集。
下载过程会失败,多试几次,不需要翻墙。
train_images
和 train_labels
为训练集, 分别为图片数据和标签数据。模型将从这些数据中进行学习。
然后在测试集: test_images
和 test_labels
上对模型进行测试。
构建网络
1 | from keras import models |
神经网络的核心组件是层(layer),它是一种数据处理模块,可以将它看成数据过滤器。
大多数深度学习都是将简单的层链接起来,从而实现渐进式的数据蒸馏(data distillation)。深度学习模型就像是数据处理的筛子,包含一系列越来越精细的数据过滤器(即层)。
编译
1 | network.compile(optimizer='rmsprop', |
loss: 损失函数,网络如何衡量在训练数据上的性能,即网络如何朝着正确的方向前进。
optimizer: 优化器,基于训练数据和损失函数来更新网络的机制
metrucs: 监控指标,本例只关心精度,即正确分类的图像所占的比例。
准备图像数据
1 | train_images = train_images.reshape((60000, 28 * 28)) |
开始之前,我们需要转换数据格式,让网络能够处理。
训练图像保存在一个uint8类型的数组中,其形状为 (60000, 28, 28),取值区间为 [0, 255]。我们需要将其转换成一个 float32数组,形状为 (60000, 28 * 28),取值范围为 0 ~ 1。
准备标签
1 | from keras.utils import to_categorical |
开始训练
1 | network.fit(train_images, train_labels, epochs=5, batch_size=128) |
测试数据
1 | test_loss, test_acc = network.evaluate(test_images, test_labels) |
测试集精度为97.98%,比训练集精度低不少。
训练精度和测试精度之间的这种差距是过拟合(overfit)造成的。
这些东西接下来再继续学习。
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来源 M-finder!
评论