MNIST:试图获得高精度
MNIST: Trying to obtain high accuracy
我目前正在研究手写数字识别问题。
首先,我针对 MNIST 数据集测试了示例手写数字。
我的准确率为 53%,我需要 90% 以上的准确率。
以下是我迄今为止为提高准确性所做的尝试。
创建了我自己的数据集
我已经创建了 41,000 个示例。首先,我制作了一个小型数据集,其中包含 10000 个示例(每个数字 1000 个)。
数据集是按照mnist格式的行创建的(后期可能想把我的数据集和mnist数据集拼凑起来)。
以此为基础建立的模型准确率接近65%。
接近
- 我一直在使用 softmax 测试我的输入数据,但它在准确性方面没有多大帮助。所以我开始尝试 cnn 方式。
所以我的问题是:
有没有其他的approach/algorithm,可以更准确的检测数字
我需要更多训练模型吗?
我需要清理图像吗?
我正在研究将 mnist 数据集和我的数据集(41,000 位数据)组合在一起,看看它是否会提高准确性。
代码
针对 mnist(运行 此代码之前的 mnist)测试我的图像
您可以找到以下 Ipyhton 笔记本:
针对 MNIST(脚本 - 1)测试我的示例数字
针对我的数据集测试我的示例数字(脚本 - 2)
脚本和图像可在此 link
先做几点说明:
- cnn 和 softmax 并不互斥。你可以让你的 cnn 用于较低级别。您应该使用 softmax 进行预测(往往效果最好,因为只有一个答案)。
- 关于你的训练很难说。请考虑在训练数据集和评估集上发布损失图。通常我们应该看到两条线向下并趋于平缓。如果他们没有变平,你应该训练更长时间。如果两条线开始分叉,您需要更多的正则化或提前停止。
- 您应该始终尝试各种参数(层数、神经元激活函数的数量等)这称为 hyper-parameter 调整,有一些工具可以帮助您,它们通常会改进很多。
- 除了cnn,你还应该试试深度神经网络。我已经看到罐装 DNNClassifier 的良好效果。
假设您执行了所有这些操作但没有看到任何改进,这可能意味着您的数据存在问题。
查看混淆矩阵,看看模型哪里出了问题。看一些错误分类的例子。根据我的经验,我发现数据集中的 1 和 7 几乎无法区分。这不完全是一个解决方案,但应该为您指明正确的方向以解决您需要修复的问题。
你可以尝试把图片单色化(因为Mnist中每个像素值都在0到255之间),用"is pixel i > 0"做个测试。这将我们的算法提高到 80%。
另外,您可以尝试将图片分成几帧(尝试 4 或 8)。
此外,您还可以基于直线、曲线等构建测试。
你可以看看我的实现,它让我达到了 91%:
https://github.com/orlevy08/Data-Analysis
我目前正在研究手写数字识别问题。
首先,我针对 MNIST 数据集测试了示例手写数字。
我的准确率为 53%,我需要 90% 以上的准确率。
以下是我迄今为止为提高准确性所做的尝试。
创建了我自己的数据集
我已经创建了 41,000 个示例。首先,我制作了一个小型数据集,其中包含 10000 个示例(每个数字 1000 个)。
数据集是按照mnist格式的行创建的(后期可能想把我的数据集和mnist数据集拼凑起来)。 以此为基础建立的模型准确率接近65%。
接近
- 我一直在使用 softmax 测试我的输入数据,但它在准确性方面没有多大帮助。所以我开始尝试 cnn 方式。
所以我的问题是:
有没有其他的approach/algorithm,可以更准确的检测数字
我需要更多训练模型吗?
我需要清理图像吗?
我正在研究将 mnist 数据集和我的数据集(41,000 位数据)组合在一起,看看它是否会提高准确性。
代码
针对 mnist(运行 此代码之前的 mnist)测试我的图像
您可以找到以下 Ipyhton 笔记本:
针对 MNIST(脚本 - 1)测试我的示例数字
针对我的数据集测试我的示例数字(脚本 - 2)
脚本和图像可在此 link
先做几点说明:
- cnn 和 softmax 并不互斥。你可以让你的 cnn 用于较低级别。您应该使用 softmax 进行预测(往往效果最好,因为只有一个答案)。
- 关于你的训练很难说。请考虑在训练数据集和评估集上发布损失图。通常我们应该看到两条线向下并趋于平缓。如果他们没有变平,你应该训练更长时间。如果两条线开始分叉,您需要更多的正则化或提前停止。
- 您应该始终尝试各种参数(层数、神经元激活函数的数量等)这称为 hyper-parameter 调整,有一些工具可以帮助您,它们通常会改进很多。
- 除了cnn,你还应该试试深度神经网络。我已经看到罐装 DNNClassifier 的良好效果。
假设您执行了所有这些操作但没有看到任何改进,这可能意味着您的数据存在问题。 查看混淆矩阵,看看模型哪里出了问题。看一些错误分类的例子。根据我的经验,我发现数据集中的 1 和 7 几乎无法区分。这不完全是一个解决方案,但应该为您指明正确的方向以解决您需要修复的问题。
你可以尝试把图片单色化(因为Mnist中每个像素值都在0到255之间),用"is pixel i > 0"做个测试。这将我们的算法提高到 80%。 另外,您可以尝试将图片分成几帧(尝试 4 或 8)。 此外,您还可以基于直线、曲线等构建测试。
你可以看看我的实现,它让我达到了 91%: https://github.com/orlevy08/Data-Analysis