实验中非常有效的代码段
Posted gelthin2017
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了实验中非常有效的代码段相关的知识,希望对你有一定的参考价值。
1. 大幅度提升 Pytorch 的训练速度
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
2. 把原有的记录文件加个后缀变为 .bak 文件,避免直接覆盖
# from co-teaching train code
txtfile = save_dir + "/" + model_str + "_%s.txt"%str(args.optimizer) ## good job! nowTime=datetime.datetime.now().strftime(‘%Y-%m-%d-%H:%M:%S‘) if os.path.exists(txtfile): os.system(‘mv %s %s‘ % (txtfile, txtfile+".bak-%s" % nowTime)) # bakeup 备份文件
3. 计算 Accuracy 返回list, 调用函数时,直接提取值,而非提取list
# from co-teaching code but MixMatch_pytorch code also has it
def accuracy(logit, target, topk=(1,)): """Computes the precision@k for the specified values of k""" output = F.softmax(logit, dim=1) maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res prec1, = accuracy(logit, labels, topk=(1,)) # , indicate tuple unpackage prec1, prec5 = accuracy(logits, labels, topk=(1, 5))
4. 善于利用 logger 文件来记录每一个 epoch 的实验值
# from Pytorch_MixMatch code class Logger(object): ‘‘‘Save training process to log file with simple plot function.‘‘‘ def __init__(self, fpath, title=None, resume=False): self.file = None self.resume = resume self.title = ‘‘ if title == None else title if fpath is not None: if resume: self.file = open(fpath, ‘r‘) name = self.file.readline() self.names = name.rstrip().split(‘ ‘) self.numbers = {} for _, name in enumerate(self.names): self.numbers[name] = [] for numbers in self.file: numbers = numbers.rstrip().split(‘ ‘) for i in range(0, len(numbers)): self.numbers[self.names[i]].append(numbers[i]) self.file.close() self.file = open(fpath, ‘a‘) else: self.file = open(fpath, ‘w‘) def set_names(self, names): if self.resume: pass # initialize numbers as empty list self.numbers = {} self.names = names for _, name in enumerate(self.names): self.file.write(name) self.file.write(‘ ‘) self.numbers[name] = [] self.file.write(‘ ‘) self.file.flush() def append(self, numbers): assert len(self.names) == len(numbers), ‘Numbers do not match names‘ for index, num in enumerate(numbers): self.file.write("{0:.4f}".format(num)) self.file.write(‘ ‘) self.numbers[self.names[index]].append(num) self.file.write(‘ ‘) self.file.flush() def plot(self, names=None): names = self.names if names == None else names numbers = self.numbers for _, name in enumerate(names): x = np.arange(len(numbers[name])) plt.plot(x, np.asarray(numbers[name])) plt.legend([self.title + ‘(‘ + name + ‘)‘ for name in names]) plt.grid(True) def close(self): if self.file is not None: self.file.close() # usage logger = Logger(new_folder+‘/log_for_%s_WebVision1M.txt‘%data_type, title=title) logger.set_names([‘epoch‘, ‘val_acc‘, ‘val_acc_ImageNet‘]) for epoch in range(100): logger.append([epoch, val_acc, val_acc_ImageNet]) logger.close()
5. 利用 argparser 命令行工具来进行代码重构,使用不同参数适配不同数据集,不同优化方式,不同setting, 避免多个高度冗余的重复代码
# 待续
以上是关于实验中非常有效的代码段的主要内容,如果未能解决你的问题,请参考以下文章