具有字符串输入的 Tensorflow 数据集不保留数据类型

Tensorflow Datasets with string inputs do not preserve data type

下面的所有可重现代码是运行在Google Colab with TF 2.2.0-rc2.

改编 documentation 中的简单示例以从简单的 Python 列表创建数据集:

import numpy as np
import tensorflow as tf
tf.__version__
# '2.2.0-rc2'
np.version.version
# '1.18.2'

dataset1 = tf.data.Dataset.from_tensor_slices([1, 2, 3]) 
for element in dataset1: 
  print(element) 
  print(type(element.numpy()))

我们得到结果

tf.Tensor(1, shape=(), dtype=int32)
<class 'numpy.int32'>
tf.Tensor(2, shape=(), dtype=int32)
<class 'numpy.int32'>
tf.Tensor(3, shape=(), dtype=int32)
<class 'numpy.int32'>

其中所有数据类型都是 int32,正如预期的那样。

但是改变这个简单的例子来提供一个字符串列表而不是整数:

dataset2 = tf.data.Dataset.from_tensor_slices(['1', '2', '3']) 
for element in dataset2: 
  print(element) 
  print(type(element.numpy()))

给出结果

tf.Tensor(b'1', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'2', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'3', shape=(), dtype=string)
<class 'bytes'>

令人惊讶的是,尽管张量本身是 dtype=string,但它们的评估类型是 bytes

此行为并不局限于 .from_tensor_slices 方法;这是 .list_files 的情况(以下片段 运行 在新的 Colab 笔记本中很简单):

disc_data = tf.data.Dataset.list_files('sample_data/*.csv') # 4 csv files
for element in disc_data: 
  print(element) 
  print(type(element.numpy()))

结果是:

tf.Tensor(b'sample_data/california_housing_test.csv', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'sample_data/mnist_train_small.csv', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'sample_data/california_housing_train.csv', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'sample_data/mnist_test.csv', shape=(), dtype=string)
<class 'bytes'>

再次,评估张量中的文件名返回为 bytes,而不是 string,尽管张量本身是 dtype=string.

使用 .from_generator 方法(此处未显示)也观察到类似的行为。

最后一个演示:如.as_numpy_iterator方法documentation所示,下面的相等条件求值为True

dataset3 = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 
                                               'b': [5, 6]}) 

list(dataset3.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5}, 
                                       {'a': (2, 4), 'b': 6}] 
# True

但如果我们将 b 的元素更改为字符串,则相等条件现在令人惊讶地评估为 False!

dataset4 = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 
                                               'b': ['5', '6']})   # change elements of b to strings

list(dataset4.as_numpy_iterator()) == [{'a': (1, 3), 'b': '5'},   # here
                                       {'a': (2, 4), 'b': '6'}]   # also
# False

可能是由于数据类型不同,因为值本身显然是相同的。


我不是通过学术实验偶然发现这种行为的;我正在尝试使用自定义函数将我的数据传递给 TF 数据集,这些函数从磁盘中读取成对的文件

f = ['filename1', 'filename2']

哪些自定义函数本身运行良好,但通过 TF 数据集映射可以提供

RuntimeError: not a string

经过这次挖掘,如果返回的数据类型确实是 bytes 而不是 string.

,这似乎至少不是无法解释的

那么,这是一个错误(看起来),还是我在这里遗漏了什么?

这是已知行为:

发件人:https://github.com/tensorflow/tensorflow/issues/5552#issuecomment-260455136

TensorFlow converts str to bytes in most places, including sess.run, and this is unlikely to change. The user is free to convert back, but unfortunately it's too large a change to add a unicode dtype to the core. Closing as won't fix for now.

我想 TensorFlow 没有任何变化 2.x - 仍然有一些地方将字符串转换为字节,您必须手动处理。

issue你自己打开来看,他们似乎把这个问题当作Numpy的问题,而不是Tensorflow本身的问题。