TensorFlow 2:如何使用散点函数?
TensorFlow 2: How to use scatter function?
我很难理解 TensorFlow 中的分散函数。例如,我想使用 tf.compat.v1.scatter_sub
从第二个索引中进行子索引,如下所示:
a = tf.Variable(tf.random.uniform(shape=[2]))
b = tf.Variable(tf.random.uniform(shape=[3, 2]))
a
是:
<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([0.62174237, 0.7431344 ], dtype=float32)>
和 b
是:
<tf.Variable 'Variable:0' shape=(3, 2) dtype=float32, numpy=
array([[0.01709783, 0.72348535],
[0.48500955, 0.7092271 ],
[0.62199426, 0.26062095]], dtype=float32)>
我想从 b
的第二行减去 a
这样我就可以实现:
array([[0.01709783, 0.72348535 ],
[-0.13673282, -0.0339073],
[0.62199426, 0.26062095]], dtype=float32)>
我认为 tf.compat.v1.scatter_sub(b, [1], a)
一定有效,但它失败了。我尝试转置 a
但它也失败了。完整的错误是这样的:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-37-90775abdb544> in <module>()
9 print("------------------------")
10
---> 11 tf.compat.v1.scatter_sub(b, [1], a)
3 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/state_ops.py in scatter_sub(ref, indices, updates, use_locking, name)
535 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access
536 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
--> 537 name=name))
538
539
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py in resource_scatter_sub(resource, indices, updates, name)
1077 try:
1078 return resource_scatter_sub_eager_fallback(
-> 1079 resource, indices, updates, name=name, ctx=_ctx)
1080 except _core._SymbolicException:
1081 pass # Add nodes to the TensorFlow graph.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py in resource_scatter_sub_eager_fallback(resource, indices, updates, name, ctx)
1095 _attrs = ("dtype", _attr_dtype, "Tindices", _attr_Tindices)
1096 _result = _execute.execute(b"ResourceScatterSub", 0, inputs=_inputs_flat,
-> 1097 attrs=_attrs, ctx=ctx, name=name)
1098 _result = None
1099 return _result
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [2], indices.shape [1], params.shape [3,2] [Op:ResourceScatterSub]
使用此功能的正确方法是什么?
你看过这里的文档了吗? https://www.tensorflow.org/api_docs/python/tf/scatter_nd
参数:
- 指数
张量。必须是以下类型之一:int32、int64。索引张量。
- 更新
张量。更新分散到输出中。
- 形状
张量。必须与索引具有相同的类型。一维。结果张量的形状。
- 名字
操作的名称(可选)。
我明白了。问题是更新(这里是 a
)假设是 列表 几个更新,但在这里我只给它向量 a
本身,不是仅包含 a
.
的列表
现在我应该将 a
扩展一维。我的意思是 a
现在是 [0.62174237, 0.7431344]
,我应该将其更改为 [[0.62174237, 0.7431344 ]]
我可以通过 tf.expand_dims
.
来完成
所以解决方案是:
tf.compat.v1.scatter_sub(b, [1], tf.expand_dims(a, axis=0))
我很难理解 TensorFlow 中的分散函数。例如,我想使用 tf.compat.v1.scatter_sub
从第二个索引中进行子索引,如下所示:
a = tf.Variable(tf.random.uniform(shape=[2]))
b = tf.Variable(tf.random.uniform(shape=[3, 2]))
a
是:
<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([0.62174237, 0.7431344 ], dtype=float32)>
和 b
是:
<tf.Variable 'Variable:0' shape=(3, 2) dtype=float32, numpy=
array([[0.01709783, 0.72348535],
[0.48500955, 0.7092271 ],
[0.62199426, 0.26062095]], dtype=float32)>
我想从 b
的第二行减去 a
这样我就可以实现:
array([[0.01709783, 0.72348535 ],
[-0.13673282, -0.0339073],
[0.62199426, 0.26062095]], dtype=float32)>
我认为 tf.compat.v1.scatter_sub(b, [1], a)
一定有效,但它失败了。我尝试转置 a
但它也失败了。完整的错误是这样的:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-37-90775abdb544> in <module>()
9 print("------------------------")
10
---> 11 tf.compat.v1.scatter_sub(b, [1], a)
3 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/state_ops.py in scatter_sub(ref, indices, updates, use_locking, name)
535 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access
536 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
--> 537 name=name))
538
539
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py in resource_scatter_sub(resource, indices, updates, name)
1077 try:
1078 return resource_scatter_sub_eager_fallback(
-> 1079 resource, indices, updates, name=name, ctx=_ctx)
1080 except _core._SymbolicException:
1081 pass # Add nodes to the TensorFlow graph.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_resource_variable_ops.py in resource_scatter_sub_eager_fallback(resource, indices, updates, name, ctx)
1095 _attrs = ("dtype", _attr_dtype, "Tindices", _attr_Tindices)
1096 _result = _execute.execute(b"ResourceScatterSub", 0, inputs=_inputs_flat,
-> 1097 attrs=_attrs, ctx=ctx, name=name)
1098 _result = None
1099 return _result
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [2], indices.shape [1], params.shape [3,2] [Op:ResourceScatterSub]
使用此功能的正确方法是什么?
你看过这里的文档了吗? https://www.tensorflow.org/api_docs/python/tf/scatter_nd
参数:
- 指数
张量。必须是以下类型之一:int32、int64。索引张量。 - 更新
张量。更新分散到输出中。 - 形状 张量。必须与索引具有相同的类型。一维。结果张量的形状。
- 名字
操作的名称(可选)。
我明白了。问题是更新(这里是 a
)假设是 列表 几个更新,但在这里我只给它向量 a
本身,不是仅包含 a
.
现在我应该将 a
扩展一维。我的意思是 a
现在是 [0.62174237, 0.7431344]
,我应该将其更改为 [[0.62174237, 0.7431344 ]]
我可以通过 tf.expand_dims
.
所以解决方案是:
tf.compat.v1.scatter_sub(b, [1], tf.expand_dims(a, axis=0))