pytorch笔记:torch.sparse类

Posted UQI-LIUWJ

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch笔记:torch.sparse类相关的知识,希望对你有一定的参考价值。

1 构造稀疏矩阵

import torch
i = torch.LongTensor([[0, 1, 1],[2, 0, 2]])   #row, col
v = torch.FloatTensor([3, 4, 5])    #data
torch.sparse.FloatTensor(i, v, torch.Size([2,3])).to_dense()   #torch.Size
'''
tensor([[0., 0., 3.],
        [4., 0., 5.]])
'''

构造方法和 scipy笔记:scipy.sparse_UQI-LIUWJ的博客-CSDN博客 2.2 coo矩阵 的类似

2 稀疏矩阵的基本运算

先构造两个稀疏矩阵

import torch
i = torch.LongTensor([[0, 1, 1],[2, 0, 2]])   #row, col
v = torch.FloatTensor([3, 4, 5])    #data
x1=torch.sparse.FloatTensor(i, v, torch.Size([2,3]))

x1,x1.to_dense()  
'''
(tensor(indices=tensor([[0, 1, 1],
                        [2, 0, 2]]),
        values=tensor([3., 4., 5.]),
        size=(2, 3), nnz=3, layout=torch.sparse_coo),
 tensor([[0., 0., 3.],
         [4., 0., 5.]]))
'''

import torch
i = torch.LongTensor([[0, 1, 1],[1, 0, 1]])   #row, col
v = torch.FloatTensor([3, 4, 5])    #data
x2=torch.sparse.FloatTensor(i, v, torch.Size([3,2]))

x2,x2.to_dense() 
'''
(tensor(indices=tensor([[0, 1, 1],
                        [1, 0, 1]]),
        values=tensor([3., 4., 5.]),
        size=(3, 2), nnz=3, layout=torch.sparse_coo),
 tensor([[0., 3.],
         [4., 5.],
         [0., 0.]]))
'''

2.1 稀疏矩阵的乘法

2.1.1 torch.mm

只支持第二个参数是dense(即dense*dense,或者sparse*dense)

dense*dense
dense*sparse
sparse*sparse
sparse*dense

 2.1.2 torch.sparse.mm

 同样地,只支持第二个参数是dense(即dense*dense,或者sparse*dense) 

dense*dense
dense*sparse
sparse*sparse
sparse*dense

2.2 转置

 t()即可

x2,x2.to_dense()
'''
(tensor(indices=tensor([[0, 1, 1],
                        [1, 0, 1]]),
        values=tensor([3., 4., 5.]),
        size=(3, 2), nnz=3, layout=torch.sparse_coo),
 tensor([[0., 3.],
         [4., 5.],
         [0., 0.]]))
'''

x2.t(),x2.t().to_dense()

'''
(tensor(indices=tensor([[1, 0, 1],
                        [0, 1, 1]]),
        values=tensor([3., 4., 5.]),
        size=(2, 3), nnz=3, layout=torch.sparse_coo),
 tensor([[0., 4., 0.],
         [3., 5., 0.]]))
'''

 2.3 索引

稀疏矩阵支持整行索引,支持Sparse.matrix[row_index];

x2,x2.to_dense()
'''
(tensor(indices=tensor([[0, 1, 1],
                        [1, 0, 1]]),
        values=tensor([3., 4., 5.]),
        size=(3, 2), nnz=3, layout=torch.sparse_coo),
 tensor([[0., 3.],
         [4., 5.],
         [0., 0.]]))
'''

x2[1],x2[1].to_dense()
'''
(tensor(indices=tensor([[0, 1]]),
        values=tensor([4., 5.]),
        size=(2,), nnz=2, layout=torch.sparse_coo),
 tensor([4., 5.]))
'''

稀疏矩阵不支持具体位置位置索引Sparse.matrix[row_index,col_index]

x2[1][1],x2[1][1].to_dense()

 2.4 相加

a = torch.sparse.FloatTensor(
    torch.tensor([[0,1,2],[2,3,4]]), 
    torch.tensor([1,1,1]), 
    torch.Size([5,5]))
a.to_dense()
'''
tensor([[0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]])
'''

a1=torch.sparse.FloatTensor(
    torch.tensor([[0,3,2],[2,3,2]]), 
    torch.tensor([1,1,1]), 
    torch.Size([5,5]))
a1.to_dense()
'''
tensor([[0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0]])
'''

 只支持sparse+sparse

torch.add(a,a1) ,torch.add(a,a1).to_dense()
'''
(tensor(indices=tensor([[0, 1, 2, 3, 2],
                        [2, 3, 4, 3, 2]]),
        values=tensor([2, 1, 1, 1, 1]),
        size=(5, 5), nnz=5, layout=torch.sparse_coo),
 tensor([[0, 0, 2, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 1, 0, 1],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0]]))
'''

a.add(a1),a.add(a1).to_dense()
#同理

以上是关于pytorch笔记:torch.sparse类的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch 1.0 中文文档:torch.sparse

我是土堆 - Pytorch教程 知识点 学习总结笔记

Pytorch学习笔记——Sequential类参数管理与GPU

Pytorch学习笔记(9) 通过DataSet、DatasetLoader构建模型输入数据集

PYTORCH 笔记 DILATE 代码解读

学习打卡07 可解释机器学习笔记之Shape+Lime代码实战