如何修改pytorch中的rnn单元?

how to modify rnn cells in pytorch?

如果我想更改 RNN 单元(例如 GRU 单元)中的计算规则,我应该怎么做?
考虑到效率问题,我不想通过for或while循环来实现。
我查看了 pytorch 的源代码,但似乎 rnn 单元的主要组件是用我无法找到和修改的 c 代码实现的。 你可以通过一个例子来回答这个问题:implement GRU cell without the existing version

谢谢~

是的,你实现了它"via for or while loop"。 由于 Pytorch 1.0 有 JIT https://pytorch.org/docs/stable/jit.html 工作得很好(由于最近对 JIT 的改进,使用最新的 git 版本的 PyTorch 可能更好),并且取决于您的网络和实现速度原生 PyTorch C++ 实现(但仍然比 CuDNN 慢)。

您可以在 https://github.com/pytorch/benchmark/blob/master/rnns/fastrnns/custom_lstms.py

查看示例实现