如何将pytorch中的标签转换为onehot
Posted
技术标签:
【中文标题】如何将pytorch中的标签转换为onehot【英文标题】:How to transform labels in pytorch to onehot 【发布时间】:2020-11-30 04:09:18 【问题描述】:如何给target_transform
一个函数来改变标签为onehot编码?
例如torchvision中的MNIST数据集:
train_dataset = torchvision.datasets.MNIST(root='./mnist_data/',
train=True,
download=True,
transform=train_transform,
target_transform=<????>)
试过F.onehot()
,但没用。
【问题讨论】:
F.onehot() 有什么问题?torch.nn.functional.one_hot(torch.tensor(2),5).type(torch.cuda.FloatTensor)
对我来说工作正常
【参考方案1】:
这就是我实现它的方式。不确定是否有更清洁的方法。
train_dataset = torchvision.datasets.MNIST(root='./data/', train=True,
transform=torchvision.transforms.ToTensor(),
target_transform=torchvision.transforms.Compose([
lambda x:torch.LongTensor([x]), # or just torch.tensor
lambda x:F.one_hot(x,10)]),
download=True)
必须是index tensor
?即 int64
不能使用torchvision.ToTensor
,因为它不是图片
此外,torch.LongTensor
和 torch.tensor
与 int
输入的行为不同
需要提供类数
【讨论】:
【参考方案2】:使用 lambda 用户定义函数将整数转换为 one-hot 编码张量。
train_dataset = torchvision.datasets.MNIST(root='./mnist_data/', train=True,
download=True, transform=train_transform,
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
它首先创建一个大小为 10 的零张量(我们数据集中的标签数量)
并调用 scatter_,它在标签 y 给定的索引上分配 value=1。
【讨论】:
以上是关于如何将pytorch中的标签转换为onehot的主要内容,如果未能解决你的问题,请参考以下文章
将 pytorch float Sigmoid 结果转换为标签
PyTorch——Tensor_把索引标签转换成one-hot标签表示
如何在pytorch中计算BCEWithLogitsLoss的不平衡权重
如何在 PyTorch 中将 RGB 图像编码为 n_class One 热张量