TensorFlow 2:如何使用散点函数?

TensorFlow 2: How to use scatter function?

我很难理解 TensorFlow 中的分散函数。例如,我想使用 tf.compat.v1.scatter_sub 从第二个索引中进行子索引,如下所示:

a = tf.Variable(tf.random.uniform(shape=[2]))
b = tf.Variable(tf.random.uniform(shape=[3, 2]))

a 是:

<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([0.62174237, 0.7431344 ], dtype=float32)>

b 是:

<tf.Variable 'Variable:0' shape=(3, 2) dtype=float32, numpy=
array([[0.01709783, 0.72348535],
       [0.48500955, 0.7092271 ],
       [0.62199426, 0.26062095]], dtype=float32)>

我想从 b 的第二行减去 a 这样我就可以实现:

array([[0.01709783, 0.72348535 ],
       [-0.13673282, -0.0339073],
       [0.62199426, 0.26062095]], dtype=float32)>

我认为 tf.compat.v1.scatter_sub(b, [1], a) 一定有效,但它失败了。我尝试转置 a 但它也失败了。完整的错误是这样的:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-37-90775abdb544> in <module>()
      9 print("------------------------")
     10 
---> 11 tf.compat.v1.scatter_sub(b, [1], a)

3 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/state_ops.py in scatter_sub(ref, indices, updates, use_locking, name)
    535   return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub(  # pylint: disable=protected-access
    536       ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
--> 537       name=name))
    538 
    539 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py in resource_scatter_sub(resource, indices, updates, name)
   1077     try:
   1078       return resource_scatter_sub_eager_fallback(
-> 1079           resource, indices, updates, name=name, ctx=_ctx)
   1080     except _core._SymbolicException:
   1081       pass  # Add nodes to the TensorFlow graph.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py in resource_scatter_sub_eager_fallback(resource, indices, updates, name, ctx)
   1095   _attrs = ("dtype", _attr_dtype, "Tindices", _attr_Tindices)
   1096   _result = _execute.execute(b"ResourceScatterSub", 0, inputs=_inputs_flat,
-> 1097                              attrs=_attrs, ctx=ctx, name=name)
   1098   _result = None
   1099   return _result

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [2], indices.shape [1], params.shape [3,2] [Op:ResourceScatterSub]

使用此功能的正确方法是什么?

你看过这里的文档了吗? https://www.tensorflow.org/api_docs/python/tf/scatter_nd

参数:

  • 指数
    张量。必须是以下类型之一:int32、int64。索引张量。
  • 更新
    张量。更新分散到输出中。
  • 形状 张量。必须与索引具有相同的类型。一维。结果张量的形状。
  • 名字
    操作的名称(可选)。

我明白了。问题是更新(这里是 a)假设是 列表 几个更新,但在这里我只给它向量 a 本身,不是仅包含 a.

的列表

现在我应该将 a 扩展一维。我的意思是 a 现在是 [0.62174237, 0.7431344],我应该将其更改为 [[0.62174237, 0.7431344 ]] 我可以通过 tf.expand_dims.

来完成

所以解决方案是:

tf.compat.v1.scatter_sub(b, [1], tf.expand_dims(a, axis=0))