如何将 TF Dense 层转换为 PyTorch?

How to translate TF Dense layer to PyTorch?

我想知道是否有人可以帮助我了解如何将一个简短的 TF 模型转换为 Torch。

考虑这个 TF 设置:

inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]
start, end = tf.split(x, 2, axis=-1)
start = tf.squeeze(start, axis = -1)
end = tf.squeeze(end, axis = -1)
model = Model(inputs = inp, outputs = [start, end])

具体来说,我不确定哪个 Torch 命令会将我的数据从 386, 1024, 1 转换为 386, 1024, 2,我也不明白它在做什么:Model(inputs = inp, outputs = [start, end])

是:

inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]

相当于:

X = torch.randn(386, 1024, 1)
X = X.expand(386, 1024, 2)
X.shape [386, 1024, 2]

TF -> Torch 构建模型的时候基本很直接,通常可以在PyTorch documentation中找到相当于TF函数的Torch函数,下面是例子转换TF码:

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]
start, end = tf.split(x, 2, axis=-1)
start = tf.squeeze(start, axis=-1)
end = tf.squeeze(end, axis=-1)
model = models.Model(inputs = inp, outputs = [start, end])

X = np.random.randn(3, 386, 1024, 1)
output = model(X)
print(output[0].shape, output[1].shape)

# Outputs: (3, 386, 1024) (3, 386, 1024)

至手电筒代码:

import torch
from torch import nn

class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.fc = nn.Linear(1, 2)

    def forward(self, x):
      x = self.fc(x)
      start, end = torch.split(x, 1, dim=-1)
      start = torch.squeeze(start, dim=-1)
      end = torch.squeeze(end, dim=-1)
      return [start, end]

net = Net()

X = torch.randn(3, 386, 1024, 1)
output = net(X)
print(output[0].size(), output[1].size())

# Outputs: torch.Size([3, 386, 1024]) torch.Size([3, 386, 1024])

以及以下TF代码:

inp = layers.Input(shape = (386, 1024, 1), dtype = tf.float32)
x = layers.Dense(2)(inp)  # [None, 386, 1024, 2]

不等同于以下 Torch 代码:

X = torch.randn(386, 1024, 1)
X = X.expand(386, 1024, 2)
X.shape [386, 1024, 2]

因为TF中的layers.Dense等同于Torch中的nn.Linear