scipy.interpolate.RegularGridInterpolator 的正确用法
Correct usage of scipy.interpolate.RegularGridInterpolator
我对 documentation for scipy.interpolate.RegularGridInterpolator 有点困惑。
比如说我有一个函数 f: R^3 => R 是在单位立方体的顶点上采样的。我想插值以便在立方体中找到值。
import numpy as np
# Grid points / sample locations
X = np.array([[0,0,0], [0,0,1], [0,1,0], [0,1,1], [1,0,0], [1,0,1], [1,1,0], [1,1,1.]])
# Function values at the grid points
F = np.random.rand(8)
现在,RegularGridInterpolator
接受一个 points
个参数,一个 values
个参数。
points : tuple of ndarray of float, with shapes (m1, ), ..., (mn, )
The points defining the regular grid in n dimensions.
values : array_like, shape (m1, ..., mn, ...)
The data on the regular grid in n dimensions.
我将其解释为能够这样调用:
import scipy.interpolate as irp
rgi = irp.RegularGridInterpolator(X, F)
但是,当我这样做时,出现以下错误:
ValueError: There are 8 point arrays, but values has 1 dimensions
我在文档中误解了什么?
好吧,当我回答自己的问题时,我觉得自己很傻,但我在原始 regulargrid
库的文档的帮助下发现了我的错误:
https://github.com/JohannesBuchner/regulargrid
points
应该是一个数组列表,指定点沿每个轴的间隔方式。
比如上面的单位立方体,我应该设置:
pts = ( np.array([0,1.]), )*3
或者如果我有沿最后一个轴以更高分辨率采样的数据,我可能会设置:
pts = ( np.array([0,1.]), np.array([0,1.]), np.array([0,0.5,1.]) )
最后,values
的形状必须与 points
隐式布局的网格相对应。例如,
val_size = map(lambda q: q.shape[0], pts)
vals = np.zeros( val_size )
# make an arbitrary function to test:
func = lambda pt: (pt**2).sum()
# collect func's values at grid pts
for i in range(pts[0].shape[0]):
for j in range(pts[1].shape[0]):
for k in range(pts[2].shape[0]):
vals[i,j,k] = func(np.array([pts[0][i], pts[1][j], pts[2][k]]))
最后,
rgi = irp.RegularGridInterpolator(points=pts, values=vals)
根据需要运行和执行。
你的回答比较好,完全可以采纳。我只是将其添加为 "alternate" 编写脚本的方式。
import numpy as np
import scipy.interpolate as spint
RGI = spint.RegularGridInterpolator
x = np.linspace(0, 1, 3) # or 0.5*np.arange(3.) works too
# populate the 3D array of values (re-using x because lazy)
X, Y, Z = np.meshgrid(x, x, x, indexing='ij')
vals = np.sin(X) + np.cos(Y) + np.tan(Z)
# make the interpolator, (list of 1D axes, values at all points)
rgi = RGI(points=[x, x, x], values=vals) # can also be [x]*3 or (x,)*3
tst = (0.47, 0.49, 0.53)
print rgi(tst)
print np.sin(tst[0]) + np.cos(tst[1]) + np.tan(tst[2])
returns:
1.93765972087
1.92113615659
我对 documentation for scipy.interpolate.RegularGridInterpolator 有点困惑。
比如说我有一个函数 f: R^3 => R 是在单位立方体的顶点上采样的。我想插值以便在立方体中找到值。
import numpy as np
# Grid points / sample locations
X = np.array([[0,0,0], [0,0,1], [0,1,0], [0,1,1], [1,0,0], [1,0,1], [1,1,0], [1,1,1.]])
# Function values at the grid points
F = np.random.rand(8)
现在,RegularGridInterpolator
接受一个 points
个参数,一个 values
个参数。
points : tuple of ndarray of float, with shapes (m1, ), ..., (mn, ) The points defining the regular grid in n dimensions.
values : array_like, shape (m1, ..., mn, ...) The data on the regular grid in n dimensions.
我将其解释为能够这样调用:
import scipy.interpolate as irp
rgi = irp.RegularGridInterpolator(X, F)
但是,当我这样做时,出现以下错误:
ValueError: There are 8 point arrays, but values has 1 dimensions
我在文档中误解了什么?
好吧,当我回答自己的问题时,我觉得自己很傻,但我在原始 regulargrid
库的文档的帮助下发现了我的错误:
https://github.com/JohannesBuchner/regulargrid
points
应该是一个数组列表,指定点沿每个轴的间隔方式。
比如上面的单位立方体,我应该设置:
pts = ( np.array([0,1.]), )*3
或者如果我有沿最后一个轴以更高分辨率采样的数据,我可能会设置:
pts = ( np.array([0,1.]), np.array([0,1.]), np.array([0,0.5,1.]) )
最后,values
的形状必须与 points
隐式布局的网格相对应。例如,
val_size = map(lambda q: q.shape[0], pts)
vals = np.zeros( val_size )
# make an arbitrary function to test:
func = lambda pt: (pt**2).sum()
# collect func's values at grid pts
for i in range(pts[0].shape[0]):
for j in range(pts[1].shape[0]):
for k in range(pts[2].shape[0]):
vals[i,j,k] = func(np.array([pts[0][i], pts[1][j], pts[2][k]]))
最后,
rgi = irp.RegularGridInterpolator(points=pts, values=vals)
根据需要运行和执行。
你的回答比较好,完全可以采纳。我只是将其添加为 "alternate" 编写脚本的方式。
import numpy as np
import scipy.interpolate as spint
RGI = spint.RegularGridInterpolator
x = np.linspace(0, 1, 3) # or 0.5*np.arange(3.) works too
# populate the 3D array of values (re-using x because lazy)
X, Y, Z = np.meshgrid(x, x, x, indexing='ij')
vals = np.sin(X) + np.cos(Y) + np.tan(Z)
# make the interpolator, (list of 1D axes, values at all points)
rgi = RGI(points=[x, x, x], values=vals) # can also be [x]*3 or (x,)*3
tst = (0.47, 0.49, 0.53)
print rgi(tst)
print np.sin(tst[0]) + np.cos(tst[1]) + np.tan(tst[2])
returns:
1.93765972087
1.92113615659