用 Pytorch 实现 FFT
Implementing FFT with Pytorch
我正在尝试使用 Pytorch 中提供的 conv1d
函数来实现 FFT。
正在生成人工信号
import numpy as np
import torch
from torch.autograd import Variable
from torch.nn.functional import conv1d
from scipy import fft, fftpack
import matplotlib.pyplot as plt
%matplotlib inline
# Creating filters
d = 4096 # size of windows
def create_filters(d):
x = np.arange(0, d, 1)
wsin = np.empty((d,1,d), dtype=np.float32)
wcos = np.empty((d,1,d), dtype=np.float32)
window_mask = 1.0-1.0*np.cos(x)
for ind in range(d):
wsin[ind,0,:] = np.sin(2*np.pi*((ind+1)/d)*x)
wcos[ind,0,:] = np.cos(2*np.pi*((ind+1)/d)*x)
return wsin,wcos
wsin, wcos = create_filters(d)
wsin_var = Variable(torch.from_numpy(wsin), requires_grad=False)
wcos_var = Variable(torch.from_numpy(wcos),requires_grad=False)
# Creating signal
t = np.linspace(0,1,4096)
x = np.sin(2*np.pi*100*t)+np.sin(2*np.pi*200*t)+np.random.normal(scale=5,size=(4096))
plt.plot(x)
使用 Pytorch 进行 FFT
signal_input = torch.from_numpy(x.reshape(1,-1),)[:,None,:4096]
signal_input = signal_input.float()
zx = conv1d(signal_input, wsin_var, stride=1).pow(2)+conv1d(signal_input, wcos_var, stride=1).pow(2)
FFT 与 Scipy
fig = plt.figure(figsize=(20,5))
plt.plot(np.abs(fft(x).reshape(-1))[:500])
我的问题
如您所见,两个输出在峰值特性方面非常相似。这意味着我的实施并非完全错误。
然而,也有一些微妙之处,例如频谱的尺度和信噪比。我无法弄清楚这里缺少什么以获得完全相同的结果。
你计算的是功率而不是振幅。
您只需添加行 zx = zx.pow(0.5)
即可取平方根以获得振幅。
从版本 1,8 开始,PyTorch 有一个本地实现 torch.fft
:
torch.fft.fft(x)
我正在尝试使用 Pytorch 中提供的 conv1d
函数来实现 FFT。
正在生成人工信号
import numpy as np
import torch
from torch.autograd import Variable
from torch.nn.functional import conv1d
from scipy import fft, fftpack
import matplotlib.pyplot as plt
%matplotlib inline
# Creating filters
d = 4096 # size of windows
def create_filters(d):
x = np.arange(0, d, 1)
wsin = np.empty((d,1,d), dtype=np.float32)
wcos = np.empty((d,1,d), dtype=np.float32)
window_mask = 1.0-1.0*np.cos(x)
for ind in range(d):
wsin[ind,0,:] = np.sin(2*np.pi*((ind+1)/d)*x)
wcos[ind,0,:] = np.cos(2*np.pi*((ind+1)/d)*x)
return wsin,wcos
wsin, wcos = create_filters(d)
wsin_var = Variable(torch.from_numpy(wsin), requires_grad=False)
wcos_var = Variable(torch.from_numpy(wcos),requires_grad=False)
# Creating signal
t = np.linspace(0,1,4096)
x = np.sin(2*np.pi*100*t)+np.sin(2*np.pi*200*t)+np.random.normal(scale=5,size=(4096))
plt.plot(x)
使用 Pytorch 进行 FFT
signal_input = torch.from_numpy(x.reshape(1,-1),)[:,None,:4096]
signal_input = signal_input.float()
zx = conv1d(signal_input, wsin_var, stride=1).pow(2)+conv1d(signal_input, wcos_var, stride=1).pow(2)
FFT 与 Scipy
fig = plt.figure(figsize=(20,5))
plt.plot(np.abs(fft(x).reshape(-1))[:500])
我的问题
如您所见,两个输出在峰值特性方面非常相似。这意味着我的实施并非完全错误。 然而,也有一些微妙之处,例如频谱的尺度和信噪比。我无法弄清楚这里缺少什么以获得完全相同的结果。
你计算的是功率而不是振幅。
您只需添加行 zx = zx.pow(0.5)
即可取平方根以获得振幅。
从版本 1,8 开始,PyTorch 有一个本地实现 torch.fft
:
torch.fft.fft(x)