Pytorch torch.save() 保存特征向量

Posted LiQiang33

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch torch.save() 保存特征向量相关的知识,希望对你有一定的参考价值。

文章目录

1 需求

存取上述特征向量

2 实现

  • 数据结构: 使用list存储这些向量,[(r_emb, query), ...]
  • 工具: torch.save()tensor保存为.pth,存取对象是字典
"""
保存特征向量,推荐使用torch保存,直接保存为tensor
"""
import torch


def save_feature(feature_list, feature_path):
    feature = 

    
    for i, (r_emb, query) in enumerate(feature_list):
        feature[f"r_emb_i"] = r_emb
        feature[f"query_i"] = query

    torch.save(feature, feature_path)
    pass

def load_feature(feature_path):
    feature = torch.load(feature_path)
    feature_list = []
    for i in range(len(feature.keys()) // 2):
        r_emb = feature[f"r_emb_i"]
        query = feature[f"query_i"]
        feature_list.append((r_emb, query))
        ...
    return feature_list
    ...

if __name__ == "__main__":
    r_emb_1 = torch.randn((32, 75, 512))
    query_1 = torch.randn((32, 22, 512))

    r_emb_2 = torch.randn((32, 75, 512))
    query_2 = torch.randn((32, 26, 512))

    feature_list = [(r_emb_1, query_1), (r_emb_2, query_2)]
    feature_path = "./save_feature.pth"

    # save_feature(feature_list, feature_path)
    feature = load_feature(feature_path)
    print("query_1 shape:", feature[0][1].shape)
    pass


以上是关于Pytorch torch.save() 保存特征向量的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch中通过torch.save保存模型和torch.load加载模型介绍

pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件

pytorch保存模型遇到点问题

PyTorch保存和加载模型

[Pytorch]Pytorch 保存模型与加载模型(转)

PyTorch学习网络的保存与提取