tensorflow :ckpt模型转换为pytorch : hdf5模型

Posted Richal Wang

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow :ckpt模型转换为pytorch : hdf5模型相关的知识,希望对你有一定的参考价值。

参考链接:https://github.com/bermanmaxim/jaccardSegment/blob/master/ckpt_to_dd.py

import tensorflow as tf
import deepdish as dd
import argparse
import os
import numpy as np

def tr(v):
    # tensorflow weights to pytorch weights
    if v.ndim == 4:
        return np.ascontiguousarray(v.transpose(3,2,0,1))
    elif v.ndim == 2:
        return np.ascontiguousarray(v.transpose())
    return v

def read_ckpt(ckpt):
    # https://github.com/tensorflow/tensorflow/issues/1823
    reader = tf.train.NewCheckpointReader(ckpt)
    weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().iteritems()}
    pyweights = {k: tr(v) for (k, v) in weights.items()}
    return pyweights

if __name__ == __main__:
    parser = argparse.ArgumentParser(description="Converts ckpt weights to deepdish hdf5")
    parser.add_argument("infile", type=str,
                        help="Path to the ckpt.")
    parser.add_argument("outfile", type=str, nargs=?, default=‘‘,
                        help="Output file (inferred if missing).")
    args = parser.parse_args()
    if args.outfile == ‘‘:
        args.outfile = os.path.splitext(args.infile)[0] + .h5
    outdir = os.path.dirname(args.outfile)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    weights = read_ckpt(args.infile)
    dd.io.save(args.outfile, weights)
    weights2 = dd.io.load(args.outfile)

 

以上是关于tensorflow :ckpt模型转换为pytorch : hdf5模型的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow ckpt模型转saved_model格式并进行模型预测

tensorflow模型ckpt转pb以及其遇到的问题

DL之GRU(Tensorflow框架):基于茅台股票数据集利用GRU算法实现回归预测(保存模型.ckpt.index.ckpt.data文件)

TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式

tensorflow 将训练模型保存为pd文件

TensorFlow:有没有办法将冻结图转换为检查点模型?