One-hot的使用
Posted MartinRY
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了One-hot的使用相关的知识,希望对你有一定的参考价值。
pytorch官方已经提供了具体one-hot函数,可以直接使用
from torch.nn.functional import one_hot
result = one_hot(data,num_class)
首先介绍下,one_hot可以运用在当调用损失函数时,所放入数据尺寸不等的时候
例如:
raise ValueError("Target size () must be the same as input size ()".format(target.size(), input.size()))
ValueError: Target size (torch.Size([2, 256, 256])) must be the same as input size (torch.Size([2, 2, 256, 256]))
可以看到上面输入的预测是一个[2, 2, 256, 256] 的tensor(分别是nchw,即batch_size, channels, height, width)
而label的则是[2, 256, 256] ,缺少了通道数(因为label是单通道)
因此可以使用one_hot进行转换,将tensor数据直接放入one_hot中的第一个参数的位置,因为只有2类,所以num_class=2
from torch.nn.functional import one_hot
result = one_hot(label,2)
但直接放有时会遇到数据类型的问题,由于one_hot只接受int类型,所以若label数据是float类型,需要先将类型转换为int
label = label.to(dtype=torch.int64)
然后再放进one_hot
最后,由于模型的bias比较常为float类型,所以直接把输出的one_hot放进去会报数据类型错误
,因此在one_hot输出后,还需要再转换为float类型
one_hot_output = one_hot_output.type(torch.float32)
最后,把one_hot输出和pred放进计算loss的函数进行计算
以上是关于One-hot的使用的主要内容,如果未能解决你的问题,请参考以下文章