【图像分类】 基于Pytorch的多类别图像分类实战

欢迎大家来到图像分类专栏,本篇基于Pytorch完成一个多类别图像分类实战。

作者 | 郭冰洋

编辑 | 言有三

1 简介

640?wx_fmt=png

实现一个完整的图像分类任务,大致需要分为五个步骤:


1、选择开源框架

目前常用的深度学习框架主要包括tensorflow、caffe、pytorch、mxnet等;

2、构建并读取数据集

根据任务需求搜集相关图像搭建相应的数据集,常见的方式包括:网络爬虫、实地拍摄、公共数据使用等。随后根据所选开源框架读取数据集。

3、框架搭建

选择合适的网络模型、损失函数以及优化方式,以完成整体框架的搭建

4、训练并调试参数

通过训练选定合适超参数

5、测试准确率

在测试集上验证模型的最终性能

本文利用Pytorch框架,按照上述结构实现一个基本的图像分类任务,并详细阐述其中的细节及注意事项。

2 数据集

640?wx_fmt=png

本次实战选择的数据集为Kaggle竞赛中的细胞数据集,共包含9961个训练样本,2491个测试样本,可以分为嗜曙红细胞、淋巴细胞、单核细胞、中性白细胞4个类别,图片大小为320x240。


Pytorch中封装了相应的数据读取的类函数,通过调用torch.utils.data.Datasets函数,则可以实现读取功能。

640?wx_fmt=png

__init__()模块用来定义相关的参数,__len__()模块用来获取训练样本个数,__getitem__()模块则用来获取每张具体的图片,在读取图片时其可以通过opencv库、PIL库等进行读取,具体代码如下:

# 数据集

class dataset(data.Dataset):

   # 参数预定义

此外,需要定义图像增强模块,即上述代码中的transform,通常采取的操作为翻转、剪切等,关于图像增强的具体介绍可以参考公众号前作。

【技术综述】深度学习中的数据增强方法都有哪些?

需要特别强调的是对图像进行去均值处理,很多同学不明白为何要减去均值,其主要的原因是图像作为一种平稳的数据分布,通过减去数据对应维度的统计平均值,可以消除公共部分,以凸显个体之间的特征和差异。进行去均值前后操作后的图像对比如下:

640?wx_fmt=png

3 框架搭建

本次实战主要选取了VGG16、Resnet50、InceptionV4三个经典网络,也是对前篇文章的一个总结。

损失函数则选择交叉熵损失函数:【技术综述】一文道尽softmax loss及其变种

优化方式选择SGD、Adam优化两种:【模型训练】SGD的那些变种,真的比SGD强吗

完整代码获取方式:发送关键词“多类别分类”给公众号

4 训练及参数调试

初始学习率设置为0.01,batch size设置为8,衰减率设置为0.00001,迭代周期为15,在不同框架组合下的最佳准确率和最低loss如下图所示:

640?wx_fmt=png

640?wx_fmt=png

可以发现在验证集上Resnet-50+SGD+Cross Entropy的组合下取得了99%左右的准确率,相反VGG-16结果则稍微差一些。

最佳组合下的准确率走势曲线如下图所示:

640?wx_fmt=jpeg

5 测试

对上述模型分别在测试集上进行测试,所获得的结果如下图所示,整体精度比训练集上约下降了一个百分点:

640?wx_fmt=png

关于代码,可以参考有三AI开源的12大深度学习开源框架使用的项目:

【完结】给新手的12大深度学习开源框架快速入门项目

640?wx_fmt=jpeg

总结

以上就是整个多类别图像分类实战的过程,由于时间限制,本次实战并没有对多个数据集进行训练,因此没有列出同一模型在不同数据集上的表现。

有三AI夏季划

640?wx_fmt=png

有三AI夏季划进行中,欢迎了解并加入,系统性成长为中级CV算法工程师。

转载文章请后台联系

侵权必究

640?wx_fmt=png

640?wx_fmt=png

640?wx_fmt=png

往期精选

