vgg的grad作为激活值来展示图片物体
Posted waldenlake
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了vgg的grad作为激活值来展示图片物体相关的知识,希望对你有一定的参考价值。
import torch import numpy import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import transforms import torchvision.models as models normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],#这是imagenet std=[0.229, 0.224, 0.225]) tran=transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) im=‘./1.jpeg‘ # im=‘./2.jpg‘ im=Image.open(im) im=tran(im) im.unsqueeze_(dim=0) im=torch.autograd.Variable(im,requires_grad=True) vgg = models.vgg16() pre=torch.load(‘/home/qk/.torch/models/vgg16-397923af.pth‘) vgg.load_state_dict(pre) out=vgg(im) outnp=out.data[0] ind=int(numpy.argmax(outnp)) out[0][ind].backward() grad=im.grad grad.squeeze_(0) grad=grad*grad grad=grad.sum(keepdim=False,dim=0) grad=torch.sqrt(grad) rg=torch.max(grad) grad=grad/rg*255. #太神奇了,为什么有uint8就能出来激活图,没有就不行!!!!!这要是不搜索一下的话,怎么可能debug出来呢? # im = Image.fromarray(grad.numpy(),‘L‘) # .eval() tensor->numpy array im = Image.fromarray(numpy.uint8(grad.numpy()),‘L‘) # .eval() tensor->numpy array im.save(‘grey.png‘) # input() from cls import d print(d[ind])
以上是关于vgg的grad作为激活值来展示图片物体的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch RuntimeError:张量的元素 0 不需要 grad 并且没有 grad_fn