PyTorch Lightning 0.7.1 发布
Posted 专知
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch Lightning 0.7.1 发布相关的知识,希望对你有一定的参考价值。
【导读】PyTorch Lightning是一种组织PyTorch代码以使研究代码(如网络定义)与工程代码(如数据加载,模型训练等)分离的方法,主要面向深度学习研究者,博士生等。 近日,Pytorch Lightning 发布0.7.1版本,增加了TPU支持,Profiling,Callbacks等功能。
Pytorch Ligntning简介
Pytorch Ligntning 使用LightingModule,将模型的优化训练也归到模块中,使得模型的训练更加便捷。
快速上手
定义LightningModule
class LitSystem(pl.LightningModule):
def __init__(self):
super().__init__()
# not the best model...
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
...
使用Trainer训练
from pytorch_lightning import Trainer
model = LitSystem()
# most basic trainer, uses good defaults
trainer = Trainer()
trainer.fit(model)
0.7.1版本新增的功能
增加TPU支持
此版本的最大功能之一可能是对TPU的支持。现在,无需更改代码,可以在CPU,GPU和TPU上运行相同的Lightning代码。
Profiling
简单来说就是统计模型运行中每步的执行时间,用于找出模型的瓶颈,便于优化。
用法:
trainer=Trainer(...,profiler=True)
模型训练结束后,会输出代码的执行时间
AdvanceProfiler 可以打印出更多细节
profiler=AdvanceProfiler()
trainer=Trainer(...,profiler=profiler)
回调
回调可以在训练的不同时间执行任意功能。这是在不污染研究代码的情况下将共享功能封装在单个类中的一种好方法。
在Lightning中,我们通过3种方式来考虑深度学习代码:
研究代码(LightningModule)
工程代码(训练过程)
不必要的研究代码(回调)
比如,可以使用回调功能进行模型训练时的输出功能。
import pytorch_lightning as pl
class MyPrintingCallback(pl.Callback):
def on_init_start(self, trainer):
print('Starting to init trainer!')
def on_init_end(self, trainer):
print('trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
# pass to trainer
trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
或者是log功能
import pytorch_lightning as pl
class MyLoggingCallback(pl.Callback):
def on_init_start(self, trainer):
trainer.logger.experiment.log_tensorboard_images(...)
def on_init_end(self, trainer):
trainer.logger.experiment.save_or_something(...)
def on_train_end(self, trainer, pl_module):
trainer.logger.experiment.log_something_else(...)
# pass to trainer
trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
或是与服务器沟通
import pytorch_lightning as pl
class MyAPICallback(pl.Callback):
def on_init_start(self, trainer):
requests.post('model started')
def on_init_end(self, trainer):
def on_train_end(self, trainer, pl_module):
data = requests.get('/new_data/or/something')
save(data)
...
# pass to trainer
trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
原文链接:https://medium.com/pytorch/pytorch-lightning-0-7-1-release-and-venture-funding-dd12b2e75fb3
以上是关于PyTorch Lightning 0.7.1 发布的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch Lightning 是不是在整个时期内平均指标?
Pytorch-Lightning 是不是具有多处理(或 Joblib)模块?