代码训练,图像分类通用测试代码

Posted Lf&x&my

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了代码训练,图像分类通用测试代码相关的知识,希望对你有一定的参考价值。

图像分类通用测试代码

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    data_transform = {
        "train":transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]),
        "val":transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])}
    data_root = os.path.abspath(os.path.join(os.getcwd(),"../.."))
    image_path = os.path.join(data_root, "data_set", "flower_data")
    assert os.path.exists(image_path),"{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path,"train"),transform=data_transform["train"])
    train_num = len(train_dataset)
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val,key) for key,val in flower_list.items())
    json_str = json.dumps(cla_dict,indent=4)
    with open(\'class_indices.json\',\'w\') as json_file:
        json_file.write(json_str)
    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size >1 else 0, 8])
    print(\'Using {} dataloadet workers every process\'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4, shuffle=True,num_workers=nw)
    print("using {} images for training, {} images fot validation.".format(train_num,val_num))
    net = AlexNet(num_classes=5,init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0002)
    save_path = \'./AlexNet.pth\'
    best_acc = 0.0
    for epoch in range(10):
        net.train()
        running_loss = 0.0
        t1 = time.perf_counter()
        for step, data in enumerate(train_loader, start=0):
            images,labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs,labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            rate = (step + 1) / len(train_loader)
            a ="*" * int(rate * 50)
            b ="." * int((1-rate) * 50)
            print("\\rtrain loss: {:^3.0f}%[{}->{}]{:.f}".format(int(rate * 100), a, b, loss), end="")
        print()
        print(time.perf_counter()-t1)
        net.eval()
        acc = 0.0
        with torch.no_grad():
            for val_data in validate_loader:
                val_images,val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += (predict_y == val_labels.to(device)).sum().item()
            val_accurate = acc / val_num
            if val_accurate > best_acc:
                best_acc = val_accurate
                torch.save(net.state_dict(), save_path)
            print(\'[epoch %d] train_loss: %.3f test_accuracy;%.3f\' % (epoch + 1, running_loss / step, val_accurate))
    print(\'Finished Training\')

以上是关于代码训练,图像分类通用测试代码的主要内容,如果未能解决你的问题,请参考以下文章

MixNet实战:使用MixNet实现图像分类

用Pytorch训练分类模型

MaxViT实战:使用MaxViT实现图像分类任务

MicroNet实战:使用MicroNet实现图像分类

MicroNet实战:使用MicroNet实现图像分类

RepVgg实战:使用RepVgg实现图像分类