Pythontorch.where()解析

Posted 笃℃

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pythontorch.where()解析相关的知识,希望对你有一定的参考价值。

【Python】torch.where()解析

文章目录

1. 介绍

torch.where(condition, x, y)

  • 函数功能:
    • 将指定 tensor 的满足条件位置设置为想要的数值。
  • 参数:
    • condition:判断条件
    • x:若满足条件,则为 x 中元素
    • y:若不满足条件,则为 y 中元素

2. API

(function)
'''
四种调用方法如下:
'''
def where(
    condition: Tensor,
    input: Tensor,
    other: Tensor,
    *,
    out: Tensor | None = None
) -> Tensor: ...

def where(
    condition: Tensor,
    self: Number,
    other: Tensor
) -> Tensor: ...

def where(
    condition: Tensor,
    input: Tensor,
    other: Number
) -> Tensor: ...

def where(
    condition: Tensor,
    self: Number,
    other: Number
) -> Tensor: ...

3. 示例

import torch
 
# 条件
condition = torch.rand(3, 2)
print(condition)
# 满足条件则取x中对应元素
x = torch.ones(3, 2)
print(x)
# 不满足条件则取y中对应元素
y = torch.zeros(3, 2)
print(y)
# 条件判断后的结果
result = torch.where(condition > 0.5, x, y)
print(result)

输出如下:

tensor([[0.3224, 0.5789],
        [0.8341, 0.1673],
        [0.1668, 0.4933]])
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])
tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])
tensor([[0., 1.],
        [1., 0.],
        [0., 0.]])

以上是关于Pythontorch.where()解析的主要内容,如果未能解决你的问题,请参考以下文章

Python CSV 解析,转义引号字符

aliyun域名解析python api

Python中的urlparseurllib抓取和解析网页

使用 xml.etree.ElementTree 在 python 中解析 XML

在FTP失败连接Pyth3.6中循环

Pyth-Solana链上联通现实的桥梁