N维GP回归
N-dimensional GP Regression
我正在尝试使用 GPflow 进行多维回归。但我对均值和方差的形状感到困惑。
例如:假设预测形状为 (20,20) 的二维输入 space X。我的训练样本的形状为 (8,2),这意味着两个维度总共有 8 个训练样本。 y 值的形状为 (8,1),这当然意味着 2 个输入维度的每个组合都有一个真实值。
如果我现在使用 model.predict_y(X),我希望得到一个平均形状 (20,20) 但获得一个 (20,1) 的形状。方差也是如此。我认为这个问题来自于 y 值的形状,但我不知道如何解决它。
bound = 3
num = 20
X = np.random.uniform(-bound, bound, (num,num))
print(X_sample.shape) # (8,2)
print(Y_sample.shape) # (8,1)
k = gpflow.kernels.RBF(input_dim=2)
m = gpflow.models.GPR(X_sample, Y_sample, kern=k)
m.likelihood.variance = sigma_n
m.compile()
gpflow.train.ScipyOptimizer().minimize(m)
mean, var = m.predict_y(X)
print(mean.shape) # (20, 1)
print(var.shape) # (20, 1)
听起来你可能对输入位置网格的形状和 numpy 数组的形状感到困惑:如果你想在二维的 20 x 20 网格上进行预测,你总共有 400 个点,每个都有 2 个值。所以 X(您传递给 m.predict_y()
的那个)应该具有形状 (400, 2)。 (请注意,第二个维度需要与 X_sample 具有相同的形状!)
要构建形状为 (400,2) 的数组,您可以使用 np.meshgrid
(例如,参见 )。
m.predict_y(X)
仅预测每个测试点的 marginal 方差,因此 returned mean
和 var
两者具有形状 (400,1)(与 X 的长度相同)。您当然可以将它们重塑为网格上的 20 x 20 值。
(也可以计算完整的协方差,对于潜在的 f 这被实现为 m.predict_f_full_cov
,对于形状为 (400,2) 的 X 将 return 一个 400x400 矩阵。如果你想从全科医生那里得到一致的样本,这是相关的,但我怀疑这远远超出了这个问题。)
我确实犯了一个错误,没有展平 return 中产生错误的数组。感谢 STJ 的快速回复!
下面是工作代码的示例:
# Generate data
bound = 3.
x1 = np.linspace(-bound, bound, num)
x2 = np.linspace(-bound, bound, num)
x1_mesh,x2_mesh = np.meshgrid(x1, x2)
X = np.dstack([x1_mesh, x2_mesh]).reshape(-1, 2)
z = f(x1_mesh, x2_mesh) # evaluation of the function on the grid
# Draw samples from feature vectors and function by a given index
size = 2
np.random.seed(1991)
index = np.random.choice(range(len(x1)), size=(size,X.ndim), replace=False)
samples = utils.sampleFeature([x1,x2], index)
X1_sample = samples[0]
X2_sample = samples[1]
X_sample = np.column_stack((X1_sample, X2_sample))
Y_sample = utils.samplefromFunc(f=z, ind=index)
# Change noise parameter
sigma_n = 0.0
# Construct models with initial guess
k = gpflow.kernels.RBF(2,active_dims=[0,1], lengthscales=1.0,ARD=True)
m = gpflow.models.GPR(X_sample, Y_sample, kern=k)
m.likelihood.variance = sigma_n
m.compile()
#print(X.shape)
mean, var = m.predict_y(X)
mean_square = mean.reshape(x1_mesh.shape) # Shape: (num,num)
var_square = var.reshape(x1_mesh.shape) # Shape: (num,num)
# Plot mean
fig = plt.figure(figsize=(16, 12))
ax = plt.axes(projection='3d')
ax.plot_surface(x1_mesh, x2_mesh, mean_square, cmap=cm.viridis, linewidth=0.5, antialiased=True, alpha=0.8)
cbar = ax.contourf(x1_mesh, x2_mesh, mean_square, zdir='z', offset=offset, cmap=cm.viridis, antialiased=True)
ax.scatter3D(X1_sample, X2_sample, offset, marker='o',edgecolors='k', color='r', s=150)
fig.colorbar(cbar)
for t in ax.zaxis.get_major_ticks(): t.label.set_fontsize(fontsize_ticks)
ax.set_title("$\mu(x_1,x_2)$", fontsize=fontsize_title)
ax.set_xlabel("\n$x_1$", fontsize=fontsize_label)
ax.set_ylabel("\n$x_2$", fontsize=fontsize_label)
ax.set_zlabel('\n\n$\mu(x_1,x_2)$', fontsize=fontsize_label)
plt.xticks(fontsize=fontsize_ticks)
plt.yticks(fontsize=fontsize_ticks)
plt.xlim(left=-bound, right=bound)
plt.ylim(bottom=-bound, top=bound)
ax.set_zlim3d(offset,np.max(z))
这导致(红点是从函数中提取的样本点)。注意:代码从未重构过什么:)
我正在尝试使用 GPflow 进行多维回归。但我对均值和方差的形状感到困惑。 例如:假设预测形状为 (20,20) 的二维输入 space X。我的训练样本的形状为 (8,2),这意味着两个维度总共有 8 个训练样本。 y 值的形状为 (8,1),这当然意味着 2 个输入维度的每个组合都有一个真实值。 如果我现在使用 model.predict_y(X),我希望得到一个平均形状 (20,20) 但获得一个 (20,1) 的形状。方差也是如此。我认为这个问题来自于 y 值的形状,但我不知道如何解决它。
bound = 3
num = 20
X = np.random.uniform(-bound, bound, (num,num))
print(X_sample.shape) # (8,2)
print(Y_sample.shape) # (8,1)
k = gpflow.kernels.RBF(input_dim=2)
m = gpflow.models.GPR(X_sample, Y_sample, kern=k)
m.likelihood.variance = sigma_n
m.compile()
gpflow.train.ScipyOptimizer().minimize(m)
mean, var = m.predict_y(X)
print(mean.shape) # (20, 1)
print(var.shape) # (20, 1)
听起来你可能对输入位置网格的形状和 numpy 数组的形状感到困惑:如果你想在二维的 20 x 20 网格上进行预测,你总共有 400 个点,每个都有 2 个值。所以 X(您传递给 m.predict_y()
的那个)应该具有形状 (400, 2)。 (请注意,第二个维度需要与 X_sample 具有相同的形状!)
要构建形状为 (400,2) 的数组,您可以使用 np.meshgrid
(例如,参见
m.predict_y(X)
仅预测每个测试点的 marginal 方差,因此 returned mean
和 var
两者具有形状 (400,1)(与 X 的长度相同)。您当然可以将它们重塑为网格上的 20 x 20 值。
(也可以计算完整的协方差,对于潜在的 f 这被实现为 m.predict_f_full_cov
,对于形状为 (400,2) 的 X 将 return 一个 400x400 矩阵。如果你想从全科医生那里得到一致的样本,这是相关的,但我怀疑这远远超出了这个问题。)
我确实犯了一个错误,没有展平 return 中产生错误的数组。感谢 STJ 的快速回复!
下面是工作代码的示例:
# Generate data
bound = 3.
x1 = np.linspace(-bound, bound, num)
x2 = np.linspace(-bound, bound, num)
x1_mesh,x2_mesh = np.meshgrid(x1, x2)
X = np.dstack([x1_mesh, x2_mesh]).reshape(-1, 2)
z = f(x1_mesh, x2_mesh) # evaluation of the function on the grid
# Draw samples from feature vectors and function by a given index
size = 2
np.random.seed(1991)
index = np.random.choice(range(len(x1)), size=(size,X.ndim), replace=False)
samples = utils.sampleFeature([x1,x2], index)
X1_sample = samples[0]
X2_sample = samples[1]
X_sample = np.column_stack((X1_sample, X2_sample))
Y_sample = utils.samplefromFunc(f=z, ind=index)
# Change noise parameter
sigma_n = 0.0
# Construct models with initial guess
k = gpflow.kernels.RBF(2,active_dims=[0,1], lengthscales=1.0,ARD=True)
m = gpflow.models.GPR(X_sample, Y_sample, kern=k)
m.likelihood.variance = sigma_n
m.compile()
#print(X.shape)
mean, var = m.predict_y(X)
mean_square = mean.reshape(x1_mesh.shape) # Shape: (num,num)
var_square = var.reshape(x1_mesh.shape) # Shape: (num,num)
# Plot mean
fig = plt.figure(figsize=(16, 12))
ax = plt.axes(projection='3d')
ax.plot_surface(x1_mesh, x2_mesh, mean_square, cmap=cm.viridis, linewidth=0.5, antialiased=True, alpha=0.8)
cbar = ax.contourf(x1_mesh, x2_mesh, mean_square, zdir='z', offset=offset, cmap=cm.viridis, antialiased=True)
ax.scatter3D(X1_sample, X2_sample, offset, marker='o',edgecolors='k', color='r', s=150)
fig.colorbar(cbar)
for t in ax.zaxis.get_major_ticks(): t.label.set_fontsize(fontsize_ticks)
ax.set_title("$\mu(x_1,x_2)$", fontsize=fontsize_title)
ax.set_xlabel("\n$x_1$", fontsize=fontsize_label)
ax.set_ylabel("\n$x_2$", fontsize=fontsize_label)
ax.set_zlabel('\n\n$\mu(x_1,x_2)$', fontsize=fontsize_label)
plt.xticks(fontsize=fontsize_ticks)
plt.yticks(fontsize=fontsize_ticks)
plt.xlim(left=-bound, right=bound)
plt.ylim(bottom=-bound, top=bound)
ax.set_zlim3d(offset,np.max(z))
这导致(红点是从函数中提取的样本点)。注意:代码从未重构过什么:)