Pytorch实现对卷积的可插拔reparameterization

Posted mistariano

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch实现对卷积的可插拔reparameterization相关的知识,希望对你有一定的参考价值。

需要实现对卷积层的重参数化reparameterization

但是代码里卷积前weight并没有hook,很难在原本的卷积类上用pure oo的方式实现

目前的解决方案是继承原本的卷积,挂载一个weight module替代原本的weight parameter。需要hack一下getattr

大致代码:

class ReparamLayer(nn.module):
    def __init__(self, weight:nn.Parameter):
        self.weight = weight

    def forward(self):
        reparam = self.weight
        # do something
        # reparam = fn(reparam)
        return reparam
    
    @property
    def data(self):
        return self.forward()  # hack

class ReparamConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        self._inited = False
        super().__init__
        w = self.weight
        self._inited = True
        self._weight = ReparamLayer(w)  # reparam weights here
        del self._parameters['weight']
    
    def __getattr__(self, item):
        if self._inited and item == 'weight':
            return self._weight  # hack
        else:
            return super().__getattr__(item)

以上是关于Pytorch实现对卷积的可插拔reparameterization的主要内容,如果未能解决你的问题,请参考以下文章

Servlet3.0的可插拔功能

Clojure 中的可插拔向量处理单元

带你手写基于 Spring 的可插拔式 RPC 框架通信协议模块

带你手写基于 Spring 的可插拔式 RPC 框架注册中心

带你手写基于 Spring 的可插拔式 RPC 框架代理类的注入与服务启动

带你手写基于 Spring 的可插拔式 RPC 框架整体结构