在 TensorFlow 中,函数 'tf.one_hot' 中的参数 'axis' 是什么

In TensorFlow, what is the argument 'axis' in the function 'tf.one_hot'

谁能帮忙解释一下 TensorFlowone_hot 函数中的 axis 是什么?

根据 documentation:

axis: The axis to fill (default: -1, a new inner-most axis)

最近我在 SO was an explanation relevant to Pandas 上找到答案:

不确定上下文是否同样适用。

这是一个例子:

x = tf.constant([0, 1, 2])

... 是输入张量,N=4(每个索引被转换为 4D 向量)。

axis=-1

计算 one_hot_1 = tf.one_hot(x, 4).eval() 产生一个 (3, 4) 张量:

[[ 1.  0.  0.  0.]
 [ 0.  1.  0.  0.]
 [ 0.  0.  1.  0.]]

...最后一个维度是单热编码的(清晰可见)。这对应于默认的 axis=-1,即 last

axis=0

现在,计算 one_hot_2 = tf.one_hot(x, 4, axis=0).eval() 会产生一个 (4, 3) 张量,它不能立即识别为单热编码:

[[ 1.  0.  0.]
 [ 0.  1.  0.]
 [ 0.  0.  1.]
 [ 0.  0.  0.]]

这是因为单热编码是沿着0轴完成的,必须转置矩阵才能看到之前的编码。当输入的维度更高时,情况变得更加复杂,但想法是相同的:区别在于用于单热编码的 extra 维度的放置。

对我来说,轴翻译为“你在哪里添加额外的数字来增加维度”。至少我是这样解释它并作为助记符的。

例如你有 [1,2,3,0,2,1] 并且这是形状 (1,6)。这意味着它是一个一维数组。 one_hot 在原始数组的每个位置添加零并将位置转换为 1,为此,原始数组的维度必须比原始数组多 1,轴告诉函数将其添加到何处,这个新维度将识别示例


轴=1

您添加第二个维度并保留第一个维度。这将导致 (6,4) 数组。因此,对于生成的数组,您使用第一个维度 (0) 来了解您看到的示例,使用第二个维度(1,新的)来了解 class 是否处于活动状态。 newArr[0][1]=1 表示示例 0,class 1,在这种情况下表示示例 0 是 class 1。
   0   1   2   3  <- class

[[ 0.  1.  0.  0.]   <- example 0
 [ 0.  0.  1.  0.]   <- example 1
 [ 0.  0.  0.  1.]   <- example 2
 [ 1.  0.  0.  0.]   <- example 3
 [ 0.  0.  1.  0.]   <- example 4
 [ 0.  1.  0.  0.]]  <- example 5

轴=0

您添加第一个维度,现有维度被移动。这将产生一个 (4,6) 数组。因此,对于生成的数组,您使用第一个维度(0,新维度)来了解 class 是否处于活动状态,并使用第二个维度 (1) 来了解您看到的示例。 newArr[0][1]=0 表示 class 0,示例 1,在这种情况下表示示例 1 不是 class 0。
   0   1   2   3   4   5  <- example

[[ 0.  0.  0.  1.  0.  0.]   <- class 0
 [ 1.  0.  0.  0.  0.  1.]   <- class 1
 [ 0.  1.  0.  0.  1.  0.]   <- class 2
 [ 0.  0.  1.  0.  0.  0.]]  <- class 3

对我来说,我是这样理解的—— (注意documentation中的indices只是classes的标签信息,可以是标量也可以是向量也可以是矩阵) 如果您的索引只是标量,则不需要轴。 但是,如果它是一个向量,您可以选择特征的方向和 classes 2` 在这里,one-hot 向量的图像有一行作为深度(class),一列作为相应的特征(标签),所以对于这种情况,轴的值为 0。 同样,如果您想要特征 x 深度,则轴的值为 -1。

同样,如果索引是矩阵,则您可以选择以下方向


(批次表示索引中的行)3

batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0