tensorflow 的 conv1d 和 pytorch 的 conv1d 之间的差异
Discrepancy between tensorflow's conv1d and pytorch's conv1d
我正在尝试将一些 pytorch 代码导入 tensorflow,我开始知道 torch.nn.functional.conv1d() 是 tf.nn.conv1d() 但恐怕仍然存在一些差异tf 的版本。具体来说,我在 tf.conv1d 中找不到组参数。例如:以下代码输出两个不同的结果:
火炬:
inputs = torch.Tensor([[[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]]]) #batch_sizex seq_length x embed_dim,
inputs = inputs.transpose(2,1) #batch_size x embed_dim x seq_length
batch_size, embed_dim, seq_length = inputs.size()
kernel_size = 3
in_channels = 2
out_channels = in_channels
weight = torch.ones(out_channels, 1, kernel_size)
inputs = inputs.contiguous().view(-1, in_channels, seq_length) #batch_size*embed_dim/in_channels x in_channels x seq_length
inputs = F.pad(inputs, (kernel_size-1,0), 'constant', 0)
output = F.conv1d(inputs, weight, padding=0, groups=in_channels)
output = output.contiguous().view(batch_size, embed_dim, seq_length).transpose(2,1)
输出:
tensor([[[1., 1., 1., 1.],
[3., 3., 3., 3.],
[6., 6., 6., 6.]]])
张量流:
inputs = tf.constant([[[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]]], dtype=tf.float32) #batch_sizex seq_length x embed_dim
inputs = tf.transpose(inputs, perm=[0,2,1])
batch_size, embed_dim, seq_length = inputs.get_shape()
print(batch_size, seq_length, embed_dim)
kernel_size = 3
in_channels = 2
out_channels = in_channels
weight = tf.ones([kernel_size, in_channels, out_channels])
inputs = tf.reshape(inputs, [(batch_size*embed_dim)//in_channels, in_channels, seq_length], name='inputs')
inputs = tf.transpose(inputs, perm=[0, 2, 1])
padding = [[0, 0], [(kernel_size - 1), 0], [0, 0]]
padded = tf.pad(inputs, padding)
res = tf.nn.conv1d(padded, weight, 1, 'VALID')
res = tf.transpose(res, perm=[0, 2, 1])
res = tf.reshape(res, [batch_size, embed_dim, seq_length])
res = tf.transpose(res, perm=[0, 2, 1])
print(res)
输出:
[[[ 2. 2. 2. 2.]
[ 6. 6. 6. 6.]
[12. 12. 12. 12.]]], shape=(1, 3, 4), dtype=float32)
不同的结果
这些版本之间没有差异,您只是设置了不同的东西。要获得与 Tensorflow 中完全相同的结果,请将指定权重的行更改为:
weight = torch.ones(out_channels, 2, kernel_size)
,因为你的输入有两个输入通道,正如你在TF中正确声明的那样:
weight = tf.ones([kernel_size, in_channels, out_channels])
组参数
您误解了 groups
参数在 pytorch
中的作用。它限制了每个过滤器使用的通道数(在这种情况下只有一个,因为 2 input_channels
除以 2 给我们一个)。
有关 2D
卷积的更直观解释,请参阅 here。
我正在尝试将一些 pytorch 代码导入 tensorflow,我开始知道 torch.nn.functional.conv1d() 是 tf.nn.conv1d() 但恐怕仍然存在一些差异tf 的版本。具体来说,我在 tf.conv1d 中找不到组参数。例如:以下代码输出两个不同的结果:
火炬:
inputs = torch.Tensor([[[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]]]) #batch_sizex seq_length x embed_dim,
inputs = inputs.transpose(2,1) #batch_size x embed_dim x seq_length
batch_size, embed_dim, seq_length = inputs.size()
kernel_size = 3
in_channels = 2
out_channels = in_channels
weight = torch.ones(out_channels, 1, kernel_size)
inputs = inputs.contiguous().view(-1, in_channels, seq_length) #batch_size*embed_dim/in_channels x in_channels x seq_length
inputs = F.pad(inputs, (kernel_size-1,0), 'constant', 0)
output = F.conv1d(inputs, weight, padding=0, groups=in_channels)
output = output.contiguous().view(batch_size, embed_dim, seq_length).transpose(2,1)
输出:
tensor([[[1., 1., 1., 1.],
[3., 3., 3., 3.],
[6., 6., 6., 6.]]])
张量流:
inputs = tf.constant([[[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]]], dtype=tf.float32) #batch_sizex seq_length x embed_dim
inputs = tf.transpose(inputs, perm=[0,2,1])
batch_size, embed_dim, seq_length = inputs.get_shape()
print(batch_size, seq_length, embed_dim)
kernel_size = 3
in_channels = 2
out_channels = in_channels
weight = tf.ones([kernel_size, in_channels, out_channels])
inputs = tf.reshape(inputs, [(batch_size*embed_dim)//in_channels, in_channels, seq_length], name='inputs')
inputs = tf.transpose(inputs, perm=[0, 2, 1])
padding = [[0, 0], [(kernel_size - 1), 0], [0, 0]]
padded = tf.pad(inputs, padding)
res = tf.nn.conv1d(padded, weight, 1, 'VALID')
res = tf.transpose(res, perm=[0, 2, 1])
res = tf.reshape(res, [batch_size, embed_dim, seq_length])
res = tf.transpose(res, perm=[0, 2, 1])
print(res)
输出:
[[[ 2. 2. 2. 2.]
[ 6. 6. 6. 6.]
[12. 12. 12. 12.]]], shape=(1, 3, 4), dtype=float32)
不同的结果
这些版本之间没有差异,您只是设置了不同的东西。要获得与 Tensorflow 中完全相同的结果,请将指定权重的行更改为:
weight = torch.ones(out_channels, 2, kernel_size)
,因为你的输入有两个输入通道,正如你在TF中正确声明的那样:
weight = tf.ones([kernel_size, in_channels, out_channels])
组参数
您误解了 groups
参数在 pytorch
中的作用。它限制了每个过滤器使用的通道数(在这种情况下只有一个,因为 2 input_channels
除以 2 给我们一个)。
有关 2D
卷积的更直观解释,请参阅 here。