markdown PocketFlow ChannelPrune代码详解

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了markdown PocketFlow ChannelPrune代码详解相关的知识,希望对你有一定的参考价值。

#/home/mars/hewu/tensorflow/PocketFlow/main.py
from nets.resnet_at_cifar10 import ModelHelper
from learners.learner_utils import create_learner

#1创建模型helper和learner
model_helper = ModelHelper()#网络和数据集的类
learner = create_learner(sw_writer,model_helper)#跳转到不同的压缩算法learner

#2进入训练,或者评估
learner.train()
learner.evaluate()


#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/learner.py
from learners.distillation_helper import DistillationHelper #蒸馏相关
from learners.abstract_learner import AbstractLearner
from learners.channel_pruning.model_wrapper import Model #模型相关
from learners.channel_pruning.channel_pruner import ChannelPruner #裁剪相关
from rl_agents.ddpg.agent import Agent as DdpgAgent #强化学习代理DDPG

#继承自AbstractLearner
class ChannelPrunedLearner(AbstractLearner):
  #继承初始化
  super(ChannelPrunedLearner,self).__init__(sm_writer,model_helper)
  
  #类内初始化
  #蒸馏类初始化
  self.learner_dst = DistillationHelper(sm_writer,model_helper)
  
  #构建
  #构建输入数据,模型定义,计算裁剪上下限等
  self.__build(is_train=True)
  
  #1train函数
  def train(self):
    #下载预训练模型,恢复权重,创建裁剪者pruner
    #...
    self.create_pruner()
    #选择裁剪策略:list,auto,uniform
    if FLAGS.cp_prune_option == 'list':
      self.__prune_and_finetune_list()
      #self.__prune_and_finetune_auto()
      #self.__prune_and_finetune_uniform()
  #2
  def create_pruner(self):
    #...
    self.model = Model(self.sess_train)
    self.pruner = ChannlPruner(
      self.model,#模型
      images=train_images,
      labels=train_labels,
      mem_images=mem_images,
      mem_labels=mem_labels,
      metrics=metrics,#度量,loss,accuracy
      lbound=self.lbound,#裁剪保留通道比例
      summary_op=summary_op,
      sm_writer=self.sm_writer)
  
  #3以auto策略为例介绍具体裁剪方法
  def __prune_and_finetune_auto(self):
    self.__prune_rl()#初始化RL类并进行裁剪(调用compress),学习最佳裁剪方法
    while not done:#完成prune和finetune
      done = self.__prune_list_layers(queue, [FLAGS.cp_list_group])
     
  def __prune_rl(self):
    #RL学习搜索裁剪策略
    
  #5__prune_rl()和__prune_list_layers()中都会调用compress
  def compress(self, c_ratio): 
    #裁剪时,只把选中的裁剪通道的权值置0,并没有真的裁剪掉
    self.prune_kernel(conv_op,c_ratio)#裁剪策略lasso等
    self.prune_W1(father_conv, idxs)#裁剪父conv的输出通道数(即当前conv的输入通道数)
    self.prune_W2(conv_op, idxs, W2)#裁剪当前conv的输入通道数
    
  def  prune_kernel(self, op, nb_channel_new): #裁剪的具体步骤
    #当前卷积:裁剪后通道数,newX输入feature map,Y目标值,W2权值
    nb_channel_new = max(int(np.around(c * nb_channel_new)), 1)#hw new channel number
    newX = self.__extract_input(op)
    Y = self.feats_dict[outname]
    W2 = self._model.param_data(op)
    #lasso裁剪,得到新的权值newW2,以及通道索引(True/False)
    idxs, newW2 = self.compute_pruned_kernel(newX, W2, Y, c_new=nb_channel_new)
    
  def compute_pruned_kernel(
      self,
      X,
      W2,
      Y,
      alpha=1e-4,
      c_new=None,
      tolerance=0.02):
        
          #固定beta,优化W,即求解W
          while True:
            _, tmp, coef = solve(right)
            ...
          #固定W,优化beta,即求解beta(idxs索引就是beta)
          while True:
            idxs, tmp, coef = solve(alpha)
            ...
          
        
    



#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/model_wrapper.py
  def get_Add_if_is_first_after_resblock(self, op):
    #Add的输出层
   
   
  def get_Add_if_is_last_in_resblock(cls, op):
    #Add的输入层
    
  def is_W1_prunable(self, conv):
    #可以裁剪的层
#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/channel_pruner.py
from sklearn.linear_model import LassoLars
from sklearn.linear_model import LinearRegression


class ChannelPruner(object):
  def __init__(self,...):
    self._model = model
    self.thisconvs = self._model.get_operations_by_type()#网络中的卷积层
    self.__build()
  def __build(self):
    self.__extract_output_of_conv_and_sum()#获取conv和add op,存入self.names列表
    self.__create_extractor()#创建用于获取卷积输入feature map的extractor
    self.initialize_state()#初始化状态:主要是确定哪些能裁剪,裁剪率等
    
  def initialize_state(self):
    #op名,对应裁剪保留范围:[] 例如第一个和最后一个卷积不裁剪,则范围为[1.0, 1.0]
    self.max_strategy_dict = {} # collection of intilial max [inp preserve, out preserve]
    #op名,对应输入通道列表和输出通道列表,里面的值为True保留这个通道,False裁剪这个通道
    self.fake_pruning_dict = {} # collection of fake pruning indices
    #layer   n          c    H  W  stride  maxreduce  layercomp
    #状态  输出通道 输入通道 高 宽 stride   最大缩减   层计算量    都是除以每一列最大值后的归一化结果
    
    
    
    
#/home/mars/hewu/tensorflow/PocketFlow/rl_agents/ddpg/agent.py
1. resnet20裁剪:
权值维度:[KH,KW,Cin,Cout]
当前卷积都是裁剪输入通道
父卷积除了DepthwiseConv2dNative类型,其他都是裁剪输出通道
depthwise conv可以往前递推,直至找到一个普通的Conv2D OP,因为depthwise conv中不同channel之间没有dependency


|裁剪的当前卷积|裁剪的父卷积|
|:--:|:--:|
|conv2d_1|conv2d|
|conv2d_2|conv2d|
|conv2d_3|conv2d_2|
|conv2d_5|conv2d_4|
|conv2d_7|conv2d_6|
|conv2d_10|conv2d_9|
|conv2d_12|conv2d_11|
|conv2d_14|conv2d_13|
|conv2d_17|conv2d_16|
|conv2d_19|conv2d_18|
|conv2d_4可以裁剪输入通道,但是转pb时需要在其前面插入tf.gather|Add|
|conv2d_6|Add|
|conv2d_8|Add|
|conv2d_9|Add|
|conv2d_11|Add|
|conv2d_13|Add|
|conv2d_15|Add|
|conv2d_16|Add|
|conv2d_18|Add|
|conv2d_20|Add|

最后一个卷积不裁剪
conv2d_21

以上是关于markdown PocketFlow ChannelPrune代码详解的主要内容,如果未能解决你的问题,请参考以下文章

markdown PocketFlow ChannelPrune代码详解

markdown PocketFlow压缩框架

go的有缓冲chann和无缓冲chan的区别

chan array初始化

java filebytebuff

java处理文件的复制