如何从张量流数据集中获取标签
How to get the labels from tensorflow dataset
ds_test = tf.data.experimental.make_csv_dataset(
file_pattern = "./dfj_test/part-*.csv.gz",
batch_size=batch_size, num_epochs=1,
#column_names=use_cols,
label_name='label_id',
#select_columns= select_cols,
num_parallel_reads=30, compression_type='GZIP',
shuffle_buffer_size=12800)
这是我训练时的测试集。完成模型后,我想压缩 df_test
的预测和标签列。
preds = model.predict(df_test)
获取预测很简单,而且是numpy数组格式。但是,我不知道如何从 df_test 中获取相应的标签。
我想 zip(preds, labels) 进行进一步分析。
有什么提示吗?谢谢
(tf 版本 2.3.1)
您可以将每个示例映射到return您想要的字段
# load some exemplary data
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
dataset = tf.data.experimental.make_csv_dataset(train_file_path, batch_size=100, num_epochs=1)
# get field by unbatching
labels_iterator= dataset.unbatch().map(lambda x: x['survived']).as_numpy_iterator()
labels = np.array(list(labels_iterator))
# get field by concatenating batches
labels_iterator= dataset.map(lambda x: x['survived']).as_numpy_iterator()
labels = np.concatenate(list(labels_iterator))
ds_test = tf.data.experimental.make_csv_dataset(
file_pattern = "./dfj_test/part-*.csv.gz",
batch_size=batch_size, num_epochs=1,
#column_names=use_cols,
label_name='label_id',
#select_columns= select_cols,
num_parallel_reads=30, compression_type='GZIP',
shuffle_buffer_size=12800)
这是我训练时的测试集。完成模型后,我想压缩 df_test
的预测和标签列。
preds = model.predict(df_test)
获取预测很简单,而且是numpy数组格式。但是,我不知道如何从 df_test 中获取相应的标签。 我想 zip(preds, labels) 进行进一步分析。 有什么提示吗?谢谢
(tf 版本 2.3.1)
您可以将每个示例映射到return您想要的字段
# load some exemplary data
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
dataset = tf.data.experimental.make_csv_dataset(train_file_path, batch_size=100, num_epochs=1)
# get field by unbatching
labels_iterator= dataset.unbatch().map(lambda x: x['survived']).as_numpy_iterator()
labels = np.array(list(labels_iterator))
# get field by concatenating batches
labels_iterator= dataset.map(lambda x: x['survived']).as_numpy_iterator()
labels = np.concatenate(list(labels_iterator))