如何将pytorch中的标签转换为onehot

Posted

技术标签:

【中文标题】如何将pytorch中的标签转换为onehot【英文标题】:How to transform labels in pytorch to onehot 【发布时间】:2020-11-30 04:09:18 【问题描述】:

如何给target_transform一个函数来改变标签为onehot编码?

例如torchvision中的MNIST数据集:

train_dataset = torchvision.datasets.MNIST(root='./mnist_data/', 
                                           train=True,
                                           download=True,
                                           transform=train_transform,
                                           target_transform=<????>)

试过F.onehot(),但没用。

【问题讨论】:

F.onehot() 有什么问题? torch.nn.functional.one_hot(torch.tensor(2),5).type(torch.cuda.FloatTensor) 对我来说工作正常 【参考方案1】:

这就是我实现它的方式。不确定是否有更清洁的方法。

train_dataset = torchvision.datasets.MNIST(root='./data/', train=True,
                                 transform=torchvision.transforms.ToTensor(),
                                 target_transform=torchvision.transforms.Compose([
                                 lambda x:torch.LongTensor([x]), # or just torch.tensor
                                 lambda x:F.one_hot(x,10)]),
                                 download=True)

必须是index tensor?即 int64

不能使用torchvision.ToTensor,因为它不是图片

此外,torch.LongTensortorch.tensorint 输入的行为不同

需要提供类数

【讨论】:

【参考方案2】:

使用 lambda 用户定义函数将整数转换为 one-hot 编码张量。

train_dataset = torchvision.datasets.MNIST(root='./mnist_data/', train=True, 
    download=True, transform=train_transform, 
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
它首先创建一个大小为 10 的零张量(我们数据集中的标签数量) 并调用 scatter_,它在标签 y 给定的索引上分配 value=1。

【讨论】:

以上是关于如何将pytorch中的标签转换为onehot的主要内容,如果未能解决你的问题,请参考以下文章

将 pytorch float Sigmoid 结果转换为标签

PyTorch——Tensor_把索引标签转换成one-hot标签表示

如何在pytorch中计算BCEWithLogitsLoss的不平衡权重

如何在 PyTorch 中将 RGB 图像编码为 n_class One 热张量

如何将 caffe prototxt 转换为 pytorch 模型?

如何将 pytorch 张量转换为 numpy 数组?