实验中非常有效的代码段

Posted gelthin2017

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了实验中非常有效的代码段相关的知识,希望对你有一定的参考价值。

1. 大幅度提升 Pytorch 的训练速度

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

2. 把原有的记录文件加个后缀变为 .bak 文件,避免直接覆盖

# from co-teaching train code
txtfile = save_dir + "/" + model_str + "_%s.txt"%str(args.optimizer) ## good job! nowTime=datetime.datetime.now().strftime(%Y-%m-%d-%H:%M:%S) if os.path.exists(txtfile): os.system(mv %s %s % (txtfile, txtfile+".bak-%s" % nowTime)) # bakeup 备份文件

3. 计算 Accuracy 返回list, 调用函数时,直接提取值,而非提取list

# from co-teaching code but MixMatch_pytorch code also has it
def accuracy(logit, target, topk=(1,)): """Computes the precision@k for the specified values of k""" output = F.softmax(logit, dim=1) maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res prec1, = accuracy(logit, labels, topk=(1,)) # , indicate tuple unpackage prec1, prec5 = accuracy(logits, labels, topk=(1, 5))

4. 善于利用 logger 文件来记录每一个 epoch 的实验值

# from Pytorch_MixMatch code
class Logger(object):
    ‘‘‘Save training process to log file with simple plot function.‘‘‘
    def __init__(self, fpath, title=None, resume=False): 
        self.file = None
        self.resume = resume
        self.title = ‘‘ if title == None else title
        if fpath is not None:
            if resume: 
                self.file = open(fpath, ‘r‘) 
                name = self.file.readline()
                self.names = name.rstrip().split(‘	‘)
                self.numbers = {}
                for _, name in enumerate(self.names):
                    self.numbers[name] = []

                for numbers in self.file:
                    numbers = numbers.rstrip().split(‘	‘)
                    for i in range(0, len(numbers)):
                        self.numbers[self.names[i]].append(numbers[i])
                self.file.close()
                self.file = open(fpath, ‘a‘)  
            else:
                self.file = open(fpath, ‘w‘)

    def set_names(self, names):
        if self.resume: 
            pass
        # initialize numbers as empty list
        self.numbers = {}
        self.names = names
        for _, name in enumerate(self.names):
            self.file.write(name)
            self.file.write(‘	‘)
            self.numbers[name] = []
        self.file.write(‘
‘)
        self.file.flush()


    def append(self, numbers):
        assert len(self.names) == len(numbers), ‘Numbers do not match names‘
        for index, num in enumerate(numbers):
            self.file.write("{0:.4f}".format(num))
            self.file.write(‘	‘)
            self.numbers[self.names[index]].append(num)
        self.file.write(‘
‘)
        self.file.flush()

    def plot(self, names=None):   
        names = self.names if names == None else names
        numbers = self.numbers
        for _, name in enumerate(names):
            x = np.arange(len(numbers[name]))
            plt.plot(x, np.asarray(numbers[name]))
        plt.legend([self.title + ‘(‘ + name + ‘)‘ for name in names])
        plt.grid(True)

    def close(self):
        if self.file is not None:
            self.file.close()

# usage
logger = Logger(new_folder+‘/log_for_%s_WebVision1M.txt‘%data_type, title=title)
logger.set_names([‘epoch‘, ‘val_acc‘, ‘val_acc_ImageNet‘])
for  epoch in range(100):
    logger.append([epoch, val_acc, val_acc_ImageNet])
logger.close()

5. 利用 argparser 命令行工具来进行代码重构,使用不同参数适配不同数据集,不同优化方式,不同setting, 避免多个高度冗余的重复代码 

 

 

# 待续

以上是关于实验中非常有效的代码段的主要内容,如果未能解决你的问题,请参考以下文章

R中非常大的矩阵计算有效

Kivy PyInstaller 在 Linux 中非常大的包大小

安卓开发中非常炫的效果集合

ActiveAdmin 在开发环境中非常慢

iOS常用于显示几小时前/几天前/几月前/几年前的代码片段

模型代码联动难? BizWorks 来助力