使用cnn提取特征,图像相似度对比。pytorch 推理的时候报内存不足的问题
Posted 东东就是我
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用cnn提取特征,图像相似度对比。pytorch 推理的时候报内存不足的问题相关的知识,希望对你有一定的参考价值。
with torch.no_grad()
https://blog.csdn.net/CRDarwin/article/details/119943128
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
def cos_sim(a, b):
"""
计算两个向量之间的余弦相似度
"""
a = np.mat(a)
b = np.mat(b)
return float(a * b.T) / (np.linalg.norm(a) * np.linalg.norm(b))
class MyDataset(Dataset):
def __init__(self, file_path,transform = None):
file_name=os.listdir(file_path)
imgs = []
img_names = []
for file in file_name:
image_path=os.path.join(file_path,file)
imgs.append(image_path)
img_names.append(file)
self.imgs = imgs
self.img_names = img_names
self.transform = transform
def __getitem__(self, index):
fn= self.imgs[index]
img_name=self.img_names[index]
img = Image.open(fn).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img,img_name
def __len__(self):
return len(self.imgs)
file_path="F:/my_code/2021/sf/data/ICR_EXT/"
save_path="F:/my_code/2021/sf/data/s/"
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((512,640)),
])
mydata=MyDataset(file_path,transform=preprocess)
l=mydata.__len__()
mydata_loader=DataLoader(mydata,batch_size=32,shuffle=True)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = torchvision.models.shufflenet_v2_x0_5(pretrained=True)
model.fc = nn.Sequential()
model.to(device)
model.eval()
img_names=[]
img_features = []
for image,img_name in mydata_loader:
image = image.to(device, torch.float)
img_names.append(img_name)
with torch.no_grad():
predictions = model(image)
predictions.unsqueeze_(1)
for t in predictions:
img_features.append(t)
print(len(img_features))
b = str(img_names)
b = b.replace('(', '')
b = b.replace(')', '')
img_names = list(eval(b))
torch.cat(img_features,dim=0)
dictionary = dict(zip( img_features,img_names))
result=[]
while len(dictionary)>0:
imgf=list(dictionary.keys())[0]
result.append(dictionary.get(imgf))
dictionary.pop(imgf)
for img_feature in list(dictionary.keys()):
if cos_sim(imgf.tolist(),img_feature.tolist())>0.93:
dictionary.pop(img_feature)
for image_name in result:
image_path = os.path.join(file_path,image_name)
image = Image.open(image_path)
save_img_path= os.path.join(save_path,image_name)
image.save(save_img_path)
print(result.__len__())
以上是关于使用cnn提取特征,图像相似度对比。pytorch 推理的时候报内存不足的问题的主要内容,如果未能解决你的问题,请参考以下文章
数据挖掘视觉模式挖掘:Hog特征+余弦相似度/k-means聚类