未验证 提交 b3a14f46 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix pointrcnn (#4073)

上级 15388cb6
...@@ -232,7 +232,8 @@ python train.py --cfg=./cfgs/default.yml \ ...@@ -232,7 +232,8 @@ python train.py --cfg=./cfgs/default.yml \
--epoch=30 \ --epoch=30 \
--save_dir=checkpoints \ --save_dir=checkpoints \
--rcnn_training_roi_dir=output/detections/data \ --rcnn_training_roi_dir=output/detections/data \
--rcnn_training_feature_dir=output/features --rcnn_training_feature_dir=output/features \
--set TRAIN.SPLIT train_aug
``` ```
RCNN模型训练权重默认保存在`checkpoints/rcnn`目录下,可通过`--save_dir`参数指定。 RCNN模型训练权重默认保存在`checkpoints/rcnn`目录下,可通过`--save_dir`参数指定。
......
...@@ -20,6 +20,7 @@ from __future__ import division ...@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import signal
import logging import logging
import multiprocessing import multiprocessing
import numpy as np import numpy as np
...@@ -1182,3 +1183,11 @@ class KittiRCNNReader(KittiDataset): ...@@ -1182,3 +1183,11 @@ class KittiRCNNReader(KittiDataset):
return reader return reader
def _term_reader(signum, frame):
logger.info('pid {} terminated, terminate reader process '
'group {}...'.format(os.getpid(), os.getpgrp()))
os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)
signal.signal(signal.SIGINT, _term_reader)
...@@ -98,6 +98,11 @@ def parse_args(): ...@@ -98,6 +98,11 @@ def parse_args():
type=str, type=str,
default=None, default=None,
help='specify the saved features for rcnn training when using rcnn_offline mode') help='specify the saved features for rcnn training when using rcnn_offline mode')
parser.add_argument(
'--worker_num',
type=int,
default=16,
help='multiprocess reader process num, default 16')
parser.add_argument( parser.add_argument(
'--log_interval', '--log_interval',
type=int, type=int,
...@@ -206,7 +211,10 @@ def train(): ...@@ -206,7 +211,10 @@ def train():
fluid.io.save_persistables(exe, path, prog) fluid.io.save_persistables(exe, path, prog)
# get reader # get reader
train_reader = kitti_rcnn_reader.get_multiprocess_reader(args.batch_size, train_feeds, drop_last=True) train_reader = kitti_rcnn_reader.get_multiprocess_reader(args.batch_size,
train_feeds,
proc_num=args.worker_num,
drop_last=True)
train_pyreader.decorate_sample_list_generator(train_reader, place) train_pyreader.decorate_sample_list_generator(train_reader, place)
train_stat = Stat() train_stat = Stat()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册