Pytorch实现对卷积的可插拔reparameterization
Posted mistariano
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch实现对卷积的可插拔reparameterization相关的知识,希望对你有一定的参考价值。
需要实现对卷积层的重参数化reparameterization
但是代码里卷积前weight并没有hook,很难在原本的卷积类上用pure oo的方式实现
目前的解决方案是继承原本的卷积,挂载一个weight module替代原本的weight parameter。需要hack一下getattr
大致代码:
class ReparamLayer(nn.module):
def __init__(self, weight:nn.Parameter):
self.weight = weight
def forward(self):
reparam = self.weight
# do something
# reparam = fn(reparam)
return reparam
@property
def data(self):
return self.forward() # hack
class ReparamConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
self._inited = False
super().__init__
w = self.weight
self._inited = True
self._weight = ReparamLayer(w) # reparam weights here
del self._parameters['weight']
def __getattr__(self, item):
if self._inited and item == 'weight':
return self._weight # hack
else:
return super().__getattr__(item)
以上是关于Pytorch实现对卷积的可插拔reparameterization的主要内容,如果未能解决你的问题,请参考以下文章
带你手写基于 Spring 的可插拔式 RPC 框架通信协议模块
带你手写基于 Spring 的可插拔式 RPC 框架注册中心