tensorflow和pytorch模型之间转换

Posted qbdj

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow和pytorch模型之间转换相关的知识,希望对你有一定的参考价值。

参考链接:
https://github.com/bermanmaxim/jaccardSegment/blob/master/ckpt_to_dd.py

一. tensorflow模型转pytorch模型

import tensorflow as tf
import deepdish as dd
import argparse
import os
import numpy as np

def tr(v):
    # tensorflow weights to pytorch weights
    if v.ndim == 4:
        return np.ascontiguousarray(v.transpose(3,2,0,1))
    elif v.ndim == 2:
        return np.ascontiguousarray(v.transpose())
    return v

def read_ckpt(ckpt):
    # https://github.com/tensorflow/tensorflow/issues/1823
    reader = tf.train.NewCheckpointReader(ckpt)
    weights = n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().items()
    pyweights = k: tr(v) for (k, v) in weights.items()
    return pyweights
if __name__ == ‘__main__‘:
    parser = argparse.ArgumentParser(description="Converts ckpt weights to deepdish hdf5")
    parser.add_argument("infile", type=str,
                        help="Path to the ckpt.")  # ***model.ckpt-22177***
    parser.add_argument("outfile", type=str, nargs=‘?‘, default=‘‘,
                        help="Output file (inferred if missing).")
    args = parser.parse_args()
    if args.outfile == ‘‘:
        args.outfile = os.path.splitext(args.infile)[0] + ‘.h5‘
    outdir = os.path.dirname(args.outfile)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    weights = read_ckpt(args.infile)
    dd.io.save(args.outfile, weights)

  

1.运行上述代码后会得到model.h5模型,如下:
备注:保持tensorflow和pytorch使用的python版本一致

技术图片

 

2.使用:在pytorch内加载改模型:
这里假设网络保存时参数命名一致

 

net = ...
import torch
import deepdish as dd
net = resnet50(..)
model_dict = net.state_dict()
#先将参数值numpy转换为tensor形式
pretrained_dict =  = dd.io.load(‘./model.h5‘)
new_pre_dict = 
for k,v in pretrained_dict.items():
    new_pre_dict[k] = torch.Tensor(v)
#更新
model_dict.update(new_pre_dict)
#加载
net.load_state_dict(model_dict)

  

二. pytorch转tensorflow(待续。。)

 

原文:https://blog.csdn.net/weixin_42699651/article/details/88932670

 

以上是关于tensorflow和pytorch模型之间转换的主要内容,如果未能解决你的问题,请参考以下文章

在将TensorFlow模型转换为Pytorch时出现大小不匹配的错误。

转载微软Facebook联手发布AI生态系统,CNTK+Caffe2+PyTorch挑战TensorFlow

比较 Conv2D 与 Tensorflow 和 PyTorch 之间的填充

pytorch 转tensorflow注意

是否可以使用 C++ 训练在 tensorflow 和 pytorch 中开发的 ONNX 模型?

承接TensorFlow深度学习代做pytorch图像处理