进行一步预测
Making One Step Forecast Predictions
我正在制作一个 LSTM 模型,并在我在 kaggle 上找到的 TSLA 数据集上对其进行训练。所以我的问题是,当我打电话给 model.predict 时,这个预测会给我第二天的股票价格吗?这是一步预测吗?当我打印 model.predict 时,我得到了一个巨大的列表,所以我使用 numpy argmax 函数给我一个数字。这是代码:
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense, Dropout, Input, GlobalMaxPooling1D
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.optimizers import Adam
df = pd.read_csv('TSLA.csv')
series = df['Close'].values.reshape(-1, 1)
scaler = StandardScaler()
scaler.fit(series[:len(series)//2])
series = scaler.transform(series).flatten()
X = []
Y = []
T = 10
D = 1
for t in range(len(series) - T):
X.append(series[t:t+T])
Y.append(series[t+T])
X = np.array(X).reshape(-1, T, D)
Y = np.array(Y)
N = len(X)
print(X.shape, Y.shape)
model = tf.keras.Sequential([
Input(shape=(T, D)),
LSTM(50),
Dense(100, activation='relu'),
Dropout(0.25),
Dense(1)
])
model.compile(optimizer=Adam(lr=0.01), loss='mse')
r = model.fit(X[:-N//2], Y[:-N//2], validation_data=(X[-N//2:], Y[-N//2:]), epochs=200)
plt.plot(r.history['loss'])
plt.plot(r.history['val_loss'])
plt.show()
preds = model.predict(X)
outs = preds[:,0]
print(outs)
print(np.argmax(outs))
Argmax 在这里没有意义。这 90 个值是训练集第二天的 90 个预测值。当你 运行 这个:
preds = model.predict(X)
它会为您提供训练集所有 90 个数据点的第二天值。这一行:
print(np.argmax(outs))
没有意义。
顺便说一下,您可以使用 Python 获取股票价格,您不需要 CSV。
pip install pandas-datareader
from pandas_datareader import data as wb
ticker=wb.DataReader('TSLA',start='2015-1-1',data_source='yahoo')
print(ticker)
High Low ... Volume Adj Close
Date ...
2015-01-02 44.650002 42.652000 ... 23822000.0 43.862000
2015-01-05 43.299999 41.431999 ... 26842500.0 42.018002
2015-01-06 42.840000 40.841999 ... 31309500.0 42.256001
2015-01-07 42.956001 41.956001 ... 14842000.0 42.189999
2015-01-08 42.759998 42.001999 ... 17212500.0 42.124001
... ... ... ... ...
2020-10-13 448.890015 436.600006 ... 34463700.0 446.649994
2020-10-14 465.899994 447.350006 ... 48045400.0 461.299988
2020-10-15 456.570007 442.500000 ... 35672400.0 448.880005
2020-10-16 455.950012 438.850006 ... 32620000.0 439.670013
2020-10-19 447.000000 437.649994 ... 9422697.0 442.840607
我正在制作一个 LSTM 模型,并在我在 kaggle 上找到的 TSLA 数据集上对其进行训练。所以我的问题是,当我打电话给 model.predict 时,这个预测会给我第二天的股票价格吗?这是一步预测吗?当我打印 model.predict 时,我得到了一个巨大的列表,所以我使用 numpy argmax 函数给我一个数字。这是代码:
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense, Dropout, Input, GlobalMaxPooling1D
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.optimizers import Adam
df = pd.read_csv('TSLA.csv')
series = df['Close'].values.reshape(-1, 1)
scaler = StandardScaler()
scaler.fit(series[:len(series)//2])
series = scaler.transform(series).flatten()
X = []
Y = []
T = 10
D = 1
for t in range(len(series) - T):
X.append(series[t:t+T])
Y.append(series[t+T])
X = np.array(X).reshape(-1, T, D)
Y = np.array(Y)
N = len(X)
print(X.shape, Y.shape)
model = tf.keras.Sequential([
Input(shape=(T, D)),
LSTM(50),
Dense(100, activation='relu'),
Dropout(0.25),
Dense(1)
])
model.compile(optimizer=Adam(lr=0.01), loss='mse')
r = model.fit(X[:-N//2], Y[:-N//2], validation_data=(X[-N//2:], Y[-N//2:]), epochs=200)
plt.plot(r.history['loss'])
plt.plot(r.history['val_loss'])
plt.show()
preds = model.predict(X)
outs = preds[:,0]
print(outs)
print(np.argmax(outs))
Argmax 在这里没有意义。这 90 个值是训练集第二天的 90 个预测值。当你 运行 这个:
preds = model.predict(X)
它会为您提供训练集所有 90 个数据点的第二天值。这一行:
print(np.argmax(outs))
没有意义。
顺便说一下,您可以使用 Python 获取股票价格,您不需要 CSV。
pip install pandas-datareader
from pandas_datareader import data as wb
ticker=wb.DataReader('TSLA',start='2015-1-1',data_source='yahoo')
print(ticker)
High Low ... Volume Adj Close
Date ...
2015-01-02 44.650002 42.652000 ... 23822000.0 43.862000
2015-01-05 43.299999 41.431999 ... 26842500.0 42.018002
2015-01-06 42.840000 40.841999 ... 31309500.0 42.256001
2015-01-07 42.956001 41.956001 ... 14842000.0 42.189999
2015-01-08 42.759998 42.001999 ... 17212500.0 42.124001
... ... ... ... ...
2020-10-13 448.890015 436.600006 ... 34463700.0 446.649994
2020-10-14 465.899994 447.350006 ... 48045400.0 461.299988
2020-10-15 456.570007 442.500000 ... 35672400.0 448.880005
2020-10-16 455.950012 438.850006 ... 32620000.0 439.670013
2020-10-19 447.000000 437.649994 ... 9422697.0 442.840607