代码训练,图像分类通用测试代码
Posted Lf&x&my
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了代码训练,图像分类通用测试代码相关的知识,希望对你有一定的参考价值。
图像分类通用测试代码
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
data_transform = {
"train":transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]),
"val":transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])}
data_root = os.path.abspath(os.path.join(os.getcwd(),"../.."))
image_path = os.path.join(data_root, "data_set", "flower_data")
assert os.path.exists(image_path),"{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path,"train"),transform=data_transform["train"])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val,key) for key,val in flower_list.items())
json_str = json.dumps(cla_dict,indent=4)
with open(\'class_indices.json\',\'w\') as json_file:
json_file.write(json_str)
batch_size = 32
nw = min([os.cpu_count(), batch_size if batch_size >1 else 0, 8])
print(\'Using {} dataloadet workers every process\'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4, shuffle=True,num_workers=nw)
print("using {} images for training, {} images fot validation.".format(train_num,val_num))
net = AlexNet(num_classes=5,init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0002)
save_path = \'./AlexNet.pth\'
best_acc = 0.0
for epoch in range(10):
net.train()
running_loss = 0.0
t1 = time.perf_counter()
for step, data in enumerate(train_loader, start=0):
images,labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs,labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
rate = (step + 1) / len(train_loader)
a ="*" * int(rate * 50)
b ="." * int((1-rate) * 50)
print("\\rtrain loss: {:^3.0f}%[{}->{}]{:.f}".format(int(rate * 100), a, b, loss), end="")
print()
print(time.perf_counter()-t1)
net.eval()
acc = 0.0
with torch.no_grad():
for val_data in validate_loader:
val_images,val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += (predict_y == val_labels.to(device)).sum().item()
val_accurate = acc / val_num
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print(\'[epoch %d] train_loss: %.3f test_accuracy;%.3f\' % (epoch + 1, running_loss / step, val_accurate))
print(\'Finished Training\')
以上是关于代码训练,图像分类通用测试代码的主要内容,如果未能解决你的问题,请参考以下文章