判断一个模型是pytorch模型还是tensorflow模型还是scikit模型

Posted

技术标签:

【中文标题】判断一个模型是pytorch模型还是tensorflow模型还是scikit模型【英文标题】:Determine whether a model is pytorch model or a tensorflow model or scikit model 【发布时间】:2021-02-27 03:36:00 【问题描述】:

如果我想确定模型的类型,即它是从哪个框架以编程方式制作的,有没有办法做到这一点? 我有一个以某种序列化方式的模型(例如泡菜文件)。为简单起见,假设我的模型可以是 tensorflow、pytorch 或 scikit learn 的。如何以编程方式确定这 3 个中的哪一个?

【问题讨论】:

【参考方案1】:

AFAIK,我从未听说过要使用 pickle 或 joblib 保存的 Tensorflow/Keras 和 Pytorch 模型 - 这些框架提供了自己的保存和加载模型的功能:请参阅 SO 线程 Tensorflow: how to save/restore a model? 和 Best way to save a trained model in PyTorch?。此外,Github thread 在尝试使用 pickle 和 joblib 保存 TensorFlow 模型时报告各种问题。

鉴于此,如果您加载了一个模型,比如说,pickle,那么查看它使用的是什么类型是微不足道的 type(model)model。以下是 scikit-learn 线性回归模型的简短演示:

import numpy as np
from sklearn.linear_model import LinearRegression

X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
y = np.dot(X, np.array([1, 2])) + 3
reg = LinearRegression()
reg.fit(X, y)

# save it

import pickle

filename = 'model1.pkl'
pickle.dump(reg, open(filename, 'wb'))

现在,加载模型:

loaded_model = pickle.load(open(filename, 'rb'))

type(loaded_model)
# sklearn.linear_model._base.LinearRegression

loaded_model
# LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)

这也适用于 XGBoost、LightGBM、CatBoost 等框架。

【讨论】:

以上是关于判断一个模型是pytorch模型还是tensorflow模型还是scikit模型的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch保留验证集上最好的模型

PyTorch模型定义

PyTorch模型定义

Pytorch模型训练&保存/加载(搭建完整流程)

PyTorch模型定义

Pytorch学习记录-TextMatching几个经典模型