pytorch PyTorch 1.1.0 源码解析--运行机制
Posted leimu
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch PyTorch 1.1.0 源码解析--运行机制相关的知识,希望对你有一定的参考价值。
原文来自知乎,现摘录与此
首先这是一段mnist数据集的基本代码。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=1e-6, momentum=0.5)
train_loader = []
model.train()
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
1.初始化(nn.Module.__init__())
init中主要初始化了很多参数,比如buffers,hook等等。根据Net类的代码,它会依次初始化各个层。nn.Module.__init__()
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
关于hook: PyTorch 设计了两种 hook:register_forward_hook 和 register_backward_hook,
分别用来获取正/反向传播时,中间层模块输入和输出的 feature/gradient,大大降低了获取模型内部信息流的难度。
register_forward_hook的作用是获取前向传播过程中,各个网络模块的输入和输出。对于模块 module,其使用方式为:module.register_forward_hook(hook_fn) 。
register_backward_hook 的作用是获取神经网络反向传播过程中,各个模块输入端和输出端的梯度值。
2.前向传播真正的计算入口点(nn.Module.__call__())
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values(): #在执行到这一条语句之前,计算实际上是没有发生的
hook_result = hook(self, input, result)
if hook_result is not None:
raise RuntimeError(
"forward hooks should never return any values, but ‘{}‘"
"didn‘t return None".format(hook))
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
for hook in self._forward_hooks.values(): 在执行到这一条语句之前,计算实际上是没有发生的。这一行会在执行forward之前进行,处理预设的hook。
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
这个地方实现了在不写C的情况下直接执行forward,有一些自定义操作没有C,会直接调用python的版本。
这一步开始,调用了forward方法,首先会调用Net类的forward方法,然后会以此调用Conv2d的__call__()方法等。
当调用Conv2d()的forward方法,其forward方法写在了torch._C下:
ensor Conv2dImpl::forward(const Tensor& input) {
if (options.transposed_) {
return torch::conv_transpose2d(
input,
weight,
bias,
options.stride_,
options.padding_,
options.output_padding_,
options.groups_,
options.dilation_);
}
return torch::conv2d(
input,
weight,
bias,
options.stride_,
options.padding_,
options.dilation_,
options.groups_);
然而这依然是一个wrapper,这部分逻辑代码最终由aten/c10定义 https://zhuanlan.zhihu.com/p/55966063
最终计算在:
CPU: legacy::cpu::_thnn_conv2d_forward
CUDA: legacy::cuda::_thnn_conv2d_forward
到这里,一个卷积层的forward操作就结束了,其他层的forward同理。
Conv2d的forward方法执行完成之后接着进行forward_hook和backward_hook的步骤,与之前的forward_pre_hook相似。
到这里,Conv2d的__call__()方法执行完毕,接下来执行relu之类的逻辑,直到return。
调用栈返回Net的forward的返回值,得到loss。
到这里,前向传播完成。
3.反向传播(loss.backward())
loss.backward() 只执行一次,计算完成所有的梯度。
首先,所有的requires_grad为True的张量都会被记录并被添加进Engine::ready_queue_by_index中,这些tensor都会被以FunctionTask的结构体记录在ReadyQueue中。
首先在前向传播的时候,所有requiresgrad==True的对象都会被添加进一个容器中,然后在backward执行之前,首先启动一个处理引擎,在做了初始化和读取相关的记录(包括之前的哪个容器)后调用了run_backward方法,然后统一计算出梯度,并返回loss的梯度。
以上是关于pytorch PyTorch 1.1.0 源码解析--运行机制的主要内容,如果未能解决你的问题,请参考以下文章
01 Pytorch和CUDA对应的版本及Pytorch和Python对应的版本及Python与Anaconda的对应关系
[源码解析] PyTorch 分布式之弹性训练---Rendezvous 引擎
PyTorch笔记 - MAE(Masked Autoencoders) PyTorch源码
PyTorch笔记 - MAE(Masked Autoencoders) PyTorch源码