From d4b30ff20beff95355dae5364ff85640b2994128 Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Mon, 18 Jan 2021 16:32:40 +0800 Subject: [PATCH] [Dygraph] add SSD/SSDLite mbv1v3 (#2070) * fix ssd ssdlite scheduler, add cosdecay * fix mbv1v3 BatchNorm mean variance * update ssd mbv1 voc modelzoo * fix ssd mbv1 warmup, update modelzoo * fix cosdecay --- dygraph/configs/ssd/README.md | 5 +- dygraph/configs/ssd/_base_/optimizer_120e.yml | 10 +-- ...ptimizer_1000e.yml => optimizer_1700e.yml} | 9 +-- .../ssd/ssdlite_mobilenet_v1_300_coco.yml | 2 +- .../ssdlite_mobilenet_v3_large_320_coco.yml | 2 +- .../ssdlite_mobilenet_v3_small_320_coco.yml | 2 +- .../ppdet/modeling/backbones/mobilenet_v1.py | 22 ++++--- .../ppdet/modeling/backbones/mobilenet_v3.py | 37 +++++++---- dygraph/ppdet/optimizer.py | 62 ++++++++++++++++++- 9 files changed, 110 insertions(+), 41 deletions(-) rename dygraph/configs/ssd/_base_/{optimizer_1000e.yml => optimizer_1700e.yml} (73%) diff --git a/dygraph/configs/ssd/README.md b/dygraph/configs/ssd/README.md index e8593eac1..b62eb9dcc 100644 --- a/dygraph/configs/ssd/README.md +++ b/dygraph/configs/ssd/README.md @@ -6,9 +6,10 @@ | 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 | | :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | -| VGG | SSD | 8 | 240e | ---- | 78.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_vgg16_300_240e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/ssd_vgg16_300_240e_voc.yml) | +| VGG | SSD | 8 | 240e | ---- | 78.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_vgg16_300_240e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ssd_vgg16_300_240e_voc.yml) | +| MobileNet v1 | SSD | 32 | 120e | ---- | 73.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_mobilenet_v1_300_120e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ssd_mobilenet_v1_300_120e_voc.yml) | -**注意:** SSD使用4GPU训练,训练240个epoch +**注意:** SSD-VGG使用4GPU在总batch size为32下训练240个epoch。SSD-MobileNetv1使用2GPU在总batch size为64下训练120周期。 ## Citations ``` diff --git a/dygraph/configs/ssd/_base_/optimizer_120e.yml b/dygraph/configs/ssd/_base_/optimizer_120e.yml index 1835fc433..908745c6c 100644 --- a/dygraph/configs/ssd/_base_/optimizer_120e.yml +++ b/dygraph/configs/ssd/_base_/optimizer_120e.yml @@ -4,13 +4,9 @@ LearningRate: base_lr: 0.001 schedulers: - !PiecewiseDecay - gamma: 0.1 - milestones: - - 80 - - 100 - - !LinearWarmup - start_factor: 0.3333333333333333 - steps: 500 + milestones: [40, 60, 80, 100] + values: [0.001, 0.0005, 0.00025, 0.0001, 0.00001] + use_warmup: false OptimizerBuilder: optimizer: diff --git a/dygraph/configs/ssd/_base_/optimizer_1000e.yml b/dygraph/configs/ssd/_base_/optimizer_1700e.yml similarity index 73% rename from dygraph/configs/ssd/_base_/optimizer_1000e.yml rename to dygraph/configs/ssd/_base_/optimizer_1700e.yml index 3c46b9f4c..fe5fedc7c 100644 --- a/dygraph/configs/ssd/_base_/optimizer_1000e.yml +++ b/dygraph/configs/ssd/_base_/optimizer_1700e.yml @@ -1,13 +1,10 @@ -epoch: 1746 +epoch: 1700 LearningRate: base_lr: 0.4 schedulers: - - !PiecewiseDecay - gamma: 0.1 - milestones: - - 160 - - 200 + - !CosineDecay + max_epochs: 1700 - !LinearWarmup start_factor: 0.3333333333333333 steps: 2000 diff --git a/dygraph/configs/ssd/ssdlite_mobilenet_v1_300_coco.yml b/dygraph/configs/ssd/ssdlite_mobilenet_v1_300_coco.yml index d3a649adf..75cb8a8a2 100644 --- a/dygraph/configs/ssd/ssdlite_mobilenet_v1_300_coco.yml +++ b/dygraph/configs/ssd/ssdlite_mobilenet_v1_300_coco.yml @@ -1,7 +1,7 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', - '_base_/optimizer_1000e.yml', + '_base_/optimizer_1700e.yml', '_base_/ssdlite_mobilenet_v1_300.yml', '_base_/ssdlite300_reader.yml', ] diff --git a/dygraph/configs/ssd/ssdlite_mobilenet_v3_large_320_coco.yml b/dygraph/configs/ssd/ssdlite_mobilenet_v3_large_320_coco.yml index 54fb553a0..78d561aad 100644 --- a/dygraph/configs/ssd/ssdlite_mobilenet_v3_large_320_coco.yml +++ b/dygraph/configs/ssd/ssdlite_mobilenet_v3_large_320_coco.yml @@ -1,7 +1,7 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', - '_base_/optimizer_1000e.yml', + '_base_/optimizer_1700e.yml', '_base_/ssdlite_mobilenet_v3_large_320.yml', '_base_/ssdlite320_reader.yml', ] diff --git a/dygraph/configs/ssd/ssdlite_mobilenet_v3_small_320_coco.yml b/dygraph/configs/ssd/ssdlite_mobilenet_v3_small_320_coco.yml index 1b696e7b9..fa0ce5346 100644 --- a/dygraph/configs/ssd/ssdlite_mobilenet_v3_small_320_coco.yml +++ b/dygraph/configs/ssd/ssdlite_mobilenet_v3_small_320_coco.yml @@ -1,7 +1,7 @@ _BASE_: [ '../datasets/coco_detection.yml', '../runtime.yml', - '_base_/optimizer_1000e.yml', + '_base_/optimizer_1700e.yml', '_base_/ssdlite_mobilenet_v3_small_320.yml', '_base_/ssdlite320_reader.yml', ] diff --git a/dygraph/ppdet/modeling/backbones/mobilenet_v1.py b/dygraph/ppdet/modeling/backbones/mobilenet_v1.py index 42fa5a2aa..198773b5e 100644 --- a/dygraph/ppdet/modeling/backbones/mobilenet_v1.py +++ b/dygraph/ppdet/modeling/backbones/mobilenet_v1.py @@ -58,16 +58,22 @@ class ConvBNLayer(nn.Layer): name=name + "_weights"), bias_attr=False) + param_attr = ParamAttr( + name=name + "_bn_scale", regularizer=L2Decay(norm_decay)) + bias_attr = ParamAttr( + name=name + "_bn_offset", regularizer=L2Decay(norm_decay)) if norm_type == 'sync_bn': - batch_norm = nn.SyncBatchNorm + self._batch_norm = nn.SyncBatchNorm( + out_channels, weight_attr=param_attr, bias_attr=bias_attr) else: - batch_norm = nn.BatchNorm2D - self._batch_norm = batch_norm( - out_channels, - weight_attr=ParamAttr( - name=name + "_bn_scale", regularizer=L2Decay(norm_decay)), - bias_attr=ParamAttr( - name=name + "_bn_offset", regularizer=L2Decay(norm_decay))) + self._batch_norm = nn.BatchNorm( + out_channels, + act=None, + param_attr=param_attr, + bias_attr=bias_attr, + use_global_stats=False, + moving_mean_name=name + '_bn_mean', + moving_variance_name=name + '_bn_variance') def forward(self, x): x = self._conv(x) diff --git a/dygraph/ppdet/modeling/backbones/mobilenet_v3.py b/dygraph/ppdet/modeling/backbones/mobilenet_v3.py index 7d00a8317..40ee4e86f 100644 --- a/dygraph/ppdet/modeling/backbones/mobilenet_v3.py +++ b/dygraph/ppdet/modeling/backbones/mobilenet_v3.py @@ -67,20 +67,33 @@ class ConvBNLayer(nn.Layer): bias_attr=False) norm_lr = 0. if freeze_norm else lr_mult + param_attr = ParamAttr( + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay), + name=name + "_bn_scale", + trainable=False if freeze_norm else True) + bias_attr = ParamAttr( + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay), + name=name + "_bn_offset", + trainable=False if freeze_norm else True) + global_stats = True if freeze_norm else False if norm_type == 'sync_bn': - batch_norm = nn.SyncBatchNorm + self.bn = nn.SyncBatchNorm( + out_c, weight_attr=param_attr, bias_attr=bias_attr) else: - batch_norm = nn.BatchNorm2D - self.bn = batch_norm( - out_c, - weight_attr=ParamAttr( - learning_rate=norm_lr, - name=name + "_bn_scale", - regularizer=L2Decay(norm_decay)), - bias_attr=ParamAttr( - learning_rate=norm_lr, - name=name + "_bn_offset", - regularizer=L2Decay(norm_decay))) + self.bn = nn.BatchNorm( + out_c, + act=None, + param_attr=param_attr, + bias_attr=bias_attr, + use_global_stats=global_stats, + moving_mean_name=name + '_bn_mean', + moving_variance_name=name + '_bn_variance') + norm_params = self.bn.parameters() + if freeze_norm: + for param in norm_params: + param.stop_gradient = True def forward(self, x): x = self.conv(x) diff --git a/dygraph/ppdet/optimizer.py b/dygraph/ppdet/optimizer.py index a11fde40d..3c1a17be3 100644 --- a/dygraph/ppdet/optimizer.py +++ b/dygraph/ppdet/optimizer.py @@ -21,6 +21,7 @@ import paddle import paddle.nn as nn import paddle.optimizer as optimizer +from paddle.optimizer.lr import CosineAnnealingDecay import paddle.regularizer as regularizer from paddle import cos @@ -32,6 +33,42 @@ from ppdet.utils.logger import setup_logger logger = setup_logger(__name__) +@serializable +class CosineDecay(object): + """ + Cosine learning rate decay + + Args: + max_epochs (int): max epochs for the training process. + if you commbine cosine decay with warmup, it is recommended that + the max_iters is much larger than the warmup iter + """ + + def __init__(self, max_epochs=1000, use_warmup=True): + self.max_epochs = max_epochs + self.use_warmup = use_warmup + + def __call__(self, + base_lr=None, + boundary=None, + value=None, + step_per_epoch=None): + assert base_lr is not None, "either base LR or values should be provided" + + max_iters = self.max_epochs * int(step_per_epoch) + + if boundary is not None and value is not None and self.use_warmup: + for i in range(int(boundary[-1]), max_iters): + boundary.append(i) + + decayed_lr = base_lr * 0.5 * ( + math.cos(i * math.pi / max_iters) + 1) + value.append(decayed_lr) + return optimizer.lr.PiecewiseDecay(boundary, value) + + return optimizer.lr.CosineAnnealingDecay(base_lr, T_max=max_iters) + + @serializable class PiecewiseDecay(object): """ @@ -42,7 +79,11 @@ class PiecewiseDecay(object): milestones (list): steps at which to decay learning rate """ - def __init__(self, gamma=[0.1, 0.01], milestones=[8, 11]): + def __init__(self, + gamma=[0.1, 0.01], + milestones=[8, 11], + values=None, + use_warmup=True): super(PiecewiseDecay, self).__init__() if type(gamma) is not list: self.gamma = [] @@ -51,15 +92,26 @@ class PiecewiseDecay(object): else: self.gamma = gamma self.milestones = milestones + self.values = values + self.use_warmup = use_warmup def __call__(self, base_lr=None, boundary=None, value=None, step_per_epoch=None): - if boundary is not None: + if boundary is not None and self.use_warmup: boundary.extend([int(step_per_epoch) * i for i in self.milestones]) + else: + # do not use LinearWarmup + boundary = [int(step_per_epoch) * i for i in self.milestones] + # self.values is setted directly in config + if self.values is not None: + assert len(self.milestones) + 1 == len(self.values) + return optimizer.lr.PiecewiseDecay(boundary, self.values) + + # value is computed by self.gamma if value is not None: for i in self.gamma: value.append(base_lr * i) @@ -114,6 +166,11 @@ class LearningRate(object): self.schedulers = schedulers def __call__(self, step_per_epoch): + assert len(self.schedulers) >= 1 + if not self.schedulers[0].use_warmup: + return self.schedulers[0](base_lr=self.base_lr, + step_per_epoch=step_per_epoch) + # TODO: split warmup & decay # warmup boundary, value = self.schedulers[1](self.base_lr) @@ -127,7 +184,6 @@ class LearningRate(object): class OptimizerBuilder(): """ Build optimizer handles - Args: regularizer (object): an `Regularizer` instance optimizer (object): an `Optimizer` instance -- GitLab