开源图像分类工具箱MMClassification安装及使用示例

Posted fengbingchun

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了开源图像分类工具箱MMClassification安装及使用示例相关的知识,希望对你有一定的参考价值。

      MMClassification是一个基于PyTorch的开源图像分类工具箱,是OpenMMLab项目的一部分,源码在 https://github.com/open-mmlab/mmclassification,最新发布版本为v0.23.2,License为Apache-2.0。它支持在Windows、Linux和Mac上运行。
      1.安装:使用conda安装
      (1).创建openmmlab虚拟环境:

conda create -n openmmlab python=3.8
conda activate openmmlab

      (2).安装PyTorch:这里PyTorch使用1.11.0版本,CUDA使用10.2版本,此CUDA版本对PyTorch各版本都支持

conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=10.2 -c pytorch

      (3).安装MMCV:MMCV有两个版本,这里安装带CUDA的mmcv-full
      1).mmcv-full: 完整版,包含所有的特性以及丰富的开箱即用的CUDA算子,安装此版本需要较长时间。
      2).mmcv:精简版,不包含CUDA算子但包含其余所有特性和功能,类似MMCV 1.0之前的版本。
      不要在同一个环境中安装两个版本,否则可能会遇到类似ModuleNotFound的错误。在安装一个版本之前,需要先卸载另一个:

pip uninstall mmcv-full
pip uninstall mmcv

      注意:这里mmcv-full使用1.5.3版本。CUDA版本和PyTorch版本与安装PyTorch时保持一致

pip install mmcv-full==1.5.3 -f https://download.openmmlab.com/mmcv/dist/cu102/torch1.11.0/index.html

      (4).安装MMClassification:没有通过源码安装

pip install mmcls==0.23.2

      2.测试:论文:《Very Deep Convolutional Networks for Large-Scale Image Recognition》
      ImageNet数据集:是根据WordNet层次结构组织的图像数据集,ImageNet_1000_label中给出了1000类别中label对应的id值。

      (1).下载模型(checkpoint):

def download_checkpoint(path, name, url):
	if os.path.isfile(path+name) == False:
		print("checkpoint(model) file does not exist, now download ...")
		subprocess.run(["wget", "-P", path, url])

path = "../../data/model/"
checkpoint = "vgg19_batch256_imagenet_20210208-e6920e4a.pth"
url = "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth"
download_checkpoint(path, checkpoint, url)

      (2).根据配置文件和checkpoint文件构建模型:

config = "../../src/mmclassification/configs/vgg/vgg19_8xb32_in1k.py"
model = init_model(config, path+checkpoint, device)

      (3).准备测试图像:原始图像来自网络

image_path = "../../data/image/"
image_name = "6.jpg"

      (4).进行推理:

result = inference_model(model, image)
print(mmcv.dump(result, file_format='json', indent=4))
# show_result_pyplot(model, image, result)

       执行结果如下图所示:

      GitHub: https://github.com/fengbingchun/PyTorch_Test

以上是关于开源图像分类工具箱MMClassification安装及使用示例的主要内容,如果未能解决你的问题,请参考以下文章

计算机视觉框架OpenMMLab开源学习:图像分类实战

计算机视觉框架OpenMMLab开源学习:图像分类

OpenMMLab 实战营打卡 - 第 三 课

图像分类/识别开源库

PyTorch开源图像分类算法框架

10 个 Python 图像编辑工具