TensorFlow 中 pytorch NN.module 的别名是什么?
What is the alias to pytorch NN.module in TensorFlow?
我正在尝试在 TensorFlow 中实现三元组注意力。我面临的问题之一是在 TensorFlow
中用什么代替 NN.module
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)
我在 nn.Module
的位置放什么?
在本例中,nn.Module
用于创建自定义图层。 TensorFlow 有这方面的教程,请 take a look。简而言之,实现它的一种方法是使用 tf.keras.layers.Layer
,其中 call
相当于 PyTorch 中的 forward
:
class ChannelPool(tf.keras.layers.Layer):
def call(self, inputs):
return tf.concat((tf.reduce_max(inputs, axis=1, keepdims=True), tf.reduce_mean(inputs, axis=1, keepdims=True)), axis=1)
你可以这样检查它们是否等价:
import torch
from torch import nn
import tensorflow as tf
import numpy as np
class PyTorch_ChannelPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
class TensorFlow_ChannelPool(tf.keras.layers.Layer):
def call(self, inputs):
return tf.concat((tf.reduce_max(inputs, axis=1, keepdims=True), tf.reduce_mean(inputs, axis=1, keepdims=True)), axis=1)
np.random.seed(2021)
x = np.random.random((1,2,3,4)).astype(np.float32)
a = PyTorch_ChannelPool()
b = TensorFlow_ChannelPool()
pytorch_output = a(torch.from_numpy(x)).numpy()
tensorflow_output = b(x).numpy()
np.all(pytorch_output == tensorflow_output)
# >>> True
我正在尝试在 TensorFlow 中实现三元组注意力。我面临的问题之一是在 TensorFlow
中用什么代替NN.module
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)
我在 nn.Module
的位置放什么?
在本例中,nn.Module
用于创建自定义图层。 TensorFlow 有这方面的教程,请 take a look。简而言之,实现它的一种方法是使用 tf.keras.layers.Layer
,其中 call
相当于 PyTorch 中的 forward
:
class ChannelPool(tf.keras.layers.Layer):
def call(self, inputs):
return tf.concat((tf.reduce_max(inputs, axis=1, keepdims=True), tf.reduce_mean(inputs, axis=1, keepdims=True)), axis=1)
你可以这样检查它们是否等价:
import torch
from torch import nn
import tensorflow as tf
import numpy as np
class PyTorch_ChannelPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
class TensorFlow_ChannelPool(tf.keras.layers.Layer):
def call(self, inputs):
return tf.concat((tf.reduce_max(inputs, axis=1, keepdims=True), tf.reduce_mean(inputs, axis=1, keepdims=True)), axis=1)
np.random.seed(2021)
x = np.random.random((1,2,3,4)).astype(np.float32)
a = PyTorch_ChannelPool()
b = TensorFlow_ChannelPool()
pytorch_output = a(torch.from_numpy(x)).numpy()
tensorflow_output = b(x).numpy()
np.all(pytorch_output == tensorflow_output)
# >>> True