From 686616ff1f777e04adf5b212cea4f4e01a422aac Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 18 Mar 2019 15:09:33 +0000 Subject: [PATCH] use Executor and CompiledProgram. --- fluid/PaddleCV/yolov3/README.md | 7 ++- fluid/PaddleCV/yolov3/README_cn.md | 7 ++- fluid/PaddleCV/yolov3/config.py | 10 +--- fluid/PaddleCV/yolov3/models/yolov3.py | 13 +---- fluid/PaddleCV/yolov3/train.py | 69 +++++--------------------- fluid/PaddleCV/yolov3/utility.py | 9 ++-- 6 files changed, 25 insertions(+), 90 deletions(-) diff --git a/fluid/PaddleCV/yolov3/README.md b/fluid/PaddleCV/yolov3/README.md index d52adf34..be2d8ea1 100644 --- a/fluid/PaddleCV/yolov3/README.md +++ b/fluid/PaddleCV/yolov3/README.md @@ -32,7 +32,7 @@ YOLOv3 structure YOLOv3 networks are composed of base feature extraction network, multi-scale feature fusion layers, and output layers. -1. Feature extraction network: YOLOv3 uses [DarkNet53](https://arxiv.org/abs/1612.08242) for feature extracion,Darknet53 uses a full convolution structure, replacing the pooling layer with a convolution operation of step size 2, and adding Residual-block to avoid gradient dispersion when the number of network layers is too deep. +1. Feature extraction network: YOLOv3 uses [DarkNet53](https://arxiv.org/abs/1612.08242) for feature extracion. Darknet53 uses a full convolution structure, replacing the pooling layer with a convolution operation of step size 2, and adding Residual-block to avoid gradient dispersion when the number of network layers is too deep. 2. Feature fusion layer. In order to solve the problem that the previous YOLO version is not sensitive to small objects, YOLOv3 uses three different scale feature maps for target detection, which are 13\*13, 26\*26, 52\*52, respectively, for detecting large, medium and small objects. The feature fusion layer selects the three scale feature maps produced by DarkNet as input, and draws on the idea of FPN (feature pyramid networks) to fuse the feature maps of each scale through a series of convolutional layers and upsampling. @@ -71,9 +71,8 @@ Please make sure that pre-trained model is downloaded and loaded correctly, othe To train the model, [cocoapi](https://github.com/cocodataset/cocoapi) is needed. Install the cocoapi: - # COCOAPI=/path/to/clone/cocoapi - git clone https://github.com/cocodataset/cocoapi.git $COCOAPI - cd $COCOAPI/PythonAPI + git clone https://github.com/cocodataset/cocoapi.git + cd PythonAPI # if cython is not installed pip install Cython # Install into global site-packages diff --git a/fluid/PaddleCV/yolov3/README_cn.md b/fluid/PaddleCV/yolov3/README_cn.md index 60a4a3cc..5be4382f 100644 --- a/fluid/PaddleCV/yolov3/README_cn.md +++ b/fluid/PaddleCV/yolov3/README_cn.md @@ -33,7 +33,7 @@ YOLOv3网络结构 YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层和输出层组成。 -1. 特征提取网络。YOLOv3使用 [DarkNet53](https://arxiv.org/abs/1612.08242)作为特征提取网络,DarkNet53 基本采用了全卷积网络,用步长为2的卷积操作替代了池化层,同时添加了 Residual 单元,避免在网络层数过深时发生梯度弥散。 +1. 特征提取网络。YOLOv3使用 [DarkNet53](https://arxiv.org/abs/1612.08242)作为特征提取网络:DarkNet53 基本采用了全卷积网络,用步长为2的卷积操作替代了池化层,同时添加了 Residual 单元,避免在网络层数过深时发生梯度弥散。 2. 特征融合层。为了解决之前YOLO版本对小目标不敏感的问题,YOLOv3采用了3个不同尺度的特征图来进行目标检测,分别为13\*13,26\*26,52\*52,用来检测大、中、小三种目标。特征融合层选取 DarkNet 产出的三种尺度特征图作为输入,借鉴了FPN(feature pyramid networks)的思想,通过一系列的卷积层和上采样对各尺度的特征图进行融合。 @@ -73,9 +73,8 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层 训练前需要首先下载[cocoapi](https://github.com/cocodataset/cocoapi): - # COCOAPI=/path/to/clone/cocoapi - git clone https://github.com/cocodataset/cocoapi.git $COCOAPI - cd $COCOAPI/PythonAPI + git clone https://github.com/cocodataset/cocoapi.git + cd PythonAPI # if cython is not installed pip install Cython # Install into global site-packages diff --git a/fluid/PaddleCV/yolov3/config.py b/fluid/PaddleCV/yolov3/config.py index 1f082c90..14e8ef42 100644 --- a/fluid/PaddleCV/yolov3/config.py +++ b/fluid/PaddleCV/yolov3/config.py @@ -39,7 +39,7 @@ _C.max_box_num = 50 # # valid score threshold to include boxes -_C.valid_thresh = 0.01 +_C.valid_thresh = 0.005 # threshold vale for box non-max suppression _C.nms_thresh = 0.45 @@ -69,7 +69,7 @@ _C.pixel_stds = [0.229, 0.224, 0.225] # # batch size -_C.batch_size = 64 +_C.batch_size = 8 # derived learning rate the to get the final learning rate. _C.learning_rate = 0.001 @@ -98,15 +98,9 @@ _C.momentum = 0.9 # support both CPU and GPU _C.use_gpu = True -# Whether use parallel -_C.parallel = True - # Class number _C.class_num = 80 -# support pyreader -_C.use_pyreader = True - # dataset path _C.train_file_list = 'annotations/instances_train2017.json' _C.train_data_dir = 'train2017' diff --git a/fluid/PaddleCV/yolov3/models/yolov3.py b/fluid/PaddleCV/yolov3/models/yolov3.py index 99ce2769..54848124 100644 --- a/fluid/PaddleCV/yolov3/models/yolov3.py +++ b/fluid/PaddleCV/yolov3/models/yolov3.py @@ -56,10 +56,8 @@ def upsample(input, scale=2,name=None): class YOLOv3(object): def __init__(self, is_train=True, - use_pyreader=True, use_random=True): self.is_train = is_train - self.use_pyreader = use_pyreader self.use_random = use_random self.outputs = [] self.losses = [] @@ -69,7 +67,7 @@ class YOLOv3(object): def build_input(self): self.image_shape = [3, cfg.input_size, cfg.input_size] - if self.use_pyreader and self.is_train: + if self.is_train: self.py_reader = fluid.layers.py_reader( capacity=64, shapes = [[-1] + self.image_shape, [-1, cfg.max_box_num, 4], [-1, cfg.max_box_num], [-1, cfg.max_box_num]], @@ -81,15 +79,6 @@ class YOLOv3(object): self.image = fluid.layers.data( name='image', shape=self.image_shape, dtype='float32' ) - self.gtbox = fluid.layers.data( - name='gtbox', shape=[cfg.max_box_num, 4], dtype='float32' - ) - self.gtlabel = fluid.layers.data( - name='gtlabel', shape=[cfg.max_box_num], dtype='int32' - ) - self.gtscore = fluid.layers.data( - name='gtscore', shape=[cfg.max_box_num], dtype='float32' - ) self.im_shape = fluid.layers.data( name="im_shape", shape=[2], dtype='int32') self.im_id = fluid.layers.data( diff --git a/fluid/PaddleCV/yolov3/train.py b/fluid/PaddleCV/yolov3/train.py index 2ebc536a..b4bbcfe1 100644 --- a/fluid/PaddleCV/yolov3/train.py +++ b/fluid/PaddleCV/yolov3/train.py @@ -42,7 +42,7 @@ def train(): if not os.path.exists(cfg.model_save_dir): os.makedirs(cfg.model_save_dir) - model = YOLOv3(use_pyreader=cfg.use_pyreader) + model = YOLOv3() model.build_model() input_size = cfg.input_size loss = model.loss() @@ -69,44 +69,37 @@ def train(): momentum=cfg.momentum) optimizer.minimize(loss) - fluid.memory_optimize(fluid.default_main_program()) - place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() - base_exe = fluid.Executor(place) - base_exe.run(fluid.default_startup_program()) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) if cfg.pretrain: def if_exist(var): return os.path.exists(os.path.join(cfg.pretrain, var.name)) - fluid.io.load_vars(base_exe, cfg.pretrain, predicate=if_exist) + fluid.io.load_vars(exe, cfg.pretrain, predicate=if_exist) - if cfg.parallel: - exe = fluid.ParallelExecutor( use_cuda=bool(cfg.use_gpu), loss_name=loss.name) - else: - exe = base_exe + compile_program = fluid.compiler.CompiledProgram( + fluid.default_main_program()).with_data_parallel( + loss_name=loss.name) random_sizes = [cfg.input_size] if cfg.random_shape: random_sizes = [32 * i for i in range(10, 20)] mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter - if cfg.use_pyreader: - train_reader = reader.train(input_size, batch_size=cfg.batch_size/devices_num, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess) - py_reader = model.py_reader - py_reader.decorate_paddle_reader(train_reader) - else: - train_reader = reader.train(input_size, batch_size=cfg.batch_size, shuffle=True, mixup_iter=mixup_iter, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess) - feeder = fluid.DataFeeder(place=place, feed_list=model.feeds()) + train_reader = reader.train(input_size, batch_size=cfg.batch_size, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess) + py_reader = model.py_reader + py_reader.decorate_paddle_reader(train_reader) def save_model(postfix): model_path = os.path.join(cfg.model_save_dir, postfix) if os.path.isdir(model_path): shutil.rmtree(model_path) - fluid.io.save_persistables(base_exe, model_path) + fluid.io.save_persistables(exe, model_path) fetch_list = [loss] - def train_loop_pyreader(): + def train_loop(): py_reader.start() smoothed_loss = SmoothedValue() try: @@ -137,43 +130,7 @@ def train(): except fluid.core.EOFException: py_reader.reset() - def train_loop(): - start_time = time.time() - prev_start_time = start_time - start = start_time - smoothed_loss = SmoothedValue() - snapshot_loss = 0 - snapshot_time = 0 - for iter_id, data in enumerate(train_reader()): - iter_id += cfg.start_iter - prev_start_time = start_time - start_time = time.time() - losses = exe.run(fetch_list=[v.name for v in fetch_list], - feed=feeder.feed(data)) - smoothed_loss.add_value(losses[0]) - snapshot_loss += losses[0] - snapshot_time += start_time - prev_start_time - lr = np.array(fluid.global_scope().find_var('learning_rate') - .get_tensor()) - print("Iter {:d}, lr: {:.6f}, loss: {:.4f}, time {:.5f}".format( - iter_id, lr[0], smoothed_loss.get_mean_value(), start_time - prev_start_time)) - sys.stdout.flush() - - if (iter_id + 1) % cfg.snapshot_iter == 0: - save_model("model_iter{}".format(iter_id)) - print("Snapshot {} saved, average loss: {}, average time: {}".format( - iter_id + 1, snapshot_loss / float(cfg.snapshot_iter), - snapshot_time / float(cfg.snapshot_iter))) - snapshot_loss = 0 - snapshot_time = 0 - if (iter_id + 1) == cfg.max_iter: - print("Finish iter {}".format(iter_id)) - break - - if cfg.use_pyreader: - train_loop_pyreader() - else: - train_loop() + train_loop() save_model('model_final') diff --git a/fluid/PaddleCV/yolov3/utility.py b/fluid/PaddleCV/yolov3/utility.py index 436b2a98..d28f6a86 100644 --- a/fluid/PaddleCV/yolov3/utility.py +++ b/fluid/PaddleCV/yolov3/utility.py @@ -94,7 +94,6 @@ def parse_args(): add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable # ENV - add_arg('parallel', bool, True, "Whether use parallel.") add_arg('use_gpu', bool, True, "Whether use GPU.") add_arg('model_save_dir', str, 'checkpoints', "The path to save model.") add_arg('pretrain', str, 'weights/darknet53', "The pretrain model path.") @@ -102,12 +101,10 @@ def parse_args(): add_arg('dataset', str, 'coco2017', "Dataset: coco2014, coco2017.") add_arg('class_num', int, 80, "Class number.") add_arg('data_dir', str, 'dataset/coco', "The data root path.") - add_arg('use_pyreader', bool, True, "Use pyreader.") - add_arg('use_profile', bool, False, "Whether use profiler.") add_arg('start_iter', int, 0, "Start iteration.") add_arg('use_multiprocess', bool, True, "add multiprocess.") #SOLVER - add_arg('batch_size', int, 64, "Learning rate.") + add_arg('batch_size', int, 8, "Mini-batch size per device.") add_arg('learning_rate', float, 0.001, "Learning rate.") add_arg('max_iter', int, 500200, "Iter number.") add_arg('snapshot_iter', int, 2000, "Save model every snapshot stride.") @@ -122,8 +119,8 @@ def parse_args(): add_arg('nms_posk', int, 100, "The number of boxes of NMS output.") add_arg('debug', bool, False, "Debug mode") # SINGLE EVAL AND DRAW - add_arg('image_path', str, 'image', "The image path used to inference and visualize.") - add_arg('image_name', str, None, "The single image used to inference and visualize. None to inference all images in image_path") + add_arg('image_path', str, 'image', "The image path used to inference and visualize.") + add_arg('image_name', str, None, "The single image used to inference and visualize. None to inference all images in image_path") add_arg('draw_thresh', float, 0.5, "Confidence score threshold to draw prediction box in image in debug mode") # yapf: enable args = parser.parse_args() -- GitLab