加载器的无效数据类型 - Pytorch Lightning DataModule

Posted

技术标签:

【中文标题】加载器的无效数据类型 - Pytorch Lightning DataModule【英文标题】:Invalid Datatype for loaders - Pytorch Lightning DataModule 【发布时间】:2022-01-21 02:30:03 【问题描述】:

我正在尝试进行文本摘要练习,并且我已经训练和测试了包含两列文本和摘要(标签)的数据集。我正在使用 T5、Pytorch 和 Lightning 包装器,并且我有一个 Pytorch 数据集类,我可以确认它工作正常,并将以下内容作为文本字典返回,并将 id、标签和掩码作为张量返回。

return dict(
    text=text,
    summary = data_row['summary'],
    text_input_ids = text_encoding['input_ids'].flatten(),
    text_attention_mask = text_encoding['attention_mask'].flatten(),
    labels = labels.flatten(),
    labels_attention_mask = summary_encoding['attention_mask'].flatten()
)

然后我有一个 Lightning 数据模块类,它将数据帧转换为 PyTorch 数据集,并将它们适合数据加载器,返回训练、验证和测试数据加载器

class TextSummaryDataModule(pl.LightningModule):
  def __init__(
      self, 
      train_df: pd.DataFrame, 
      test_df: pd.DataFrame, 
      tokenizer: T5Tokenizer, 
      batch_size: int=8, 
      text_max_token_len: int=512, 
      summary_max_token_len: int=128
    ):
    
      super().__init__()
      
      self.train_df = train_df
      self.test_df = test_df

      self.tokenizer = tokenizer
      self.batch_size = batch_size
      self.text_max_token_len = text_max_token_len
      self.summary_max_token_len = summary_max_token_len

  def setup(self):
    self.train_dataset = TextSummaryDataset(
        self.train_df,
        self.tokenizer,
        self.text_max_token_len,
        self.summary_max_token_len
    )

    self.test_dataset = TextSummaryDataset(
        self.test_df,
        self.tokenizer,
        self.text_max_token_len,
        self.summary_max_token_len
    )

  def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size = self.batch_size,
        shuffle=True,
        num_workers=2
    )

  def val_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = self.batch_size,
        shuffle=False,
        num_workers=2
    )

  def test_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = self.batch_size,
        shuffle=False,
        num_workers=2
    )

一切正常,直到我尝试执行模型并收到以下警告和错误

    用户警告:您定义了一个验证步骤,但没有验证数据加载器。跳过验证循环 - 我已经在数据模块中明确定义并返回了它

    加载程序的数据类型无效:TextSummaryDataModule - 我已确认我正在返回一个包含文本和摘要的标记、注意掩码和标签的字典

【问题讨论】:

【参考方案1】:

可耻的是我在这里使用了 pl.LightningModule 而不是 DataModule...

【讨论】:

以上是关于加载器的无效数据类型 - Pytorch Lightning DataModule的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch数据加载

使用flask在heroku bert pytorch模型上部署:错误:_pickle.UnpicklingError:无效的加载键,'v'

pytorch-lightning train_dataloader 用完数据

Pytorch 基本使用(数据加载,类型转换)

pytorch 加载数据集

Pytorch应用:构建分类器