PyTorch 实战(模型训练、模型加载、模型测试)

摘要:
这一次,我们将使用Python的一个实际项目来记录这个过程:自定义数据集-˃数据加载-˃构建神经网络-˃迁移学习-˃保存模型-˃加载模型-˃测试模型自定义数据集。请参阅我的上一篇博客:用户定义的数据集处理数据加载默认伙伴对深度学习框架有一定的了解,因此我们在这里不做过多解释。Kaiming He的“图像识别的深度反应式学习”获得了CVPR的最佳论文。因此,何开明等人在普通普通网络的基础上增加了一条捷径,形成了一个剩余块。

    本次将一个使用Pytorch的一个实战项目,记录流程:自定义数据集->数据加载->搭建神经网络->迁移学习->保存模型->加载模型->测试模型

    自定义数据集
    参考我的上一篇博客:自定义数据集处理

    数据加载
    默认小伙伴有对深度学习框架有一定的了解,这里就不做过多的说明了。
    好吧,还是简单的说一下吧:
    我们在做好了自定义数据集之后,其实数据的加载和MNSIT 、CIFAR-10 、CIFAR-100等数据集的都是相似的,过程如下所示:
        导入必要的包

import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader

    1
    2
    3
    4

    加载数据
    可以发现和MNIST 、CIFAR的加载基本上是一样的

train_db = Pokemon('pokeman', 224, mode='train')
val_db = Pokemon('pokeman', 224, mode='val')
test_db = Pokemon('pokeman', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
                          num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)

    1
    2
    3
    4
    5
    6
    7

    搭建神经网络
    ResNet-18网络结构:
    在这里插入图片描述
    ResNet全名Residual Network残差网络。Kaiming He 的《Deep Residual Learning for Image Recognition》获得了CVPR最佳论文。他提出的深度残差网络在2015年可以说是洗刷了图像方面的各大比赛,以绝对优势取得了多个比赛的冠军。而且它在保证网络精度的前提下,将网络的深度达到了152层,后来又进一步加到1000的深度。论文的开篇先是说明了深度网络的好处:特征等级随着网络的加深而变高,网络的表达能力也会大大提高。因此论文中提出了一个问题:是否可以通过叠加网络层数来获得一个更好的网络呢?作者经过实验发现,单纯的把网络叠起来的深层网络的效果反而不如合适层数的较浅的网络效果。因此何恺明等人在普通平原网络的基础上增加了一个shortcut, 构成一个residual block。此时拟合目标就变为F(x),F(x)就是残差:
    在这里插入图片描述
        训练模型

def evalute(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()
    return correct / total
def main():
    model = ResNet18(5).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):
        for step, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            model.train()
            logits = model(x)
            loss = criteon(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1
        if epoch % 1 == 0:
            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc
                viz.line([val_acc], [global_step], win='val_acc', update='append')
    print('best acc:', best_acc, 'best epoch:', best_epoch)
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')
    test_acc = evalute(model, test_loader)



    迁移学习
    提升模型的准确率:


    # model = ResNet18(5).to(device)
    trained_model=resnet18(pretrained=True)  # 此时是一个非常好的model
    model = nn.Sequential(*list(trained_model.children())[:-1],  # 此时使用的是前17层的网络 0-17  *:随机打散
                          Flatten(),
                          nn.Linear(512,5)
                          ).to(device)
    # x=torch.randn(2,3,224,224)
    # print(model(x).shape)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11

    保存、加载模型
    pytorch保存模型的方式有两种:
    第一种:将整个网络都都保存下来
    第二种:仅保存和加载模型参数(推荐使用这样的方法)

# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')

    1
    2
    3

# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

    1
    2
    3

可以看到这是我保存的模型:
其中best.mdl是第二中方法保存的
model.pkl则是第一种方法保存的
在这里插入图片描述

    测试模型
    这里是训练时的情况
    在这里插入图片描述
    看这个数据准确率还是不错的,但是还是需要实际的测试这个模型,看它到底学到东西了没有,接下来简单的测试一下:


import torch
from PIL import Image
from torchvision import transforms
device = torch.device('cuda')
transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std=[0.229,0.224,0.225])
                            ])
def prediect(img_path):
    net=torch.load('model.pkl')
    net=net.to(device)
    torch.no_grad()
    img=Image.open(img_path)
    img=transform(img).unsqueeze(0)
    img_ = img.to(device)
    outputs = net(img_)
    _, predicted = torch.max(outputs, 1)
    # print(predicted)
    print('this picture maybe :',classes[predicted[0]])
if __name__ == '__main__':
    prediect('./test/name.jpg')



实际的测试结果:
在这里插入图片描述
在这里插入图片描述
效果还是可以的,完整的代码:
https://github.com/huzixuan1/Loader_DateSet
数据集下载链接:
https://pan.baidu.com/s/12-NQiF4fXEOKrXXdbdDPCg


免责声明:文章转载自《PyTorch 实战(模型训练、模型加载、模型测试)》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇(转)django使用django-celery与celeryJQuery点击行tr实现checkBox选中与未选中切换下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

用scikit-learn和pandas学习线性回归

对于想深入了解线性回归的童鞋,这里给出一个完整的例子,详细学完这个例子,对用scikit-learn来运行线性回归,评估模型不会有什么问题了。 1. 获取数据,定义问题 没有数据,当然没法研究机器学习啦。:) 这里我们用UCI大学公开的机器学习数据来跑线性回归。 数据的介绍在这:http://archive.ics.uci.edu/ml/datasets/...

智能客户端(SmartClient)

引文 http://dev.csdn.net/develop/article/16/16270.shtm  智能客户端(SmartClient)     本文主要讨论基于企业环境的客户端应用程序模型,由于本人曾经从事过传统的客户端/服务器两层结构应用程序和基于.net平台的多层结构应用程序的开发,因此本文将着重描述.net平台上的智能客户端应用程序模型,并...

零基础入门深度学习(5)

无论即将到来的是大数据时代还是人工智能时代,亦或是传统行业使用人工智能在云上处理大数据的时代,作为一个有理想有追求的程序员,不懂深度学习(Deep Learning)这个超热的技术,会不会感觉马上就out了?现在救命稻草来了,《零基础入门深度学习》系列文章旨在讲帮助爱编程的你从零基础达到入门级水平。零基础意味着你不需要太多的数学知识,只要会写程序就行了,...

【转】几款网络仿真软件的比较

转自: 网络仿真技术是一种通过建立网络设备和网络链路的统计模型, 并模拟网络流量的传输, 从而获取网络设计或优化所需要的网络性能数据的仿真技术。由于仿真不是基于数学计算, 而是基于统计模型,因此,统计复用的随机性被精确地再现。网络仿真技术具有以下特点:一, 全新的模拟实验机理使其具有在高度复杂的网络环境下得到高可信度结果的特点。二, 网络仿真的预测功能是其...

4G EPS 的架构模型

目录 文章目录 目录 前文列表 EPS 的架构 EPS 的架构模型 E-UTRAN UE eNodeB EPC MME(移动性管理) SGW(本地移动性锚点) PGW(业务锚点) HSS(用户认证及鉴权中心) PCRF(计费规则与策略) EPS 运行原理 上行传输 下行传输 前文列表 《4G EPS 第四代移动通信系统》 EPS 的架构...

ios开发网络学习二:URL转码以及字典转模型框架MJExtension的使用

一:url转码,当url中涉及到中文的时候,要考虑转码,用UTF8对中文的url进行转码 #import "ViewController.h" @interfaceViewController () @end @implementationViewController #pragma mark ---------------------- #pra...