用 Pytorch 实现 FFT

Implementing FFT with Pytorch

我正在尝试使用 Pytorch 中提供的 conv1d 函数来实现 FFT。

正在生成人工信号

import numpy as np
import torch
from torch.autograd import Variable
from torch.nn.functional import conv1d

from scipy import fft, fftpack

import matplotlib.pyplot as plt

%matplotlib inline

# Creating filters

d = 4096 # size of windows

def create_filters(d):
    x = np.arange(0, d, 1)
    wsin = np.empty((d,1,d), dtype=np.float32)
    wcos = np.empty((d,1,d), dtype=np.float32)
    window_mask = 1.0-1.0*np.cos(x)
    for ind in range(d):
        wsin[ind,0,:] = np.sin(2*np.pi*((ind+1)/d)*x)
        wcos[ind,0,:] = np.cos(2*np.pi*((ind+1)/d)*x)

    return wsin,wcos

wsin, wcos = create_filters(d)
wsin_var = Variable(torch.from_numpy(wsin), requires_grad=False)
wcos_var = Variable(torch.from_numpy(wcos),requires_grad=False)

# Creating signal

t = np.linspace(0,1,4096)
x = np.sin(2*np.pi*100*t)+np.sin(2*np.pi*200*t)+np.random.normal(scale=5,size=(4096))

plt.plot(x) 

使用 Pytorch 进行 FFT

signal_input = torch.from_numpy(x.reshape(1,-1),)[:,None,:4096]

signal_input = signal_input.float()

zx = conv1d(signal_input, wsin_var, stride=1).pow(2)+conv1d(signal_input, wcos_var, stride=1).pow(2)

FFT 与 Scipy

fig = plt.figure(figsize=(20,5))
plt.plot(np.abs(fft(x).reshape(-1))[:500])

我的问题

如您所见,两个输出在峰值特性方面非常相似。这意味着我的实施并非完全错误。 然而,也有一些微妙之处,例如频谱的尺度和信噪比。我无法弄清楚这里缺少什么以获得完全相同的结果。

你计算的是功率而不是振幅。 您只需添加行 zx = zx.pow(0.5) 即可取平方根以获得振幅。

从版本 1,8 开始,PyTorch 有一个本地实现 torch.fft:

torch.fft.fft(x)