七月在线 《关键点检测概览与环境配置》
Posted 刘润森!
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了七月在线 《关键点检测概览与环境配置》相关的知识,希望对你有一定的参考价值。
七月在线 课程:https://www.julyedu.com/course/getDetail/262
什么是关键点?
关键点定义:关键点也称为兴趣点,它是2D图像、3D点云或曲面模型上,可以通过定义检测标准来获取的具有稳定性、区别性的点集。关键点检测涉及同时检测人和定位他们的关键点。关键点与兴趣点相同。它们是空间位置或图像中的点,它们定义了图像中有趣或突出的内容。它们对图像旋转、收缩、平移、失真等是不变的。
关键点的意义?
加快后续识别、追踪等数据的处理速度。
环境配置
nvidia GPU 配置:
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html
code : MNIST
MNIST实战!
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import torchvision
import numpy as np
from torch.autograd import Variable
import random
%matplotlib inline
transform = transforms.Compose([
transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
data_train = datasets.MNIST(root = "./data/",
transform=transform,
train = True,
download = True)
data_test = datasets.MNIST(root="./data/",
transform = transform,
train = False)
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
batch_size = 64,
shuffle = True,
num_workers=2)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
batch_size = 64,
shuffle = True,
num_workers=2)
images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1,2,0)
std = [0.5,0.5,0.5]
mean = [0.5,0.5,0.5]
img = img*std+mean
print([labels[i] for i in range(64)])
plt.imshow(img)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(stride=2,kernel_size=2))
self.dense = torch.nn.Sequential(torch.nn.Linear(14*14*128,1024),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(1024, 10))
def forward(self, x):
x = self.conv1(x)
x = x.view(-1, 14*14*128)
x = self.dense(x)
return x
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
n_epochs = 5
for epoch in range(n_epochs):
running_loss = 0.0
running_correct = 0
print("Epoch /".format(epoch, n_epochs))
print("-"*10)
for data in data_loader_train:
X_train, y_train = data
X_train, y_train = Variable(X_train), Variable(y_train)
outputs = model(X_train)
_,pred = torch.max(outputs.data, 1)
optimizer.zero_grad()
loss = cost(outputs, y_train)
loss.backward()
optimizer.step() #进行单次优化
running_loss += loss.data
running_correct += torch.sum(pred == y_train.data)
testing_correct = 0
for data in data_loader_test:
X_test, y_test = data
X_test, y_test = Variable(X_test), Variable(y_test)
outputs = model(X_test)
_, pred = torch.max(outputs.data, 1)
testing_correct += torch.sum(pred == y_test.data)
print("Loss is::.4f, Train Accuracy is::.4f%, Test Accuracy is::.4f".format(running_loss/len(data_train),
100*running_correct/len(data_train),
100*testing_correct/len(data_test)))
torch.save(model.state_dict(), "model_parameter.pkl")
reference resources
- https://paperswithcode.com/sota/keypoint-detection-on-coco-test-dev
以上是关于七月在线 《关键点检测概览与环境配置》的主要内容,如果未能解决你的问题,请参考以下文章