博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
使用Faster RCNN训练自己的数据集
阅读量:4127 次
发布时间:2019-05-25

本文共 3331 字,大约阅读时间需要 11 分钟。

使用Faster RCNN训练自己的数据集,过程不太顺利,踩坑数次,所以把流程记录一下。

所使用的代码版本:

1.源码及环境配置

原Github版本使用的Pytorch==0.4.0,但是看了网上的博客记录这个版本有较多错误无法解决,建议使用Pytorch==1.0.0及以上版本;

  • 源码

Pytorch0.4.0版源码:

Pytorch1.0.0版源码:

  • 环境配置

Ubuntu16.04

Python==3.6 + Pytorch==1.2.0

由于CUDA版本向下兼容,所以这里不作特殊说明.

  • 使用Anaconda安装虚拟环境
conda create -n faster-rcnn python=3.6

在中找到对应的Pytorch与torchvision版本:

# CUDA 10.0conda install pytorch==1.2.0 torchvision==0.4.0

如图所示,此时cudatoolkit=10.0也会自动安装。

  • 安装其他环境依赖
pip install -r requirements.txt

2.预训练模型编译

  • 新建文件夹

(注:本文将原文件夹重命名为faster-rcnn)在文件夹中新建data文件夹

cd faster-rcnn && mkdir data

data文件夹中新建pretrained_model文件夹

mkdir pretrained_model
  • 下载预训练模型VGG16与ResNet-101

预训练模型VGG16:

预训练模型ResNet-101:

将下载好的预训练模型放到pretrained_model文件夹中

  • 执行编译
cd libpython setup.py build developcd ..

编译完成,如图所示

如果执行编译后,训练自己的数据集仍然报错:

ImportError: cannot import name '_mask'

则是缺少COCO API,需要执行以下指令

cd datagit clone https://github.com/pdollar/coco.git cd coco/PythonAPImakecd ../../..

如图所示

可以看到'_mask.o'已经编译成功

  • Scipy降版本

使用pip查看已经安装的Python库

pip list

可以看到其中Scipy与Pillow版本分别问scipy==1.5.4与Pillow==8.2.0,由于Scipy版本自身的变动原因,需要对Scipy进行降版本,否则在训练中会报错

ImportError: cannot import name 'imread'

首先卸载以上两个版本

pip uninstall scipypip uninstall pillow

然后安装指定版本即可

pip install scipy == 1.2.1pip install pillow == 6.1.0

3.数据集准备

本文训练的数据集是VOC格式

  • 新建文件夹
cd data && mkdir VOCdevkit2007

将之前标注得到的数据集放到VOCdevkit2007文件夹中,并重命名为VOC2007

VOC文件夹格式

---VOC2007    ------Annotations    ------ImagesSet          -------Main                 ----trainval.txt                 ----train.txt                 ----val.txt                 ----test.txt    ------JPEGImages

本文训练集/验证集/测试集比例为6:2:2

  • 修改数据集类别

首先进入到文件夹

cd lib/datasets/vim pascal_voc.py

将第48行'self._classes'改成自己的类别

# beforeself._classes = ('__background__',  # always index 0                         'aeroplane', 'bicycle', 'bird', 'boat',                         'bottle', 'bus', 'car', 'cat', 'chair',                         'cow', 'diningtable', 'dog', 'horse',                         'motorbike', 'person', 'pottedplant',                         'sheep', 'sofa', 'train', 'tvmonitor')# afterself._classes = ('__background__',  # always index 0                         'xxx', 'yyy', 'zzz')

将第243行'cls'中的.lower()去掉,这是由于有些标注数据集类别中存在大写,在训练过程中会报错(或者全部使用小写标注)

# beforecls = self._class_to_ind[obj.find('name').text.lower().strip()]# aftercls = self._class_to_ind[obj.find('name').text.strip()]

以上就是我遇到的问题,都解决之后就可以开始训练了

4.训练

训练指令

# train$ CUDA_VISIBLE_DEVICES=1 python trainval_net.py --dataset pascal_voc --net vgg16 --bs 16 --nw 4 --cuda --epochs 100

参数解释

CUDA_VISIBLE_DEVICES    # GPU ID,即使用哪块GPU进行训练-dataset    # 数据集类型,就以pascal-voc为例-net    # 所使用的backbone网络,以vgg16为例–bs    # 指的batch size,以16为例,显存不够就调小bs–nw    # 指的是worker number,取决于你的Gpu能力,以4为例,稍微差一些的gpu可以选小一点的值–cuda    # 指的是使用GPU训练-epochs    # 此处设为100,估计需要跑很久

训好的model会存到models文件夹中

等待训练完成ing

5.测试

训练完成后,首先进入文件夹

cd model/vgg16/pascal_voc# orcd model/res101/pascal_voclscd ../../../

查看训练完成之后保存的模型

在文件夹主目录中输入指令进行测试

python test_net.py --dataset pascal_voc --net vgg16 --checksession $SESSION --checkepoch $EPOCH --checkpoint $CHECKPOINT --cuda# 批量测试$  python test_net.py --dataset pascal_voc --net vgg16 --checksession 1 --checkepoch 50 --checkpoint 460 --cuda

其中,“checksession”、“checkepoch”与“checkpoint”这三个参数的含义为:如果选择训练得到的模型“faster_rcnn_1_50_460.pth”,则checksession=1,checkepoch=50,checkpoint=460;这样可以加载到想要的模型。以此类推。可以看到模型已经加载完成:

输出结果:

参考博客

转载地址:http://vgrpi.baihongyu.com/

你可能感兴趣的文章
ideas about sharing software
查看>>
different aspects for software
查看>>
To do list
查看>>
Study of Source code
查看>>
如何使用BBC英语学习频道
查看>>
spring事务探索
查看>>
浅谈Spring声明式事务管理ThreadLocal和JDKProxy
查看>>
初识xsd
查看>>
java 设计模式-职责型模式
查看>>
构造型模式
查看>>
svn out of date 无法更新到最新版本
查看>>
java杂记
查看>>
RunTime.getRuntime().exec()
查看>>
Oracle 分组排序函数
查看>>
删除weblogic 域
查看>>
VMware Workstation 14中文破解版下载(附密钥)(笔记)
查看>>
日志框架学习
查看>>
日志框架学习2
查看>>
SVN-无法查看log,提示Want to go offline,时间显示1970问题,error主要是 url中 有一层的中文进行了2次encode
查看>>
NGINX
查看>>