PyTorch Lightning 0.7.1 发布

Posted 专知

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch Lightning 0.7.1 发布相关的知识,希望对你有一定的参考价值。

【导读】PyTorch Lightning是一种组织PyTorch代码以使研究代码(如网络定义)与工程代码(如数据加载,模型训练等)分离的方法,主要面向深度学习研究者,博士生等。 近日,Pytorch Lightning 发布0.7.1版本,增加了TPU支持,Profiling,Callbacks等功能。



Pytorch Ligntning简介

Pytorch Ligntning 使用LightingModule,将模型的优化训练也归到模块中,使得模型的训练更加便捷。


快速上手

定义LightningModule

class LitSystem(pl.LightningModule):
def __init__(self):super().__init__()# not the best model...self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx): ...

使用Trainer训练

from pytorch_lightning import Trainer

model = LitSystem()

# most basic trainer, uses good defaults
trainer = Trainer()
trainer.fit(model)


0.7.1版本新增的功能

增加TPU支持

此版本的最大功能之一可能是对TPU的支持。现在,无需更改代码,可以在CPU,GPU和TPU上运行相同的Lightning代码。


Profiling

简单来说就是统计模型运行中每步的执行时间,用于找出模型的瓶颈,便于优化。

用法:

trainer=Trainer(...,profiler=True)

模型训练结束后,会输出代码的执行时间

PyTorch Lightning 0.7.1 发布

AdvanceProfiler 可以打印出更多细节

profiler=AdvanceProfiler()trainer=Trainer(...,profiler=profiler)


回调

回调可以在训练的不同时间执行任意功能。这是在不污染研究代码的情况下将共享功能封装在单个类中的一种好方法。

在Lightning中,我们通过3种方式来考虑深度学习代码:

  • 研究代码(LightningModule)

  • 工程代码(训练过程)

  • 不必要的研究代码(回调)

比如,可以使用回调功能进行模型训练时的输出功能。

import pytorch_lightning as pl
class MyPrintingCallback(pl.Callback):
def on_init_start(self, trainer): print('Starting to init trainer!')
def on_init_end(self, trainer): print('trainer is init now')
def on_train_end(self, trainer, pl_module): print('do something when training ends')
# pass to trainertrainer = pl.Trainer(callbacks=[MyPrintingCallback()])

或者是log功能

import pytorch_lightning as pl
class MyLoggingCallback(pl.Callback):
def on_init_start(self, trainer): trainer.logger.experiment.log_tensorboard_images(...)
def on_init_end(self, trainer): trainer.logger.experiment.save_or_something(...)
def on_train_end(self, trainer, pl_module): trainer.logger.experiment.log_something_else(...)
# pass to trainertrainer = pl.Trainer(callbacks=[MyPrintingCallback()])

或是与服务器沟通

import pytorch_lightning as pl
class MyAPICallback(pl.Callback):
def on_init_start(self, trainer): requests.post('model started')
def on_init_end(self, trainer):

def on_train_end(self, trainer, pl_module): data = requests.get('/new_data/or/something') save(data) ...
# pass to trainertrainer = pl.Trainer(callbacks=[MyPrintingCallback()])

原文链接:https://medium.com/pytorch/pytorch-lightning-0-7-1-release-and-venture-funding-dd12b2e75fb3


PyTorch Lightning 0.7.1 发布
专知,专业可信的人工智能知识分发,让认知协作更快更好!欢迎注册登录专知www.zhuanzhi.ai,获取5000+AI主题干货知识资料!
欢迎微信扫一扫加入专知人工智能知识星球群,获取最新AI专业干货知识教程资料和与专家交流咨询!
点击“阅读原文”,了解使用专知,查看获取5000+AI主题知识资源

以上是关于PyTorch Lightning 0.7.1 发布的主要内容,如果未能解决你的问题,请参考以下文章

pytorch-lightning入门—— 初了解

PyTorch Lightning 是不是在整个时期内平均指标?

Pytorch-Lightning 是不是具有多处理(或 Joblib)模块?

pytorch-lightning 中的正态分布采样

如何从Pytorch 到 Pytorch Lightning | 简要介绍

pytorch lightning 手写数字分类实例