Pythontorch.where()解析
Posted 笃℃
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pythontorch.where()解析相关的知识,希望对你有一定的参考价值。
【Python】torch.where()解析
文章目录
1. 介绍
torch.where(condition, x, y)
- 函数功能:
- 将指定 tensor 的满足条件位置设置为想要的数值。
- 参数:
- condition:判断条件
- x:若满足条件,则为 x 中元素
- y:若不满足条件,则为 y 中元素
2. API
(function)
'''
四种调用方法如下:
'''
def where(
condition: Tensor,
input: Tensor,
other: Tensor,
*,
out: Tensor | None = None
) -> Tensor: ...
def where(
condition: Tensor,
self: Number,
other: Tensor
) -> Tensor: ...
def where(
condition: Tensor,
input: Tensor,
other: Number
) -> Tensor: ...
def where(
condition: Tensor,
self: Number,
other: Number
) -> Tensor: ...
3. 示例
import torch
# 条件
condition = torch.rand(3, 2)
print(condition)
# 满足条件则取x中对应元素
x = torch.ones(3, 2)
print(x)
# 不满足条件则取y中对应元素
y = torch.zeros(3, 2)
print(y)
# 条件判断后的结果
result = torch.where(condition > 0.5, x, y)
print(result)
输出如下:
tensor([[0.3224, 0.5789],
[0.8341, 0.1673],
[0.1668, 0.4933]])
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
tensor([[0., 0.],
[0., 0.],
[0., 0.]])
tensor([[0., 1.],
[1., 0.],
[0., 0.]])
以上是关于Pythontorch.where()解析的主要内容,如果未能解决你的问题,请参考以下文章