我怎样才能改进数字预测?

how can i improve number predictions?

我有一些数字分类模型,在测试数据上它工作正常,但是当我想对其他图像进行分类时,我遇到了我的模型无法准确预测它是什么数字的问题。请帮助我提高 model.predict() 性能。

我试过用很多方法训练我的模型,在下面的代码中有一个创建分类模型的函数,我实际上用很多方法训练了这个模型,[1K < n < 60K] 输入测试数据, [3 < e < 50] 次训练迭代。

def load_data():
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

    train_images = tf.keras.utils.normalize(train_images, axis = 1)
    test_images = tf.keras.utils.normalize(test_images, axis = 1)

    return (train_images, train_labels), (test_images, test_labels)

def create_model():
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(128, activation = tf.nn.relu))
    model.add(tf.keras.layers.Dense(128, activation = tf.nn.relu))
    model.add(tf.keras.layers.Dense(10, activation = tf.nn.softmax))

    data = load_data(n=60000, k=5)
    model.compile(optimizer ='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])
    model.fit(data[0][0][:n], data[0][1][:n], epochs = e)# ive tried from 3-50 epochs
    model.save(config.model_name)

def load_model():
    return tf.keras.models.load_model(config.model_name)def predict(images):
    try:
        model = load_model()
    except:
        create_model()
        model = load_model()
    images = tf.keras.utils.normalize(images, axis = 0)
    d = load_data()

    plot_many_images([d[0][0][0].reshape((28,28)), images[0]],['data', 'image'])

    predictions = model.predict(images)
    return predictions

我认为我的输入数据看起来不像是预测模型的数据,但我已尽力使其尽可能相似。这张图(https://imgur.com/FfLGMEK)左边是火车数据图片,右边是我解析后的图片,都是28x28像素,都是cv2.noramalized

对于测试图像预测,我使用了这个(https://imgur.com/RMfKtag) sudoku, it's already formatted to be similar with a test data numbers, but when I test this image with the model prediction the result is not so nice(https://imgur.com/RQFvLNE) 如您所见,预测数据还有很多不足之处。

P.S。预测数据结果中的 (' ') 项是我亲手做的(我已经用 ' ' 替换了那个位置的数字),因为预测后它们都有一些值(1-9),现在不需要了。

你是什么意思"on test data it works OK"?如果您的意思是它对训练数据有效但对测试数据没有很好的预测,那么您的模型可能 over-fit 处于训练阶段。我建议使用 train/validation/test 方法来训练你的网络。