在 TextVectorization Keras 中调用 adapt 时出错

Error calling adapt in TextVectorization Keras

我有以下代码,带有自定义标准化定义。

def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    regex = tf.strings.regex_replace(lowercase, r'[^\w]', ' ')
    regex = tf.strings.regex_replace(regex, ' +', ' ')

    return tf.strings.split(regex)

vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(
    standardize=custom_standardization,
    max_tokens=50000,
    output_mode="int",
    output_sequence_length=100,
)

但是当我像这样调用adapt时,我得到了下一个错误

vectorize_layer.adapt(['the cat'])
# Error:
InvalidArgumentError: Expected 'tf.Tensor(False, shape=(), dtype=bool)' to be true. Summarized data: b'the given axis (axis = 2) is not squeezable!'

按照他们的解释,

When using a custom callable for split, the data received by the callable will have the 1st dimension squeezed out - instead of [["string to split"], ["another string to split"]], the Callable will see ["string to split", "another string to split"]. The callable should return a Tensor with the first dimension containing the split tokens - in this example, we should see something like [["string", "to", "split"], ["another", "string", "to", "split"]]. This makes the callable site natively compatible with tf.strings.split().

Blockquote Source

但是我看不出错误在哪里

编辑:我对我的代码做了一些研究 当我传递像 ['The other day was raining', 'Please call me later'] 这样的数组时,函数 custom_standardization() returns 就像这样

[['the', 'other', 'day', 'was', 'raining'], ['pleasse', 'call', 'me', 'later']]

如此看来,形状相同是不尊重的。为什么会改变想法?

我提到了你之前分享的document。以下提到了自定义标准化

When using a custom callable for standardize, the data received by the callable will be exactly as passed to this layer. The callable should return a tensor of the same shape as the input.

所以我将 return tf.strings.split(regex) 替换为 return regex(因为拆分在这里改变了形状)。请这样尝试。

import tensorflow as tf

def custom_standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    regex = tf.strings.regex_replace(lowercase, r'[^\w]', ' ')
    regex = tf.strings.regex_replace(regex, ' +', ' ')

    return regex

vectorize_layer = tf.keras.layers.experimental.preprocessing.TextVectorization(
    standardize=custom_standardization,
    max_tokens=50000,
    output_mode="int",
    output_sequence_length=100,
)

#checking input shape and output shape are shape or not 
input = tf.constant([["foo !  @ qux  #bar"], ["qux baz"]])
print(input)
print(custom_standardization(input))

vectorize_layer.adapt(["foo qux bar"])

提供gist供参考。