如何将 Pandas DataFrame 加载到 LSTM 模型中?
Posted
技术标签:
【中文标题】如何将 Pandas DataFrame 加载到 LSTM 模型中?【英文标题】:How can I load a Pandas DataFrame into a LSTM model? 【发布时间】:2020-12-30 12:16:19 【问题描述】:我只是在玩 RNN,无法将我的数据转换为适合我的模型的正确格式。我有以下数据框:
Apple Pears Oranges ID
0 1.00 2.09 4.11 0
1 1.38 1.73 5.13 1
2 1.68 2.28 6.91 2
3 1.50 2.69 8.93 3
4 1.35 2.63 12.25 4
5 1.52 3.09 12.20 5
6 1.63 3.63 13.68 6
7 2.01 4.92 16.21 7
8 2.52 4.01 18.79 8
9 3.10 5.49 24.05 9
ID
是我的数据的顺序/时间步长。
我运行了这个命令来尝试将它加载到时间序列数据集中:
Dataset = keras.preprocessing.timeseries_dataset_from_array(priceHistorydf, basketHistorydf, sequence_length=10)
但是当我尝试在此基础上训练模型时,出现以下错误:
from tensorflow import keras
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import *
X_train = priceHistorydf
y_train = basketHistorydf
model = Sequential()
model.add(TimeDistributed(Dense(10), input_shape=(X_train.shape[1:])))
model.add(Bidirectional(LSTM(8)))
model.add(Dense(8, activation='tanh'))
model.add(Dense(8, activation='tanh'))
model.add(Dense(y_train.shape[-1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer="adam")
# history = model.fit(X_train, y_train, epochs=2, batch_size=8)
history = model.fit(Dataset, epochs=2, batch_size=8)
错误:
ValueError: `TimeDistributed` Layer should be passed an `input_shape ` with at least 3 dimensions, received: [None, 4]
我只是在猜测,但我意识到我没有明确让模型知道ID
是时间步长;但我不确定如何使用我的数据框将其传递给模型。
有什么建议吗?
【问题讨论】:
【参考方案1】:主要问题是您错误地设置了input_shape
参数(即X_train
是原始数据,而不是生成的时间序列;因此X_train.shape[1:]
作为输入形状不正确)。由于您使用了sequence_lenght=10
,并且每个时间步都有 3 个特征,因此我们应该有 input_shape=(10,3)
(当然,假设您首先从数据中删除 ID
列,因为这不是一个特征)。
附带说明:Dense(...)
和TimeDistributed(Dense(...))
完全相同,因为Dense
层默认应用于最后一个轴。更多信息和解释请参见this answer。
【讨论】:
谢谢。有些教程对我来说不是很清楚。时间序列的顺序是由数据帧的顺序推断的还是我需要明确说明顺序/时间步长? @Lostsoul 时间步的顺序与数据框中的行顺序相同。阅读此函数的documentation 了解更多信息及其工作原理。以上是关于如何将 Pandas DataFrame 加载到 LSTM 模型中?的主要内容,如果未能解决你的问题,请参考以下文章
Pandas df.describe() - 如何将值提取到 Dataframe 中?
将 CSV 加载到 Pandas MultiIndex DataFrame
将 CSV 数据文件上传到 Pandas Dataframe 时如何分配标签和特征
通过 BS4 将 Scraped Table 加载到 Pandas Dataframe 中