Pytorch torch.save() 保存特征向量
Posted LiQiang33
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch torch.save() 保存特征向量相关的知识,希望对你有一定的参考价值。
文章目录
1 需求
存取上述特征向量
2 实现
- 数据结构: 使用
list
存储这些向量,[(r_emb, query), ...]
- 工具:
torch.save()
将tensor
保存为.pth
,存取对象是字典
"""
保存特征向量,推荐使用torch保存,直接保存为tensor
"""
import torch
def save_feature(feature_list, feature_path):
feature =
for i, (r_emb, query) in enumerate(feature_list):
feature[f"r_emb_i"] = r_emb
feature[f"query_i"] = query
torch.save(feature, feature_path)
pass
def load_feature(feature_path):
feature = torch.load(feature_path)
feature_list = []
for i in range(len(feature.keys()) // 2):
r_emb = feature[f"r_emb_i"]
query = feature[f"query_i"]
feature_list.append((r_emb, query))
...
return feature_list
...
if __name__ == "__main__":
r_emb_1 = torch.randn((32, 75, 512))
query_1 = torch.randn((32, 22, 512))
r_emb_2 = torch.randn((32, 75, 512))
query_2 = torch.randn((32, 26, 512))
feature_list = [(r_emb_1, query_1), (r_emb_2, query_2)]
feature_path = "./save_feature.pth"
# save_feature(feature_list, feature_path)
feature = load_feature(feature_path)
print("query_1 shape:", feature[0][1].shape)
pass
以上是关于Pytorch torch.save() 保存特征向量的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch中通过torch.save保存模型和torch.load加载模型介绍