IDE TensorFlow 数据集中的断点 API 映射 py_function?

IDE breakpoint in TensorFlow Dataset API mapped py_function?

我正在使用 Tensorflow Dataset API to prepare my data for input into my network. During this process, I have some custom Python functions which are mapped to the dataset using tf.py_function. I want to be able to debug the data going into these functions and what happens to that data inside these functions. When a py_function is called, this calls back to the main Python process (according to this answer)。由于此函数位于 Python 中,并且在主进程中,我希望常规 IDE 断点能够在此进程中停止。但是,情况似乎并非如此(下面的示例断点不会停止执行)。有没有办法在数据集 map 使用的 py_function 中设置断点?

断点不停止执行的例子

import tensorflow as tf

def add_ten(example, label):
    example_plus_ten = example + 10  # Breakpoint here.
    return example_plus_ten, label

examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels =   [ 0,  0,  1,  1,  1,  1,  0,  0]

examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
                                                                     Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))

tf.data.Dataset 的 Tensorflow 2.0 实现为每个调用打开一个 C 线程,而不通知您的调试器。 使用 pydevd 手动设置将连接到默认调试器服务器并开始向其提供调试数据的跟踪功能。

import pydevd
pydevd.settrace()

您的代码示例:

import tensorflow as tf
import pydevd

def add_ten(example, label):
    pydevd.settrace(suspend=False)
    example_plus_ten = example + 10  # Breakpoint here.
    return example_plus_ten, label

examples = [10, 20, 30, 40, 50, 60, 70, 80]
labels =   [ 0,  0,  1,  1,  1,  1,  0,  0]

examples_dataset = tf.data.Dataset.from_tensor_slices(examples)
labels_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = tf.data.Dataset.zip((examples_dataset, labels_dataset))
dataset = dataset.map(map_func=lambda example, label: tf.py_function(func=add_ten, inp=[example, label],
                                                                     Tout=[tf.int32, tf.int32]))
dataset = dataset.batch(2)
example_and_label = next(iter(dataset))

注意:如果您使用的 IDE 已经捆绑了 pydevd(例如 PyDev 或 PyCharm),您不必单独安装 pydevd,它将在调试会话。