获取 torchvision 预训练网络的分类标签
Posted
技术标签:
【中文标题】获取 torchvision 预训练网络的分类标签【英文标题】:getting the classification labels for torchvision's pretrained networks 【发布时间】:2020-06-17 14:09:25 【问题描述】:Pytorch 的torchvision
包提供pre-trained neural networks 用于图像分类。我一直在使用以下代码使用 Alexnet 对图像进行分类(注意:其中一些代码来自 this webpage):
from PIL import Image
import torch
from torchvision import transforms
from torchvision import models
# function to transform image
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
# image
img = Image.open('/path/to/image.jpg')
img = transform(img)
img = torch.unsqueeze(img, 0)
# alexnet
alexnet = models.alexnet(pretrained=True)
alexnet.eval()
out = alexnet(img)
percents = torch.nn.functional.softmax(out, dim=1)[0] * 100
top5_vals, top5_inds = percents.topk(5)
共有 1,000 个类,top5_inds
变量为我提供了前 5 个类的索引。但是我如何获得相关的标签(例如蜗牛、篮球、香蕉)?我似乎找不到任何类型的列表作为 Pytorch 文档或 alexnet
变量的一部分。
【问题讨论】:
【参考方案1】:Torchvision 模型在 ImageNet 数据集上进行了预训练。由于其全面性和规模,ImageNet 是预训练和迁移学习最常用的数据集。正如您所指出的,它有 1000 个类。完整的类列表可以搜索,也可以参考GitHub上的这个列表:https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
【讨论】:
以上是关于获取 torchvision 预训练网络的分类标签的主要内容,如果未能解决你的问题,请参考以下文章
[Pytorch系列-43]:工具集 - torchvision预训练模型参数的导入(以ResNet为例)
[Pytorch系列-42]:工具集 - torchvision常见预训练模型的下载地址
[Pytorch系列-39]:工具集 - torchvision搭建AlexNet/VGG/Resnet等网络并训练CFAR10分类数据