Keras 梯度 wrt 其他东西

Keras gradient wrt something else

我正在努力实施文章 https://drive.google.com/file/d/1s-qs-ivo_fJD9BU_tM5RY8Hv-opK4Z-H/view 中描述的方法。最终使用的算法在这里(在第 6 页):

我们的想法是进行对抗训练,方法是在网络对小变化最敏感的方向上修改数据,并使用修改后的数据训练网络,但标签与原始数据相同。

我正在尝试使用 MNIST 数据集和 100 个数据的小批量在 Keras 中实现此方法,但我无法理解梯度 wrt r 的计算(第 3 行的第一行算法的步骤)。我不知道如何用 Keras 计算它。这是我的代码:

loss = losses.SparseCategoricalCrossentropy()

for epoch in range(5):
    print(f"Start of epoch {epoch}")
    for step, (xBatchTrain,yBatchTrain) in enumerate(trainDataset):
        #Generating the 100 unit vectors
        randomVectors = np.random.random(xBatchTrain.shape)
        U = randomVectors / np.linalg.norm(randomVectors,axis=1)[:, None]

        #Generating the r vectors
        Xi = 2
        R = tf.convert_to_tensor(U * Xi[:, None],dtype='float32')

        dataNoised = xBatchTrain + R

        with tf.GradientTape(persistent=True) as imTape:
            imTape.watch(R)
            #Geting the losses
            C = [loss(label,pred) for label, pred in zip(yBatchTrain,dumbModel(dataNoised,training=False))]

        #Getting the gradient wrt r for each images
        for l,r in zip(C,R):
            print(imTape.gradient(l,r))

每个样本的“print”行 returns None。我应该 return 我是一个包含 784 个值的向量,每个值代表一个像素?

(我很抱歉部分代码很丑,我是 Keras、tf 和深度学习的新手)

[编辑]

这是整个笔记本的要点:https://gist.github.com/DridriLaBastos/136a8e9d02b311e82fe22ec1c2850f78

首先把with tf.GradientTape(persistent=True) as imTape:里面的dataNoised = xBatchTrain + R移动到记录R

相关的操作

其次,而不是使用:

for l,r in zip(C,R):
    print(imTape.gradient(l,r))

你应该使用 imTape.gradient(C,R) 来获取梯度集,因为 zip 会破坏 R 的张量中的操作依赖性,打印出来 return类似于 xBatchTrain:

形状相同的东西
tf.Tensor(
[[-1.4924371e-06  1.0490652e-05 -1.8195267e-05 ...  1.5640746e-05
   3.3767541e-05 -2.0983218e-05]
 [ 2.3668531e-02  1.9133706e-02  3.1396169e-02 ... -1.4431887e-02
   5.3144591e-03  6.2225698e-03]
 [ 2.0492254e-03  7.1049971e-04  1.6121448e-03 ... -1.0579333e-03
   2.4968456e-03  8.3572773e-04]
 ...
 [-4.5572519e-03  6.2278998e-03  6.8322839e-03 ... -2.1966733e-03
   1.0822206e-03  1.8687058e-03]
 [-6.3691144e-03 -4.1699030e-02 -9.3158096e-02 ... -2.9496195e-02
  -7.0264392e-02 -3.2520775e-02]
 [-1.4666058e-02  2.0758331e-02  2.9009990e-02 ... -3.2206681e-02
   3.1550713e-02  4.9267178e-03]], shape=(100, 784), dtype=float32)