Tensorflow 2.3: AttributeError: 'Tensor' object has no attribute 'numpy'

Tensorflow 2.3: AttributeError: 'Tensor' object has no attribute 'numpy'

我想加载从 here 借来的文本文件,其中每一行代表一个 json 字符串,如下所示:

{"overall": 2.0, "verified": true, "reviewTime": "02 4, 2014", "reviewerID": "A1M117A53LEI8", "asin": "7508492919", "reviewerName": "Sharon Williams", "reviewText": "DON'T CARE FOR IT.  GAVE IT AS A GIFT AND THEY WERE OKAY WITH IT.  JUST NOT WHAT I EXPECTED.", "summary": "CASE", "unixReviewTime": 1391472000}

我只想使用 tensorflow 从数据集中提取 reviewTextoverall 特征,但面临以下错误。

AttributeError: in user code:

    <ipython-input-4-419019a35c5e>:9 None  *
        line_dataset = line_dataset.map(lambda row: transform(row))
    <ipython-input-4-419019a35c5e>:2 transform  *
        str_example = example.numpy().decode("utf-8")

    AttributeError: 'Tensor' object has no attribute 'numpy'

我的代码片段如下所示:

def transform(example):
    str_example = example.numpy().decode("utf-8")
    json_example = json.loads(str_example)
    overall = json_example.get('overall', None)
    text = json_example.get('reviewText', None)
    return (overall, text)

line_dataset = tf.data.TextLineDataset(filenames = [file_path])
line_dataset = line_dataset.map(lambda row: transform(row))
for example in line_dataset.take(5):
    print(example)

我正在使用 tensorflow 2.3.0。

数据集的输入管道始终跟踪到图形中(就好像您使用 @tf.function) to make it faster, which means, among other things, that you cannot use .numpy(). You can however use tf.numpy_function 在图形中以 NumPy 数组的形式访问数据:

def transform(example):
    # example will now by a NumPy array
    str_example = example.decode("utf-8")
    json_example = json.loads(str_example)
    overall = json_example.get('overall', None)
    text = json_example.get('reviewText', None)
    return (overall, text)

line_dataset = tf.data.TextLineDataset(filenames = [file_path])
line_dataset = line_dataset.map(
    lambda row: tf.numpy_function(transform, row, (tf.float32, tf.string)))
for example in line_dataset.take(5):
    print(example)

有点罗嗦,但可以这样尝试:

def transform(example):     
    str_example = example.numpy().decode("utf-8")     
    json_example = json.loads(str_example)     
    overall = json_example.get('overall', None)     
    text = json_example.get('reviewText', None)     
    return (overall, text)  

line_dataset = tf.data.TextLineDataset(filenames = [file_path]) 
line_dataset = line_dataset.map(
    lambda input:     
        tf.py_function(transform, [input], (tf.float32, tf.string))
)  
for example in line_dataset.take(5):     
    print(example)

此特定代码段适用于任何 python 函数,而不仅仅是 for numpy 函数。所以,如果你需要printinput等函数,你可以使用这个。你不必知道所有细节,但如果你有兴趣,请问我。 :)