无法正确使用 tf.while_loop

Unable to use tf.while_loop properly

我的代码:

def entropy(x):
  return tf.convert_to_tensor(skimage.measure_shannon_entropy(np.array(x)))
def calc_entropy(x, fn):
  i = tf.constant(0)
  while_condition = lambda i: tf.less(i, fn)
  #loop
  r = tf.while_loop(while_condition, entropy, x[0, :, :, i])
  return r
a = tf.constant([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]) # shape=(1, 2, 2, 2)
output = calc_entropy(a, 2)

输出:[1.23, 0.12]

但我的代码显示此错误: ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()

尝试这样的事情:

import tensorflow as tf
from skimage.measure.entropy import shannon_entropy
import numpy as np

def entropy(i, v, x):
  v = tf.tensor_scatter_nd_update(v, [[i]], [tf.convert_to_tensor(shannon_entropy(np.array(x[0, :, :, i])))])
  return tf.add(i, 1), v, x

def calc_entropy(x, fn):
  i = tf.constant(0)
  v = tf.zeros((fn,), dtype=tf.float64)
  while_condition = lambda i, v, x: tf.less(i, fn)
  _, v, _ = tf.while_loop(while_condition, entropy, loop_vars=(i, v, x))
  return v 

a = tf.constant([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])
output = calc_entropy(a, 2)
print(output)
tf.Tensor([2. 2.], shape=(2,), dtype=float64)

但是,我不知道你期望的输出如何[1.23, 0.12]。手动检查计算:

a = np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])
print(shannon_entropy(a[0, :, :, 0]))
print(shannon_entropy(a[0, :, :, 1]))
2.0
2.0