如何在keras中手动获取与model.predict()相同的输出
How to manually obtain the same output as model.predict() in keras
我正在尝试通过 Numpy 重现我将使用 Keras 获得的输出 model.predict()
。我的 keras 模型层如下:
_________________________________________________________________
Layer (type) Output Shape Param
=================================================================
main_input (InputLayer) (None, 10, 76) 0
_________________________________________________________________
masking (Masking) (None, 10, 76) 0
_________________________________________________________________
rnn (SimpleRNN) [(None, 64), (None, 64)] 9024
_________________________________________________________________
dropout_15 (Dropout) (None, 64) 0
_________________________________________________________________
dense1 (Dense) (None, 64) 4160
_________________________________________________________________
denseoutput (Dense) (None, 1) 65
=================================================================
Total params: 13,249
Trainable params: 13,249
Non-trainable params: 0
SimpleRNN层的第二个输出是return_state=True
返回的状态。
我尝试了两种不同的方法。首先,我计算了 WXt + Us + b,其中 W 是内核,Xt 是输入,U是循环核,s是通过return_state=True
得到的状态,b 是偏见。这返回了与使用 predict()
(函数 mult_1
)获得的输出类似的输出。
之后,我用函数 mult_2
尝试了类似的方法,但得到的结果比 mult_1
更差。
def mult_1(X):
X = ma.masked_values(X, -99)
s = (model.predict(X)[1])
W = (model.get_weights()[0])
U = (model.get_weights()[1])
b = (model.get_weights()[2])
Wx = np.dot(X[:,-1,:], W)
Us = np.dot(s,U)
output = Wx + Us + b
return np.tanh(output)
def mult2(X):
max_habitantes = X.shape[1]
i = 0
s_0 = np.ones((X.shape[0], 64)) # initial state
X = ma.masked_values(X, -99)
while i < 10:
Xt = X[:,i,:]
if i == 0:
s = s_0
else:
s = output
W = (model.get_weights()[0])
U = (model.get_weights()[1])
b = (model.get_weights()[2])
Wx = np.dot(Xt, W)
Us = np.dot(s,U)
output = np.tanh(Wx + Us +b)
i = i+1
return output
预测有些偏差,尽管与 predict()
的预测没有什么不同。我做错了一些乘法吗?
您应该使用零数组作为 mult_2 中 rnn 的初始状态。
以下两段代码将为您提供相同的结果:
x = np.random.rand(1,10,76)
使用 Keras model.predict()
inputs = Input(shape=(10,76), dtype=np.float32)
_, state = SimpleRNN(units=64, return_state=True)(inputs)
out_drop = Dropout(0.2)(state)
out_d1 = Dense(64, activation='tanh')(out_drop)
out = Dense(1, activation='tanh')(out_d1)
model = Model(inputs, out)
In [1]: model.predict(x)
Out[1]: array([[-0.82426485]]
使用 numpy 函数进行预测:
def rnn_pred(X):
"""
Same as your mult_2 func. but with zero init. for rnn initial state
"""
W = (model.get_weights()[0])
U = (model.get_weights()[1])
b = (model.get_weights()[2])
max_habitantes = X.shape[1]
i = 0
s_0 = np.zeros((X.shape[0], 64)) # initial state
while i < 10:
Xt = X[:,i,:]
if i == 0:
s = s_0
else:
s = output
Wx = np.dot(Xt, W)
Us = np.dot(s,U)
output = np.tanh(Wx+Us+b)
i = i+1
return output
def dense_pred(rnn_out):
U_d1 = (model.get_weights()[3]) # dense64 weights
b_d1 = (model.get_weights()[4]) # dense64 bias
U_d2 = (model.get_weights()[5]) # dense1 weights
b_d2 = (model.get_weights()[6]) # dense1 bias
out1 = np.dot(rnn_out, U_d1) + b_d1
out1 = np.tanh(out1)
out2 = np.dot(out1, U_d2) + b_d2
out2 = np.tanh(out2)
return out2
In [2]: dense_pred(rnn_pred(x))
Out[2]: array([[-0.82426485]])
我正在尝试通过 Numpy 重现我将使用 Keras 获得的输出 model.predict()
。我的 keras 模型层如下:
_________________________________________________________________
Layer (type) Output Shape Param
=================================================================
main_input (InputLayer) (None, 10, 76) 0
_________________________________________________________________
masking (Masking) (None, 10, 76) 0
_________________________________________________________________
rnn (SimpleRNN) [(None, 64), (None, 64)] 9024
_________________________________________________________________
dropout_15 (Dropout) (None, 64) 0
_________________________________________________________________
dense1 (Dense) (None, 64) 4160
_________________________________________________________________
denseoutput (Dense) (None, 1) 65
=================================================================
Total params: 13,249
Trainable params: 13,249
Non-trainable params: 0
SimpleRNN层的第二个输出是return_state=True
返回的状态。
我尝试了两种不同的方法。首先,我计算了 WXt + Us + b,其中 W 是内核,Xt 是输入,U是循环核,s是通过return_state=True
得到的状态,b 是偏见。这返回了与使用 predict()
(函数 mult_1
)获得的输出类似的输出。
之后,我用函数 mult_2
尝试了类似的方法,但得到的结果比 mult_1
更差。
def mult_1(X):
X = ma.masked_values(X, -99)
s = (model.predict(X)[1])
W = (model.get_weights()[0])
U = (model.get_weights()[1])
b = (model.get_weights()[2])
Wx = np.dot(X[:,-1,:], W)
Us = np.dot(s,U)
output = Wx + Us + b
return np.tanh(output)
def mult2(X):
max_habitantes = X.shape[1]
i = 0
s_0 = np.ones((X.shape[0], 64)) # initial state
X = ma.masked_values(X, -99)
while i < 10:
Xt = X[:,i,:]
if i == 0:
s = s_0
else:
s = output
W = (model.get_weights()[0])
U = (model.get_weights()[1])
b = (model.get_weights()[2])
Wx = np.dot(Xt, W)
Us = np.dot(s,U)
output = np.tanh(Wx + Us +b)
i = i+1
return output
预测有些偏差,尽管与 predict()
的预测没有什么不同。我做错了一些乘法吗?
您应该使用零数组作为 mult_2 中 rnn 的初始状态。 以下两段代码将为您提供相同的结果:
x = np.random.rand(1,10,76)
使用 Keras model.predict()
inputs = Input(shape=(10,76), dtype=np.float32)
_, state = SimpleRNN(units=64, return_state=True)(inputs)
out_drop = Dropout(0.2)(state)
out_d1 = Dense(64, activation='tanh')(out_drop)
out = Dense(1, activation='tanh')(out_d1)
model = Model(inputs, out)
In [1]: model.predict(x)
Out[1]: array([[-0.82426485]]
使用 numpy 函数进行预测:
def rnn_pred(X):
"""
Same as your mult_2 func. but with zero init. for rnn initial state
"""
W = (model.get_weights()[0])
U = (model.get_weights()[1])
b = (model.get_weights()[2])
max_habitantes = X.shape[1]
i = 0
s_0 = np.zeros((X.shape[0], 64)) # initial state
while i < 10:
Xt = X[:,i,:]
if i == 0:
s = s_0
else:
s = output
Wx = np.dot(Xt, W)
Us = np.dot(s,U)
output = np.tanh(Wx+Us+b)
i = i+1
return output
def dense_pred(rnn_out):
U_d1 = (model.get_weights()[3]) # dense64 weights
b_d1 = (model.get_weights()[4]) # dense64 bias
U_d2 = (model.get_weights()[5]) # dense1 weights
b_d2 = (model.get_weights()[6]) # dense1 bias
out1 = np.dot(rnn_out, U_d1) + b_d1
out1 = np.tanh(out1)
out2 = np.dot(out1, U_d2) + b_d2
out2 = np.tanh(out2)
return out2
In [2]: dense_pred(rnn_pred(x))
Out[2]: array([[-0.82426485]])