diff --git a/fluid/PaddleCV/yolov3/config.py b/fluid/PaddleCV/yolov3/config.py index 095429b2098bb0dce098b5dbda23e87343b8029c..07adfdb71d72f5ffb46ad229d8fc4d345cea271d 100644 --- a/fluid/PaddleCV/yolov3/config.py +++ b/fluid/PaddleCV/yolov3/config.py @@ -24,10 +24,6 @@ cfg = _C # Training options # -# batch - -_C.batch = 8 - # Snapshot period _C.snapshot_iter = 2000 @@ -72,6 +68,9 @@ _C.pixel_stds = [0.229, 0.224, 0.225] # SOLVER options # +# batch size +_C.batch_size = 64 + # derived learning rate the to get the final learning rate. _C.learning_rate = 0.001 @@ -92,9 +91,7 @@ _C.weight_decay = 0.0005 # momentum with SGD _C.momentum = 0.9 -# decay -_C.decay = 0.0005 - +# # ENV options # diff --git a/fluid/PaddleCV/yolov3/config_parser.py b/fluid/PaddleCV/yolov3/config_parser.py deleted file mode 100644 index 4e96bfb9f12ad8c73036083df0c146073e1cd0d2..0000000000000000000000000000000000000000 --- a/fluid/PaddleCV/yolov3/config_parser.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -LAYER_TYPES = [ - "net", - "convolutional", - "shortcut", - "route", - "upsample", - "maxpool", - "yolo", - ] - -class ConfigPaser(object): - def __init__(self, config_path): - self.config_path = config_path - - def parse(self): - with open(self.config_path) as cfg_file: - model_defs = [] - for line in cfg_file.readlines(): - line = line.strip() - if len(line) == 0: - continue - if line.startswith('#'): - continue - if line.startswith('['): - layer_type = line[1:-1].strip() - if layer_type not in LAYER_TYPES: - print("Unknow config layer type: ", layer_type) - return None - model_defs.append({}) - model_defs[-1]['type'] = layer_type - else: - key, value = line.split('=') - model_defs[-1][key.strip()] = value.strip() - - return model_defs - - diff --git a/fluid/PaddleCV/yolov3/edict.py b/fluid/PaddleCV/yolov3/edict.py new file mode 100644 index 0000000000000000000000000000000000000000..415cc6f7d6514a2fa79fb2a75bb23d8b8fd2fe72 --- /dev/null +++ b/fluid/PaddleCV/yolov3/edict.py @@ -0,0 +1,37 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + elif name in self: + return self[name] + else: + raise AttributeError(name) + + def __setattr__(self, name, value): + if name in self.__dict__: + self.__dict__[name] = value + else: + self[name] = value diff --git a/fluid/PaddleCV/yolov3/eval.py b/fluid/PaddleCV/yolov3/eval.py index d56a7454b6cfc997ccb9c0ddbe76d8776a72514b..a9e742ad0a83009a8db742a65a50bc814617ee7e 100644 --- a/fluid/PaddleCV/yolov3/eval.py +++ b/fluid/PaddleCV/yolov3/eval.py @@ -17,13 +17,13 @@ from __future__ import division from __future__ import print_function import os import time +import json import numpy as np import paddle import paddle.fluid as fluid import reader -import models.yolov3 as models +from models.yolov3 import YOLOv3 from utility import print_arguments, parse_args -import json from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval, Params from config import cfg @@ -39,11 +39,9 @@ def eval(): if not os.path.exists('output'): os.mkdir('output') - model = models.YOLOv3(cfg.model_cfg_path, is_train=False) + model = YOLOv3(cfg.model_cfg_path, is_train=False) model.build_model() outputs = model.get_pred() - yolo_anchors = model.get_yolo_anchors() - yolo_classes = model.get_yolo_classes() place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) # yapf: disable @@ -52,7 +50,7 @@ def eval(): return os.path.exists(os.path.join(cfg.weights, var.name)) fluid.io.load_vars(exe, cfg.weights, predicate=if_exist) # yapf: enable - input_size = model.get_input_size() + input_size = cfg.input_size test_reader = reader.test(input_size, 1) label_names, label_ids = reader.get_label_infos() if cfg.debug: diff --git a/fluid/PaddleCV/yolov3/infer.py b/fluid/PaddleCV/yolov3/infer.py index 572dbd82d6988de3811384d7c57b8275f20f21a5..07650eb6d1ea4564a6deccb78d9b8020e0ea22a6 100644 --- a/fluid/PaddleCV/yolov3/infer.py +++ b/fluid/PaddleCV/yolov3/infer.py @@ -6,9 +6,7 @@ import paddle.fluid as fluid import box_utils import reader from utility import print_arguments, parse_args -import models.yolov3 as models -# from coco_reader import load_label_names -import json +from models.yolov3 import YOLOv3 from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval, Params from config import cfg @@ -19,12 +17,10 @@ def infer(): if not os.path.exists('output'): os.mkdir('output') - model = models.YOLOv3(cfg.model_cfg_path, is_train=False) + model = YOLOv3(cfg.model_cfg_path, is_train=False) model.build_model() outputs = model.get_pred() - input_size = model.get_input_size() - yolo_anchors = model.get_yolo_anchors() - yolo_classes = model.get_yolo_classes() + input_size = cfg.input_size place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) # yapf: disable diff --git a/fluid/PaddleCV/yolov3/learning_rate.py b/fluid/PaddleCV/yolov3/learning_rate.py index 6a44cfe8d8e0dcd9faf5d439f5501318a8937850..a4c9dfb6949054c86c8f307b0bfaa3eea61ee51c 100644 --- a/fluid/PaddleCV/yolov3/learning_rate.py +++ b/fluid/PaddleCV/yolov3/learning_rate.py @@ -22,7 +22,7 @@ from paddle.fluid.layers import control_flow def exponential_with_warmup_decay(learning_rate, boundaries, values, - warmup_iter, warmup_factor, start_step): + warmup_iter, warmup_factor): global_step = lr_scheduler._decay_step_counter() lr = fluid.layers.create_global_var( diff --git a/fluid/PaddleCV/yolov3/models.py b/fluid/PaddleCV/yolov3/models.py deleted file mode 100644 index 379ad6b68e3c209dcead3adba62adddd251913ea..0000000000000000000000000000000000000000 --- a/fluid/PaddleCV/yolov3/models.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. - -from __future__ import division -from __future__ import print_function - -import paddle.fluid as fluid -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.initializer import Constant -from paddle.fluid.initializer import Normal -from paddle.fluid.regularizer import L2Decay - -import box_utils -from config.config_parser import ConfigPaser -from config.config import cfg - - -def conv_bn_layer(input, - ch_out, - filter_size, - stride, - padding, - act=None, - bn=False, - name=None, - is_train=True): - if bn: - out = fluid.layers.conv2d( - input=input, - num_filters=ch_out, - filter_size=filter_size, - stride=stride, - padding=padding, - act=None, - param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02), - name=name + "_weights"), - bias_attr=False, - name=name + '.conv2d.output.1') - - bn_name = "bn" + name[4:] - - out = fluid.layers.batch_norm(input=out, - act=None, - is_test=not is_train, - param_attr=ParamAttr( - initializer=fluid.initializer.Normal(0., 0.02), - regularizer=L2Decay(0.), - name=bn_name + '_scale'), - bias_attr=ParamAttr( - initializer=fluid.initializer.Constant(0.0), - regularizer=L2Decay(0.), - name=bn_name + '_offset'), - moving_mean_name=bn_name+'_mean', - moving_variance_name=bn_name+'_var', - name=bn_name+'.output') - else: - out = fluid.layers.conv2d( - input=input, - num_filters=ch_out, - filter_size=filter_size, - stride=stride, - padding=padding, - act=None, - param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02), - name=name + "_weights"), - bias_attr=ParamAttr(initializer=fluid.initializer.Constant(0.0), - regularizer=L2Decay(0.), - name=name + "_bias"), - name=name + '.conv2d.output.1') - - if act == 'relu': - out = fluid.layers.relu(x=out) - if act == 'leaky': - out = fluid.layers.leaky_relu(x=out, alpha=0.1) - return out - - -class YOLOv3(object): - def __init__(self, - model_cfg_path, - is_train=True, - use_pyreader=True, - use_random=True): - self.model_cfg_path = model_cfg_path - self.config_parser = ConfigPaser(model_cfg_path) - self.is_train = is_train - self.use_pyreader = use_pyreader - self.use_random = use_random - self.outputs = [] - self.losses = [] - self.boxes = [] - self.scores = [] - self.downsample = 32 - - def build_model(self): - model_defs = self.config_parser.parse() - if model_defs is None: - return None - - self.hyperparams = model_defs.pop(0) - assert self.hyperparams['type'].lower() == "net", \ - "net config params should be given in the first segment named 'net'" - self.img_height = cfg.input_size - self.img_width = cfg.input_size - - self.build_input() - - out = self.image - layer_outputs = [] - self.yolo_layer_defs = [] - self.yolo_anchors = [] - self.yolo_classes = [] - self.outputs = [] - for i, layer_def in enumerate(model_defs): - if layer_def['type'] == 'convolutional': - bn = layer_def.get('batch_normalize', 0) - ch_out = int(layer_def['filters']) - filter_size = int(layer_def['size']) - stride = int(layer_def['stride']) - padding = (filter_size - 1) // 2 if int(layer_def['pad']) else 0 - act = layer_def['activation'] - out = conv_bn_layer( - input=out, - ch_out=ch_out, - filter_size=filter_size, - stride=stride, - padding=padding, - act=act, - bn=bool(bn), - name="conv"+str(i), - is_train=self.is_train) - - elif layer_def['type'] == 'shortcut': - layer_from = int(layer_def['from']) - out = fluid.layers.elementwise_add( - x=out, - y=layer_outputs[layer_from], - name="res"+str(i)) - - elif layer_def['type'] == 'route': - layers = map(int, layer_def['layers'].split(",")) - out = fluid.layers.concat( - input=[layer_outputs[i] for i in layers], - axis=1) - - elif layer_def['type'] == 'upsample': - scale = int(layer_def['stride']) - - # get dynamic upsample output shape - shape_nchw = fluid.layers.shape(out) - shape_hw = fluid.layers.slice(shape_nchw, axes=[0], \ - starts=[2], ends=[4]) - shape_hw.stop_gradient = True - in_shape = fluid.layers.cast(shape_hw, dtype='int32') - out_shape = in_shape * scale - out_shape.stop_gradient = True - - # reisze by actual_shape - out = fluid.layers.resize_nearest( - input=out, - scale=scale, - actual_shape=out_shape, - name="upsample"+str(i)) - - elif layer_def['type'] == 'maxpool': - pool_size = int(layer_def['size']) - pool_stride = int(layer_def['stride']) - pool_padding = 0 - if pool_stride == 1 and pool_size == 2: - pool_padding = 1 - out = fluid.layers.pool2d( - input=out, - pool_type='max', - pool_size=pool_size, - pool_stride=pool_stride, - pool_padding=pool_padding) - - elif layer_def['type'] == 'yolo': - self.yolo_layer_defs.append(layer_def) - self.outputs.append(out) - - anchor_mask = map(int, layer_def['mask'].split(',')) - anchors = map(int, layer_def['anchors'].split(',')) - mask_anchors = [] - for m in anchor_mask: - mask_anchors.append(anchors[2 * m]) - mask_anchors.append(anchors[2 * m + 1]) - self.yolo_anchors.append(mask_anchors) - class_num = int(layer_def['classes']) - self.yolo_classes.append(class_num) - - if self.is_train: - ignore_thresh = float(layer_def['ignore_thresh']) - loss = fluid.layers.yolov3_loss( - x=out, - gtbox=self.gtbox, - gtlabel=self.gtlabel, - gtscore=self.gtscore, - anchors=anchors, - anchor_mask=anchor_mask, - class_num=class_num, - ignore_thresh=ignore_thresh, - downsample_ratio=self.downsample, - use_label_smooth=cfg.label_smooth, - name="yolo_loss"+str(i)) - self.losses.append(fluid.layers.reduce_mean(loss)) - else: - boxes, scores = fluid.layers.yolo_box( - x=out, - img_size=self.im_shape, - anchors=mask_anchors, - class_num=class_num, - conf_thresh=cfg.valid_thresh, - downsample_ratio=self.downsample, - name="yolo_box"+str(i)) - self.boxes.append(boxes) - self.scores.append(fluid.layers.transpose(scores, perm=[0, 2, 1])) - - self.downsample //= 2 - - layer_outputs.append(out) - - def loss(self): - return sum(self.losses) - - def get_pred(self): - yolo_boxes = fluid.layers.concat(self.boxes, axis=1) - yolo_scores = fluid.layers.concat(self.scores, axis=2) - return fluid.layers.multiclass_nms( - bboxes=yolo_boxes, - scores=yolo_scores, - score_threshold=cfg.valid_thresh, - nms_top_k=cfg.nms_topk, - keep_top_k=cfg.nms_posk, - nms_threshold=cfg.nms_thresh, - background_label=-1, - name="multiclass_nms") - - def get_yolo_anchors(self): - return self.yolo_anchors - - def get_yolo_classes(self): - return self.yolo_classes - - def build_input(self): - self.image_shape = [3, self.img_height, self.img_width] - if self.use_pyreader and self.is_train: - self.py_reader = fluid.layers.py_reader( - capacity=64, - shapes = [[-1] + self.image_shape, [-1, cfg.max_box_num, 4], [-1, cfg.max_box_num], [-1, cfg.max_box_num]], - lod_levels=[0, 0, 0, 0], - dtypes=['float32'] * 2 + ['int32'] + ['float32'], - use_double_buffer=True) - self.image, self.gtbox, self.gtlabel, self.gtscore = fluid.layers.read_file(self.py_reader) - else: - self.image = fluid.layers.data( - name='image', shape=self.image_shape, dtype='float32' - ) - self.gtbox = fluid.layers.data( - name='gtbox', shape=[cfg.max_box_num, 4], dtype='float32' - ) - self.gtlabel = fluid.layers.data( - name='gtlabel', shape=[cfg.max_box_num], dtype='int32' - ) - self.gtscore = fluid.layers.data( - name='gtscore', shape=[cfg.max_box_num], dtype='float32' - ) - self.im_shape = fluid.layers.data( - name="im_shape", shape=[2], dtype='int32') - self.im_id = fluid.layers.data( - name="im_id", shape=[1], dtype='int32') - - def feeds(self): - if not self.is_train: - return [self.image, self.im_id, self.im_shape] - return [self.image, self.gtbox, self.gtlabel, self.gtscore] - - def get_hyperparams(self): - return self.hyperparams - - def get_input_size(self): - return cfg.input_size - - diff --git a/fluid/PaddleCV/yolov3/models/__init__.py b/fluid/PaddleCV/yolov3/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fluid/PaddleCV/yolov3/models/darknet.py b/fluid/PaddleCV/yolov3/models/darknet.py index f57c4a73d1616c336970854f16af6aa034e3d7f7..f61a9ae227ba213c2e5f67d634c0969174c4b3e2 100644 --- a/fluid/PaddleCV/yolov3/models/darknet.py +++ b/fluid/PaddleCV/yolov3/models/darknet.py @@ -1,102 +1,99 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. - -import paddle.fluid as fluid -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.initializer import Constant -from paddle.fluid.regularizer import L2Decay -from config import cfg - -def conv_bn_layer(input, - ch_out, - filter_size, - stride, - padding, - act='leaky', - i=0): - conv1 = fluid.layers.conv2d( - input=input, - num_filters=ch_out, - filter_size=filter_size, - stride=stride, - padding=padding, - act=None, - param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02), - name="conv" + str(i)+"_weights"), - bias_attr=False) - - bn_name = "bn" + str(i) - - out = fluid.layers.batch_norm( - input=conv1, - act=None, - is_test=True, - param_attr=ParamAttr( - initializer=fluid.initializer.Normal(0., 0.02), - regularizer=L2Decay(0.), - name=bn_name + '_scale'), - bias_attr=ParamAttr( - initializer=fluid.initializer.Constant(0.0), - regularizer=L2Decay(0.), - name=bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_var') - if act == 'leaky': - out = fluid.layers.leaky_relu(x=out, alpha=0.1) - return out - -def basicblock(input, ch_out, stride,i): - """ - channel: convolution channels for 1x1 conv - """ - conv1 = conv_bn_layer(input, ch_out, 1, 1, 0, i=i) - conv2 = conv_bn_layer(conv1, ch_out*2, 3, 1, 1, i=i+1) - out = fluid.layers.elementwise_add(x=input, y=conv2, act=None,name="res"+str(i+2)) - return out - -def layer_warp(block_func, input, ch_out, count, stride,i): - res_out = block_func(input, ch_out, stride, i=i) - for j in range(1, count): - res_out = block_func(res_out, ch_out, 1 ,i=i+j*3) - return res_out - -DarkNet_cfg = { - 53: ([1,2,8,8,4],basicblock) -} - -# num_filters = [32, 64, 128, 256, 512, 1024] - -def add_DarkNet53_conv_body(body_input): - - stages, block_func = DarkNet_cfg[53] - stages = stages[0:5] - conv1 = conv_bn_layer( - body_input, ch_out=32, filter_size=3, stride=1, padding=1, act="leaky",i=0) - conv2 = conv_bn_layer( - conv1, ch_out=64, filter_size=3, stride=2, padding=1, act="leaky", i=1) - block3 = layer_warp(block_func, conv2, 32, stages[0], 1, i=2) - downsample3 = conv_bn_layer( - block3, ch_out=128, filter_size=3, stride=2, padding=1, i=5) - block4 = layer_warp(block_func, downsample3, 64, stages[1], 1, i=6) - downsample4 = conv_bn_layer( - block4, ch_out=256, filter_size=3, stride=2, padding=1, i=12) - block5 = layer_warp(block_func, downsample4, 128, stages[2], 1,i=13) - downsample5 = conv_bn_layer( - block5, ch_out=512, filter_size=3, stride=2, padding=1, i=37) - block6 = layer_warp(block_func, downsample5, 256, stages[3], 1, i=38) - downsample6 = conv_bn_layer( - block6, ch_out=1024, filter_size=3, stride=2, padding=1, i=62) - block7 = layer_warp(block_func, downsample6, 512, stages[4], 1,i=63) - return block7,block6,block5 - +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Constant +from paddle.fluid.regularizer import L2Decay + +def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + act='leaky', + is_test=True, + name=None): + conv1 = fluid.layers.conv2d( + input=input, + num_filters=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + act=None, + param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02), + name=name+".conv.weights"), + bias_attr=False) + + bn_name = name + ".bn" + out = fluid.layers.batch_norm( + input=conv1, + act=None, + is_test=is_test, + param_attr=ParamAttr( + initializer=fluid.initializer.Normal(0., 0.02), + regularizer=L2Decay(0.), + name=bn_name + '.scale'), + bias_attr=ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=L2Decay(0.), + name=bn_name + '.offset'), + moving_mean_name=bn_name + '.mean', + moving_variance_name=bn_name + '.var') + if act == 'leaky': + out = fluid.layers.leaky_relu(x=out, alpha=0.1) + return out + +def downsample(input, ch_out, filter_size=3, stride=2, padding=1, is_test=True, name=None): + return conv_bn_layer(input, + ch_out=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + is_test=is_test, + name=name) + +def basicblock(input, ch_out, is_test=True, name=None): + conv1 = conv_bn_layer(input, ch_out, 1, 1, 0, is_test=is_test, name=name+".0") + conv2 = conv_bn_layer(conv1, ch_out*2, 3, 1, 1, is_test=is_test, name=name+".1") + out = fluid.layers.elementwise_add(x=input, y=conv2, act=None) + return out + +def layer_warp(block_func, input, ch_out, count, is_test=True, name=None): + res_out = block_func(input, ch_out, is_test=is_test, name='{}.0'.format(name)) + for j in range(1, count): + res_out = block_func(res_out, ch_out, is_test=is_test, name='{}.{}'.format(name, j)) + return res_out + +DarkNet_cfg = { + 53: ([1,2,8,8,4],basicblock) +} + +def add_DarkNet53_conv_body(body_input, is_test=True): + stages, block_func = DarkNet_cfg[53] + stages = stages[0:5] + conv1 = conv_bn_layer( + body_input, ch_out=32, filter_size=3, stride=1, padding=1, is_test=is_test, name="yolo_input") + downsample_ = downsample(conv1, ch_out=conv1.shape[1]*2, is_test=is_test, name="yolo_input.downsample") + index = 2 + blocks = [] + for i, stage in enumerate(stages): + block = layer_warp(block_func, downsample_, 32 *(2**i), stage, is_test=is_test, name="stage.{}".format(i)) + blocks.append(block) + index += 3 * stage + if i < len(stages) - 1: # do not downsaple in the last stage + downsample_ = downsample(block, ch_out=block.shape[1]*2, is_test=is_test, name="stage.{}.downsample".format(i)) + index += 1 + return blocks[-1:-4:-1] + diff --git a/fluid/PaddleCV/yolov3/models/yolov3.py b/fluid/PaddleCV/yolov3/models/yolov3.py index b10be47c247ec530b6e27335d00c9e757e977d9b..d117294a9b4e39303db867522693a7c7e14caad5 100644 --- a/fluid/PaddleCV/yolov3/models/yolov3.py +++ b/fluid/PaddleCV/yolov3/models/yolov3.py @@ -1,279 +1,209 @@ - -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. - -from __future__ import division -from __future__ import print_function - -import paddle.fluid as fluid -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.initializer import Constant -from paddle.fluid.initializer import Normal -from paddle.fluid.regularizer import L2Decay - -from config_parser import ConfigPaser -from config import cfg - -from darknet import add_DarkNet53_conv_body -from darknet import conv_bn_layer - -def yolo_detection_block(input, channel,i): - assert channel % 2 == 0, "channel {} cannot be divided by 2".format(channel) - conv1 = input - for j in range(2): - conv1 = conv_bn_layer(conv1, channel, filter_size=1, stride=1, padding=0,i=i+j*2) - conv1 = conv_bn_layer(conv1, channel*2, filter_size=3, stride=1, padding=1,i=i+j*2+1) - route = conv_bn_layer(conv1, channel, filter_size=1, stride=1, padding=0,i=i+4) - tip = conv_bn_layer(route,channel*2, filter_size=3, stride=1, padding=1,i=i+5) - return route, tip - -def upsample(out, stride=2,name=None): - out = out - scale = stride - # get dynamic upsample output shape - shape_nchw = fluid.layers.shape(out) - shape_hw = fluid.layers.slice(shape_nchw, axes=[0], starts=[2], ends=[4]) - shape_hw.stop_gradient = True - in_shape = fluid.layers.cast(shape_hw, dtype='int32') - out_shape = in_shape * scale - out_shape.stop_gradient = True - - # reisze by actual_shape - out = fluid.layers.resize_nearest( - input=out, - scale=scale, - actual_shape=out_shape, - name=name) - return out - -class YOLOv3(object): - def __init__(self, - model_cfg_path, - is_train=True, - use_pyreader=True, - use_random=True): - self.model_cfg_path = model_cfg_path - self.config_parser = ConfigPaser(model_cfg_path) - self.is_train = is_train - self.use_pyreader = use_pyreader - self.use_random = use_random - self.outputs = [] - self.losses = [] - self.downsample = 32 - self.ignore_thresh = .7 - self.class_num = 80 - - def build_model(self): - - self.img_height = cfg.input_size - self.img_width = cfg.input_size - - self.build_input() - - out = self.image - - self.yolo_anchors = [] - self.yolo_classes = [] - self.outputs = [] - self.boxes = [] - self.scores = [] - - - scale1,scale2,scale3 = add_DarkNet53_conv_body(out) - - # 13*13 scale output - route1, tip1 = yolo_detection_block(scale1, channel=512,i=75) - # scale1 output - scale1_out = fluid.layers.conv2d( - input=tip1, - num_filters=255, - filter_size=1, - stride=1, - padding=0, - act=None, - param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02), - name="conv81_weights"), - bias_attr=ParamAttr(initializer=fluid.initializer.Constant(0.0), - regularizer=L2Decay(0.), - name="conv81_bias")) - - self.outputs.append(scale1_out) - - route1 = conv_bn_layer( - input=route1, - ch_out=256, - filter_size=1, - stride=1, - padding=0, - i=84) - # upsample - route1 = upsample(route1) - - # concat - route1 = fluid.layers.concat( - input=[route1,scale2], - axis=1) - - # 26*26 scale output - route2, tip2 = yolo_detection_block(route1, channel=256,i=87) - - # scale2 output - scale2_out = fluid.layers.conv2d( - input=tip2, - num_filters=255, - filter_size=1, - stride=1, - padding=0, - act=None, - param_attr=ParamAttr(name="conv93_weights"), - bias_attr=ParamAttr(name="conv93_bias")) - - self.outputs.append(scale2_out) - - route2 = conv_bn_layer( - input=route2, - ch_out=128, - filter_size=1, - stride=1, - padding=0, - i=96) - # upsample - route2 = upsample(route2) - - # concat - route2 = fluid.layers.concat( - input=[route2,scale3], - axis=1) - - # 52*52 scale output - route3, tip3 = yolo_detection_block(route2, channel=128, i=99) - - # scale3 output - scale3_out = fluid.layers.conv2d( - input=tip3, - num_filters=255, - filter_size=1, - stride=1, - padding=0, - act=None, - param_attr=ParamAttr(name="conv105_weights"), - bias_attr=ParamAttr(name="conv105_bias")) - - - self.outputs.append(scale3_out) - # yolo - - anchor_mask = [6,7,8,3,4,5,0,1,2] - anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] - for i,out in enumerate(self.outputs): - mask = anchor_mask[i*3 : (i+1)*3] - mask_anchors=[] - - for m in mask: - mask_anchors.append(anchors[2 * m]) - mask_anchors.append(anchors[2 * m + 1]) - self.yolo_anchors.append(mask_anchors) - class_num = int(self.class_num) - self.yolo_classes.append(class_num) - - if self.is_train: - ignore_thresh = float(self.ignore_thresh) - loss = fluid.layers.yolov3_loss( - x=out, - gtbox=self.gtbox, - gtlabel=self.gtlabel, - # gtscore=self.gtscore, - anchors=anchors, - anchor_mask=mask, - class_num=class_num, - ignore_thresh=ignore_thresh, - downsample_ratio=self.downsample, - # use_label_smooth=False, - name="yolo_loss"+str(i)) - self.losses.append(fluid.layers.reduce_mean(loss)) - else: - boxes, scores = fluid.layers.yolo_box( - x=out, - img_size=self.im_shape, - anchors=mask_anchors, - class_num=class_num, - conf_thresh=cfg.valid_thresh, - downsample_ratio=self.downsample, - name="yolo_box"+str(i)) - self.boxes.append(boxes) - self.scores.append(fluid.layers.transpose(scores, perm=[0, 2, 1])) - - self.downsample //= 2 - - - def loss(self): - return sum(self.losses) - - def get_pred(self): - # return self.outputs - yolo_boxes = fluid.layers.concat(self.boxes, axis=1) - yolo_scores = fluid.layers.concat(self.scores, axis=2) - return fluid.layers.multiclass_nms( - bboxes=yolo_boxes, - scores=yolo_scores, - score_threshold=cfg.valid_thresh, - nms_top_k=cfg.nms_topk, - keep_top_k=cfg.nms_posk, - nms_threshold=cfg.nms_thresh, - background_label=-1, - name="multiclass_nms") - - def get_yolo_anchors(self): - return self.yolo_anchors - - def get_yolo_classes(self): - return self.yolo_classes - - def build_input(self): - self.image_shape = [3, self.img_height, self.img_width] - if self.use_pyreader and self.is_train: - self.py_reader = fluid.layers.py_reader( - capacity=64, - shapes = [[-1] + self.image_shape, [-1, cfg.max_box_num, 4], [-1, cfg.max_box_num], [-1, cfg.max_box_num]], - lod_levels=[0, 0, 0, 0], - dtypes=['float32'] * 2 + ['int32'] + ['float32'], - use_double_buffer=True) - self.image, self.gtbox, self.gtlabel, self.gtscore = fluid.layers.read_file(self.py_reader) - else: - self.image = fluid.layers.data( - name='image', shape=self.image_shape, dtype='float32' - ) - self.gtbox = fluid.layers.data( - name='gtbox', shape=[cfg.max_box_num, 4], dtype='float32' - ) - self.gtlabel = fluid.layers.data( - name='gtlabel', shape=[cfg.max_box_num], dtype='int32' - ) - self.gtscore = fluid.layers.data( - name='gtscore', shape=[cfg.max_box_num], dtype='float32' - ) - self.im_shape = fluid.layers.data( - name="im_shape", shape=[2], dtype='int32') - self.im_id = fluid.layers.data( - name="im_id", shape=[1], dtype='int32') - - def feeds(self): - if not self.is_train: - return [self.image, self.im_id, self.im_shape] - return [self.image, self.gtbox, self.gtlabel, self.gtscore] - - def get_input_size(self): - return cfg.input_size - - + +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Constant +from paddle.fluid.initializer import Normal +from paddle.fluid.regularizer import L2Decay + +from config import cfg + +from .darknet import add_DarkNet53_conv_body +from .darknet import conv_bn_layer + +def yolo_detection_block(input, channel, is_test=True, name=None): + assert channel % 2 == 0, "channel {} cannot be divided by 2".format(channel) + conv = input + for j in range(2): + conv = conv_bn_layer(conv, channel, filter_size=1, stride=1, padding=0, is_test=is_test, name='{}.{}.0'.format(name, j)) + conv = conv_bn_layer(conv, channel*2, filter_size=3, stride=1, padding=1, is_test=is_test, name='{}.{}.1'.format(name, j)) + route = conv_bn_layer(conv, channel, filter_size=1, stride=1, padding=0, is_test=is_test, name='{}.2'.format(name)) + tip = conv_bn_layer(route,channel*2, filter_size=3, stride=1, padding=1, is_test=is_test, name='{}.tip'.format(name)) + return route, tip + +def upsample(input, scale=2,name=None): + # get dynamic upsample output shape + shape_nchw = fluid.layers.shape(input) + shape_hw = fluid.layers.slice(shape_nchw, axes=[0], starts=[2], ends=[4]) + shape_hw.stop_gradient = True + in_shape = fluid.layers.cast(shape_hw, dtype='int32') + out_shape = in_shape * scale + out_shape.stop_gradient = True + + # reisze by actual_shape + out = fluid.layers.resize_nearest( + input=input, + scale=scale, + actual_shape=out_shape, + name=name) + return out + +class YOLOv3(object): + def __init__(self, + model_cfg_path, + is_train=True, + use_pyreader=True, + use_random=True): + self.model_cfg_path = model_cfg_path + self.is_train = is_train + self.use_pyreader = use_pyreader + self.use_random = use_random + self.outputs = [] + self.losses = [] + self.downsample = 32 + self.ignore_thresh = .7 + self.class_num = 80 + + def build_model(self): + + self.img_height = cfg.input_size + self.img_width = cfg.input_size + + self.build_input() + + self.outputs = [] + self.boxes = [] + self.scores = [] + + blocks = add_DarkNet53_conv_body(self.image, not self.is_train) + for i, block in enumerate(blocks): + if i > 0: + block = fluid.layers.concat( + input=[route, block], + axis=1) + route, tip = yolo_detection_block(block, channel=512//(2**i), + is_test=(not self.is_train), + name="yolo_block.{}".format(i)) + block_out = fluid.layers.conv2d( + input=tip, + num_filters=255, + filter_size=1, + stride=1, + padding=0, + act=None, + param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02), + name="yolo_output.{}.conv.weights".format(i)), + bias_attr=ParamAttr(initializer=fluid.initializer.Constant(0.0), + regularizer=L2Decay(0.), + name="yolo_output.{}.conv.bias".format(i))) + self.outputs.append(block_out) + + if i < len(blocks) - 1: + route = conv_bn_layer( + input=route, + ch_out=256//(2**i), + filter_size=1, + stride=1, + padding=0, + is_test=(not self.is_train), + name="yolo_transition.{}".format(i)) + # upsample + route = upsample(route) + + + anchor_mask = [6,7,8,3,4,5,0,1,2] + anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] + for i,out in enumerate(self.outputs): + mask = anchor_mask[i*3 : (i+1)*3] + mask_anchors=[] + + for m in mask: + mask_anchors.append(anchors[2 * m]) + mask_anchors.append(anchors[2 * m + 1]) + class_num = int(self.class_num) + + if self.is_train: + ignore_thresh = float(self.ignore_thresh) + loss = fluid.layers.yolov3_loss( + x=out, + gtbox=self.gtbox, + gtlabel=self.gtlabel, + gtscore=self.gtscore, + anchors=anchors, + anchor_mask=mask, + class_num=class_num, + ignore_thresh=ignore_thresh, + downsample_ratio=self.downsample, + use_label_smooth=cfg.label_smooth, + name="yolo_loss"+str(i)) + self.losses.append(fluid.layers.reduce_mean(loss)) + else: + boxes, scores = fluid.layers.yolo_box( + x=out, + img_size=self.im_shape, + anchors=mask_anchors, + class_num=class_num, + conf_thresh=cfg.valid_thresh, + downsample_ratio=self.downsample, + name="yolo_box"+str(i)) + self.boxes.append(boxes) + self.scores.append(fluid.layers.transpose(scores, perm=[0, 2, 1])) + + self.downsample //= 2 + + + def loss(self): + return sum(self.losses) + + def get_pred(self): + yolo_boxes = fluid.layers.concat(self.boxes, axis=1) + yolo_scores = fluid.layers.concat(self.scores, axis=2) + return fluid.layers.multiclass_nms( + bboxes=yolo_boxes, + scores=yolo_scores, + score_threshold=cfg.valid_thresh, + nms_top_k=cfg.nms_topk, + keep_top_k=cfg.nms_posk, + nms_threshold=cfg.nms_thresh, + background_label=-1, + name="multiclass_nms") + + def build_input(self): + self.image_shape = [3, self.img_height, self.img_width] + if self.use_pyreader and self.is_train: + self.py_reader = fluid.layers.py_reader( + capacity=64, + shapes = [[-1] + self.image_shape, [-1, cfg.max_box_num, 4], [-1, cfg.max_box_num], [-1, cfg.max_box_num]], + lod_levels=[0, 0, 0, 0], + dtypes=['float32'] * 2 + ['int32'] + ['float32'], + use_double_buffer=True) + self.image, self.gtbox, self.gtlabel, self.gtscore = fluid.layers.read_file(self.py_reader) + else: + self.image = fluid.layers.data( + name='image', shape=self.image_shape, dtype='float32' + ) + self.gtbox = fluid.layers.data( + name='gtbox', shape=[cfg.max_box_num, 4], dtype='float32' + ) + self.gtlabel = fluid.layers.data( + name='gtlabel', shape=[cfg.max_box_num], dtype='int32' + ) + self.gtscore = fluid.layers.data( + name='gtscore', shape=[cfg.max_box_num], dtype='float32' + ) + self.im_shape = fluid.layers.data( + name="im_shape", shape=[2], dtype='int32') + self.im_id = fluid.layers.data( + name="im_id", shape=[1], dtype='int32') + + def feeds(self): + if not self.is_train: + return [self.image, self.im_id, self.im_shape] + return [self.image, self.gtbox, self.gtlabel, self.gtscore] + diff --git a/fluid/PaddleCV/yolov3/reader.py b/fluid/PaddleCV/yolov3/reader.py index ab519dff9882ebb8f2ccadda6013fc7a73e33b57..bbd540119ec2b40dfd721af43e6588d9ac4beba7 100644 --- a/fluid/PaddleCV/yolov3/reader.py +++ b/fluid/PaddleCV/yolov3/reader.py @@ -255,8 +255,8 @@ def train(size=416, random_sizes=[], interval=10, pyreader_num=1, - num_workers=16, - max_queue=32, + num_workers=2, + max_queue=4, use_multiprocessing=True): generator = dsr.get_reader('train', size, batch_size, shuffle, int(mixup_iter/pyreader_num), random_sizes) diff --git a/fluid/PaddleCV/yolov3/train.py b/fluid/PaddleCV/yolov3/train.py index 21432ade6947151933141b1a3f8579f6a0fbc6c5..932e7797e0c8f44a2c7728bbbc4fcd380e416434 100644 --- a/fluid/PaddleCV/yolov3/train.py +++ b/fluid/PaddleCV/yolov3/train.py @@ -26,7 +26,7 @@ from utility import parse_args, print_arguments, SmoothedValue import paddle import paddle.fluid as fluid import reader -import models.yolov3 as models +from models.yolov3 import YOLOv3 from learning_rate import exponential_with_warmup_decay from config import cfg @@ -42,27 +42,21 @@ def train(): if not os.path.exists(cfg.model_save_dir): os.makedirs(cfg.model_save_dir) - model = models.YOLOv3(cfg.model_cfg_path, use_pyreader=cfg.use_pyreader) + model = YOLOv3(cfg.model_cfg_path, use_pyreader=cfg.use_pyreader) model.build_model() - input_size = model.get_input_size() + input_size = cfg.input_size loss = model.loss() loss.persistable = True - print("cfg.learning",cfg.learning_rate) - print("cfg.decay",cfg.decay) - devices = os.getenv("CUDA_VISIBLE_DEVICES") or "" devices_num = len(devices.split(",")) print("Found {} CUDA devices.".format(devices_num)) - learning_rate = float(cfg.learning_rate) + learning_rate = cfg.learning_rate boundaries = cfg.lr_steps gamma = cfg.lr_gamma step_num = len(cfg.lr_steps) - if isinstance(gamma, list): - values = [learning_rate * g for g in gamma] - else: - values = [learning_rate * (gamma**i) for i in range(step_num + 1)] + values = [learning_rate * (gamma**i) for i in range(step_num + 1)] optimizer = fluid.optimizer.Momentum( learning_rate=exponential_with_warmup_decay( @@ -70,10 +64,9 @@ def train(): boundaries=boundaries, values=values, warmup_iter=cfg.warm_up_iter, - warmup_factor=cfg.warm_up_factor, - start_step=cfg.start_iter), - regularization=fluid.regularizer.L2Decay(float(cfg.decay)), - momentum=float(cfg.momentum)) + warmup_factor=cfg.warm_up_factor), + regularization=fluid.regularizer.L2Decay(cfg.weight_decay), + momentum=cfg.momentum) optimizer.minimize(loss) fluid.memory_optimize(fluid.default_main_program()) @@ -98,11 +91,11 @@ def train(): mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter if cfg.use_pyreader: - train_reader = reader.train(input_size, batch_size=int(cfg.batch)/devices_num, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, interval=10, pyreader_num=devices_num, use_multiprocessing=cfg.use_multiprocess) + train_reader = reader.train(input_size, batch_size=cfg.batch_size/devices_num, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, interval=10, pyreader_num=devices_num, use_multiprocessing=cfg.use_multiprocess) py_reader = model.py_reader py_reader.decorate_paddle_reader(train_reader) else: - train_reader = reader.train(input_size, batch_size=int(cfg.batch), shuffle=True, mixup_iter=mixup_iter, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess) + train_reader = reader.train(input_size, batch_size=cfg.batch_size, shuffle=True, mixup_iter=mixup_iter, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess) feeder = fluid.DataFeeder(place=place, feed_list=model.feeds()) def save_model(postfix): diff --git a/fluid/PaddleCV/yolov3/utility.py b/fluid/PaddleCV/yolov3/utility.py index 5c5ec77dc6d21ea73f492e3ae01a33ada730f594..41f03512228f81f569483ef424a463444fd6a5ba 100644 --- a/fluid/PaddleCV/yolov3/utility.py +++ b/fluid/PaddleCV/yolov3/utility.py @@ -108,14 +108,15 @@ def parse_args(): add_arg('start_iter', int, 0, "Start iteration.") add_arg('use_multiprocess', bool, True, "add multiprocess.") #SOLVER + add_arg('batch_size', int, 64, "Learning rate.") add_arg('learning_rate', float, 0.001, "Learning rate.") add_arg('max_iter', int, 500200, "Iter number.") add_arg('snapshot_iter', int, 2000, "Save model every snapshot stride.") + add_arg('label_smooth', bool, True, "Use label smooth in class label.") + add_arg('no_mixup_iter', int, 40000, "Disable mixup in last N iter.") # TRAIN TEST INFER add_arg('input_size', int, 608, "Image input size of YOLOv3.") add_arg('random_shape', bool, True, "Resize to random shape for train reader.") - add_arg('label_smooth', bool, True, "Use label smooth in class label.") - add_arg('no_mixup_iter', int, 40000, "Disable mixup in last N iter.") add_arg('valid_thresh', float, 0.005, "Valid confidence score for NMS.") add_arg('nms_thresh', float, 0.45, "NMS threshold.") add_arg('nms_topk', int, 400, "The number of boxes to perform NMS.")