<p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <span style="color:#E53333;"><strong>课程介绍</strong></span>  </p> <p style="text-align:left;">      Pytorch项目实战 垃圾分类 课程从实战角度出发,基于真实数据集与实际业务需求,结合当下最新话题-垃圾分类问题为实际业务出发点,介绍最前沿深度学习解决方案。 </p> <p style="text-align:left;">     从0到1讲解如何场景业务分析、进行数据处理,模型训练与调优,最后进行测试与结果展示分析。全程实战操作,以最接地气方式详解每一步流程与解决方案。 </p> <p style="text-align:left;">     课程结合当下深度学习热门领域,尤其是基于facebook 开源分类神器ResNext101网络架构,对网络架构进行调整,以计算机视觉为核心讲解各大网络应用于实战方法,适合快速入门与进阶提升。 </p> <p style="text-align:left;"> <strong><span style="color:#E53333;">课程要求</span></strong> </p> <p style="text-align:left;"> (1)开发环境:python版本:Python3.7+;<span style="color:#E53333;"> torch 版本:1.2.0+; torchvision版本:0.4.0+</span> </p> <p style="text-align:left;"> (2)开发工具:Pycharm; </p> <p style="text-align:left;"> (3)学员基础:需要一定Python基础,及深度学习基础; </p> <p style="text-align:left;"> (4)学员收货:掌握最新科技图像分类关键技术; </p> <p style="text-align:left;"> (5)学员资料:内含完整程序源码和数据集; </p> <p style="text-align:left;"> (6)课程亮点:专题技术,完整案例,全程实战操作,徒手撸代码 </p> <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <span style="color:#E53333;"><strong>课程特色</strong></span> </p> 阵容强大 <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> 讲师一直从事与一线项目开发,高级算法专家,一直从事于图像、NLP、个性化推荐系统热门技术领域。 </p> <p style="text-align:left;"> 仅跟前沿 </p> <p style="text-align:left;"> 基于当前热门讨论话题:垃圾分类,课程采用学术届和工业届最新前沿技术知识要点。 </p> <p style="text-align:left;"> 实战为先 </p> <p style="text-align:left;"> 根据实际深度学习工业场景-垃圾分类,从产品需求、产品设计和方案设计、产品技术功能实现、模型上线部署。精心设计工业实战项目 </p> <p style="text-align:left;"> 保障效果 </p> <p style="text-align:left;"> 项目实战方向包含了学术届和工业届最前沿技术要点 </p> <p style="text-align:left;"> 项目包装简历优化 </p> <p style="text-align:left;"> 课程内垃圾分类图像实战项目完成后可以直接优化到简历中 </p> <p style="text-align:left;"> <strong><span style="color:#E53333;">课程思维导图</span></strong> </p> <p style="text-align:left;"> <img src="https://img-bss.csdn.net/201912081323318969.png" alt="" /> </p> <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <strong><span style="color:#E53333;">课程实战案例</span></strong> </p> <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <img src="https://img-bss.csdn.net/201912081326184463.png" alt="" /> </p> <p style="text-align:left;"> <br /> </p> <p style="text-align:left;"> <br /> </p>
相关推荐
<p> <span><b><span style="background-color:#FFE500;">超实用课程内容</span></b><br /> </span> </p> <p> <span>本课程从</span>pytorch安装开始讲起,从基本计算结构到深度学习各大神经网络,全程案例代码实战,一步步带大家入门如何使用深度学习框架pytorch,玩转pytorch模型训练等所有知识点。最后通过 kaggle 项目:猫狗分类,实战pytorch深度学习工具。 </p> <p> <br /> </p> <p style="font-family:Helvetica;color:#3A4151;font-size:14px;background-color:#FFFFFF;"> <b><span style="background-color:#FFE500;">课程如何观看?</span></b> </p> <p style="font-family:Helvetica;color:#3A4151;font-size:14px;background-color:#FFFFFF;"> PC端:<a href="https://edu.csdn.net/course/detail/26277"></a><a href="https://edu.csdn.net/course/detail/26150"></a><a href="https://edu.csdn.net/course/detail/26150"></a><a href="https://edu.csdn.net/course/detail/27286">https://edu.csdn.net/course/detail/27286</a> </p> <p style="font-family:Helvetica;color:#3A4151;font-size:14px;background-color:#FFFFFF;"> 移动端:CSDN 学院APP(注意不是CSDN APP哦) </p> <p style="font-family:Helvetica;color:#3A4151;font-size:14px;background-color:#FFFFFF;"> 本课程为录播课,课程永久观看,大家可以抓紧时间学习后一起讨论哦~ </p> <p class="ql-long-24357476" style="font-family:"color:#3A4151;font-size:14px;background-color:#FFFFFF;"> <strong><span style="background-color:#FFE500;">学员专享增值服务</span></strong> </p> <p class="ql-long-24357476" style="font-family:"color:#3A4151;font-size:14px;background-color:#FFFFFF;"> <b>源码开放</b> </p> <p class="ql-long-24357476" style="font-family:"color:#3A4151;font-size:14px;background-color:#FFFFFF;"> 课件、课程案例代码完全开放给你,你可以根据所学知识,自行修改、优化 </p> <p class="ql-long-24357476" style="font-family:"color:#3A4151;font-size:14px;background-color:#FFFFFF;"> 下载方式:电脑登录<a href="https://edu.csdn.net/course/detail/26277"></a><a href="https://edu.csdn.net/course/detail/26150"></a><a href="https://edu.csdn.net/course/detail/27286">https://edu.csdn.net/course/detail/27286</a>,点击右下方课程资料、代码、课件等打包下载 </p> <p> <br /> </p>
©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页