markdown PocketFlow ChannelPrune代码详解
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了markdown PocketFlow ChannelPrune代码详解相关的知识,希望对你有一定的参考价值。
#/home/mars/hewu/tensorflow/PocketFlow/main.py
from nets.resnet_at_cifar10 import ModelHelper
from learners.learner_utils import create_learner
#1创建模型helper和learner
model_helper = ModelHelper()#网络和数据集的类
learner = create_learner(sw_writer,model_helper)#跳转到不同的压缩算法learner
#2进入训练,或者评估
learner.train()
learner.evaluate()
#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/learner.py
from learners.distillation_helper import DistillationHelper #蒸馏相关
from learners.abstract_learner import AbstractLearner
from learners.channel_pruning.model_wrapper import Model #模型相关
from learners.channel_pruning.channel_pruner import ChannelPruner #裁剪相关
from rl_agents.ddpg.agent import Agent as DdpgAgent #强化学习代理DDPG
#继承自AbstractLearner
class ChannelPrunedLearner(AbstractLearner):
#继承初始化
super(ChannelPrunedLearner,self).__init__(sm_writer,model_helper)
#类内初始化
#蒸馏类初始化
self.learner_dst = DistillationHelper(sm_writer,model_helper)
#构建
#构建输入数据,模型定义,计算裁剪上下限等
self.__build(is_train=True)
#1train函数
def train(self):
#下载预训练模型,恢复权重,创建裁剪者pruner
#...
self.create_pruner()
#选择裁剪策略:list,auto,uniform
if FLAGS.cp_prune_option == 'list':
self.__prune_and_finetune_list()
#self.__prune_and_finetune_auto()
#self.__prune_and_finetune_uniform()
#2
def create_pruner(self):
#...
self.model = Model(self.sess_train)
self.pruner = ChannlPruner(
self.model,#模型
images=train_images,
labels=train_labels,
mem_images=mem_images,
mem_labels=mem_labels,
metrics=metrics,#度量,loss,accuracy
lbound=self.lbound,#裁剪保留通道比例
summary_op=summary_op,
sm_writer=self.sm_writer)
#3以auto策略为例介绍具体裁剪方法
def __prune_and_finetune_auto(self):
self.__prune_rl()#初始化RL类并进行裁剪(调用compress),学习最佳裁剪方法
while not done:#完成prune和finetune
done = self.__prune_list_layers(queue, [FLAGS.cp_list_group])
def __prune_rl(self):
#RL学习搜索裁剪策略
#5__prune_rl()和__prune_list_layers()中都会调用compress
def compress(self, c_ratio):
#裁剪时,只把选中的裁剪通道的权值置0,并没有真的裁剪掉
self.prune_kernel(conv_op,c_ratio)#裁剪策略lasso等
self.prune_W1(father_conv, idxs)#裁剪父conv的输出通道数(即当前conv的输入通道数)
self.prune_W2(conv_op, idxs, W2)#裁剪当前conv的输入通道数
def prune_kernel(self, op, nb_channel_new): #裁剪的具体步骤
#当前卷积:裁剪后通道数,newX输入feature map,Y目标值,W2权值
nb_channel_new = max(int(np.around(c * nb_channel_new)), 1)#hw new channel number
newX = self.__extract_input(op)
Y = self.feats_dict[outname]
W2 = self._model.param_data(op)
#lasso裁剪,得到新的权值newW2,以及通道索引(True/False)
idxs, newW2 = self.compute_pruned_kernel(newX, W2, Y, c_new=nb_channel_new)
def compute_pruned_kernel(
self,
X,
W2,
Y,
alpha=1e-4,
c_new=None,
tolerance=0.02):
#固定beta,优化W,即求解W
while True:
_, tmp, coef = solve(right)
...
#固定W,优化beta,即求解beta(idxs索引就是beta)
while True:
idxs, tmp, coef = solve(alpha)
...
#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/model_wrapper.py
def get_Add_if_is_first_after_resblock(self, op):
#Add的输出层
def get_Add_if_is_last_in_resblock(cls, op):
#Add的输入层
def is_W1_prunable(self, conv):
#可以裁剪的层
#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/channel_pruner.py
from sklearn.linear_model import LassoLars
from sklearn.linear_model import LinearRegression
class ChannelPruner(object):
def __init__(self,...):
self._model = model
self.thisconvs = self._model.get_operations_by_type()#网络中的卷积层
self.__build()
def __build(self):
self.__extract_output_of_conv_and_sum()#获取conv和add op,存入self.names列表
self.__create_extractor()#创建用于获取卷积输入feature map的extractor
self.initialize_state()#初始化状态:主要是确定哪些能裁剪,裁剪率等
def initialize_state(self):
#op名,对应裁剪保留范围:[] 例如第一个和最后一个卷积不裁剪,则范围为[1.0, 1.0]
self.max_strategy_dict = {} # collection of intilial max [inp preserve, out preserve]
#op名,对应输入通道列表和输出通道列表,里面的值为True保留这个通道,False裁剪这个通道
self.fake_pruning_dict = {} # collection of fake pruning indices
#layer n c H W stride maxreduce layercomp
#状态 输出通道 输入通道 高 宽 stride 最大缩减 层计算量 都是除以每一列最大值后的归一化结果
#/home/mars/hewu/tensorflow/PocketFlow/rl_agents/ddpg/agent.py
1. resnet20裁剪:
权值维度:[KH,KW,Cin,Cout]
当前卷积都是裁剪输入通道
父卷积除了DepthwiseConv2dNative类型,其他都是裁剪输出通道
depthwise conv可以往前递推,直至找到一个普通的Conv2D OP,因为depthwise conv中不同channel之间没有dependency
|裁剪的当前卷积|裁剪的父卷积|
|:--:|:--:|
|conv2d_1|conv2d|
|conv2d_2|conv2d|
|conv2d_3|conv2d_2|
|conv2d_5|conv2d_4|
|conv2d_7|conv2d_6|
|conv2d_10|conv2d_9|
|conv2d_12|conv2d_11|
|conv2d_14|conv2d_13|
|conv2d_17|conv2d_16|
|conv2d_19|conv2d_18|
|conv2d_4可以裁剪输入通道,但是转pb时需要在其前面插入tf.gather|Add|
|conv2d_6|Add|
|conv2d_8|Add|
|conv2d_9|Add|
|conv2d_11|Add|
|conv2d_13|Add|
|conv2d_15|Add|
|conv2d_16|Add|
|conv2d_18|Add|
|conv2d_20|Add|
最后一个卷积不裁剪
conv2d_21
以上是关于markdown PocketFlow ChannelPrune代码详解的主要内容,如果未能解决你的问题,请参考以下文章