当索引张量是不同维度时,如何向量化索引和计算?

How to vectorize indexing and computation when indexed tensors are different dimensions?

我正在尝试在 Pytorch 中矢量化以下 for 循环。我很乐意只对内部 for 循环进行矢量化,但是做整批也很棒。

# B: the batch size
# N: the number of training examples 
# dim: the dimension of each feature vector
# K: the number of discrete labels. each vector has a single label
# delta: margin for hinge loss

batch_data = torch.tensor(...)  # Tensor of shape [B x N x d]
batch_labels = torch.tensor(...)  # Tensor of shape [B x N x 1], each element is one of K labels (ints)

batch_losses = []  # Ultimately should be [B x 1]
batch_centroids = []  # Ultimately should be [B x K_i x dim]
for i in range(B):
    centroids = []  # Keep track of the means for each class. 
    classes = torch.unique(labels)  # Get the unique labels for the classes.

    # NOTE: The number of classes K for each item in the batch might actually
    # be different. This may complicate batch-level operations.

    total_loss = 0

    # For each class independently. This is the part I want to vectorize.
    for cl in classes:
        # Take the subset of training examples with that label.
        subset = data[torch.where(labels == cl)]

        # Find the centroid of that subset.
        centroid = subset.mean(dim=0)
        centroids.append(centroid)
  
        # Get the distance between each point in the subset and the centroid.
        dists = subset - centroid
        norm = torch.linalg.norm(dists, dim=1)

        # The loss is the mean of the hinge loss across the subset.
        margin = norm - delta
        hinge = torch.clamp(margin, min=0.0) ** 2

        total_loss += hinge.mean()

    # Keep track of everything. If it's too hard to keep track of centroids, that's also OK.
    loss = total_loss.mean()
    batch_losses.append(loss)
    batch_centroids.append(centroids)
   
   

我一直在为如何处理大小不规则的张量而绞尽脑汁。每批K_i中类个数不同,每个子集的大小也不同

事实证明,实际上 可以对参差不齐的数组进行矢量化。我将使用 numpy,但代码应该可以直接翻译成 torch。关键技术是:

  1. 按参差不齐的数组成员排序
  2. 执行累积
  3. 找到边界索引,计算相邻差异

对于 n x d 矩阵 Xn 长度数组 label 的单个(非批处理)输入,以下 returns k x d 质心和 n 到各自质心的长度距离:

def vcentroids(X, label):
    """
    Vectorized version of centroids.
    """        
    # order points by cluster label
    ix = np.argsort(label)
    label = label[ix]
    Xz = X[ix]
    
    # compute pos where pos[i]:pos[i+1] is span of cluster i
    d = np.diff(label, prepend=0) # binary mask where labels change
    pos = np.flatnonzero(d) # indices where labels change
    pos = np.repeat(pos, d[pos]) # repeat for 0-length clusters
    pos = np.append(np.insert(pos, 0, 0), len(X))
    
    Xz = np.concatenate((np.zeros_like(Xz[0:1]), Xz), axis=0)
    Xsums = np.cumsum(Xz, axis=0)
    Xsums = np.diff(Xsums[pos], axis=0)
    counts = np.diff(pos)
    c = Xsums / np.maximum(counts, 1)[:, np.newaxis]
    
    repeated_centroids = np.repeat(c, counts, axis=0)
    aligned_centroids = repeated_centroids[inverse_permutation(ix)]
    dist = np.sum((X - aligned_centroids) ** 2, axis=1)
    
    return c, dist

批处理几乎不需要特殊处理。对于输入 B x n x d 数组 batch_X,具有 B x n 个批次标签 batch_labels,为每个批次创建唯一标签:

batch_k = batch_labels.max(axis=1) + 1
batch_k[1:] = batch_k[:-1]
batch_k[0] = 0
base = np.cumsum(batch_k)
batch_labels += base.expand_dims(1) 

所以现在每个批次元素都有一个唯一的连续标签范围。即,第一个批次元素将在 [0, k0) 的某个范围内具有 n 标签,其中 k0 = batch_k[0],第二个将具有范围 [k0, k0 + k1),其中 k1 = batch_k[1],等等

然后只需将 n x B x d 输入扁平化为 n*B x d 并调用相同的向量化方法。您的损失函数可以使用最终距离和基于相同位置数组的缩减技术推导。

有关矢量化工作原理的详细说明,请参阅 my blog post

如果您对 classes 使用单热编码,对规范使用 pairwise distance trick,则可以对整个事物进行矢量化处理:

import torch

B = 32
N = 1000
dim = 50
K = 25

batch_data = torch.randn((B, N, dim))
batch_labels = torch.randint(0, K, size=(B, N))
batch_one_hot = torch.nn.functional.one_hot(batch_labels)

centroids = torch.matmul(
    batch_one_hot.transpose(-1, 1).type(batch_data.dtype),
    batch_data
) / batch_one_hot.sum(1)[..., None]

norms = torch.linalg.norm(batch_data[:, :, None] - centroids[:, None], axis=-1)

# Compute the rest of your loss
# ...

需要注意的几件事:

  1. 对于任何缺少 class 的批次,您将得到除以零的结果。您可以通过首先分别计算 class 总和(使用 matmul)和计数(对轴 1 上的单热张量求和)来处理此问题。然后,用 count == 0 屏蔽总和,并将其余总和除以它们的 class 计数。
  2. 如果你有大量的classes,这会导致内存问题,因为one-hot tensor会太大。在这种情况下,@VF1 的回答可能更有意义。