From b3a14f46815aa9eb8d637f2a7a9c4d7cdda7f3df Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Sat, 14 Dec 2019 13:36:26 +0800 Subject: [PATCH] fix pointrcnn (#4073) --- PaddleCV/Paddle3D/PointRCNN/README.md | 3 ++- PaddleCV/Paddle3D/PointRCNN/data/kitti_rcnn_reader.py | 9 +++++++++ PaddleCV/Paddle3D/PointRCNN/train.py | 10 +++++++++- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/PaddleCV/Paddle3D/PointRCNN/README.md b/PaddleCV/Paddle3D/PointRCNN/README.md index 4c85db54..842cb86d 100644 --- a/PaddleCV/Paddle3D/PointRCNN/README.md +++ b/PaddleCV/Paddle3D/PointRCNN/README.md @@ -232,7 +232,8 @@ python train.py --cfg=./cfgs/default.yml \ --epoch=30 \ --save_dir=checkpoints \ --rcnn_training_roi_dir=output/detections/data \ - --rcnn_training_feature_dir=output/features + --rcnn_training_feature_dir=output/features \ + --set TRAIN.SPLIT train_aug ``` RCNN模型训练权重默认保存在`checkpoints/rcnn`目录下,可通过`--save_dir`参数指定。 diff --git a/PaddleCV/Paddle3D/PointRCNN/data/kitti_rcnn_reader.py b/PaddleCV/Paddle3D/PointRCNN/data/kitti_rcnn_reader.py index 811a20b2..57367d2c 100644 --- a/PaddleCV/Paddle3D/PointRCNN/data/kitti_rcnn_reader.py +++ b/PaddleCV/Paddle3D/PointRCNN/data/kitti_rcnn_reader.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import os +import signal import logging import multiprocessing import numpy as np @@ -1182,3 +1183,11 @@ class KittiRCNNReader(KittiDataset): return reader + +def _term_reader(signum, frame): + logger.info('pid {} terminated, terminate reader process ' + 'group {}...'.format(os.getpid(), os.getpgrp())) + os.killpg(os.getpgid(os.getpid()), signal.SIGKILL) + +signal.signal(signal.SIGINT, _term_reader) + diff --git a/PaddleCV/Paddle3D/PointRCNN/train.py b/PaddleCV/Paddle3D/PointRCNN/train.py index b7a39ca4..41a6f098 100644 --- a/PaddleCV/Paddle3D/PointRCNN/train.py +++ b/PaddleCV/Paddle3D/PointRCNN/train.py @@ -98,6 +98,11 @@ def parse_args(): type=str, default=None, help='specify the saved features for rcnn training when using rcnn_offline mode') + parser.add_argument( + '--worker_num', + type=int, + default=16, + help='multiprocess reader process num, default 16') parser.add_argument( '--log_interval', type=int, @@ -206,7 +211,10 @@ def train(): fluid.io.save_persistables(exe, path, prog) # get reader - train_reader = kitti_rcnn_reader.get_multiprocess_reader(args.batch_size, train_feeds, drop_last=True) + train_reader = kitti_rcnn_reader.get_multiprocess_reader(args.batch_size, + train_feeds, + proc_num=args.worker_num, + drop_last=True) train_pyreader.decorate_sample_list_generator(train_reader, place) train_stat = Stat() -- GitLab