未验证 提交 d4b30ff2 编写于 作者: F Feng Ni 提交者: GitHub

[Dygraph] add SSD/SSDLite mbv1v3 (#2070)

* fix ssd ssdlite scheduler, add cosdecay

* fix mbv1v3 BatchNorm mean variance

* update ssd mbv1 voc modelzoo

* fix ssd mbv1 warmup, update modelzoo

* fix cosdecay
上级 99ad5a84
......@@ -6,9 +6,10 @@
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 |
| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| VGG | SSD | 8 | 240e | ---- | 78.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_vgg16_300_240e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/ssd_vgg16_300_240e_voc.yml) |
| VGG | SSD | 8 | 240e | ---- | 78.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_vgg16_300_240e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ssd_vgg16_300_240e_voc.yml) |
| MobileNet v1 | SSD | 32 | 120e | ---- | 73.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_mobilenet_v1_300_120e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ssd_mobilenet_v1_300_120e_voc.yml) |
**注意:** SSD使用4GPU训练,训练240个epoch
**注意:** SSD-VGG使用4GPU在总batch size为32下训练240个epoch。SSD-MobileNetv1使用2GPU在总batch size为64下训练120周期。
## Citations
```
......
......@@ -4,13 +4,9 @@ LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 80
- 100
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
milestones: [40, 60, 80, 100]
values: [0.001, 0.0005, 0.00025, 0.0001, 0.00001]
use_warmup: false
OptimizerBuilder:
optimizer:
......
epoch: 1746
epoch: 1700
LearningRate:
base_lr: 0.4
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 160
- 200
- !CosineDecay
max_epochs: 1700
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 2000
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1000e.yml',
'_base_/optimizer_1700e.yml',
'_base_/ssdlite_mobilenet_v1_300.yml',
'_base_/ssdlite300_reader.yml',
]
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1000e.yml',
'_base_/optimizer_1700e.yml',
'_base_/ssdlite_mobilenet_v3_large_320.yml',
'_base_/ssdlite320_reader.yml',
]
......
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1000e.yml',
'_base_/optimizer_1700e.yml',
'_base_/ssdlite_mobilenet_v3_small_320.yml',
'_base_/ssdlite320_reader.yml',
]
......
......@@ -58,16 +58,22 @@ class ConvBNLayer(nn.Layer):
name=name + "_weights"),
bias_attr=False)
param_attr = ParamAttr(
name=name + "_bn_scale", regularizer=L2Decay(norm_decay))
bias_attr = ParamAttr(
name=name + "_bn_offset", regularizer=L2Decay(norm_decay))
if norm_type == 'sync_bn':
batch_norm = nn.SyncBatchNorm
self._batch_norm = nn.SyncBatchNorm(
out_channels, weight_attr=param_attr, bias_attr=bias_attr)
else:
batch_norm = nn.BatchNorm2D
self._batch_norm = batch_norm(
out_channels,
weight_attr=ParamAttr(
name=name + "_bn_scale", regularizer=L2Decay(norm_decay)),
bias_attr=ParamAttr(
name=name + "_bn_offset", regularizer=L2Decay(norm_decay)))
self._batch_norm = nn.BatchNorm(
out_channels,
act=None,
param_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=False,
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
def forward(self, x):
x = self._conv(x)
......
......@@ -67,20 +67,33 @@ class ConvBNLayer(nn.Layer):
bias_attr=False)
norm_lr = 0. if freeze_norm else lr_mult
param_attr = ParamAttr(
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay),
name=name + "_bn_scale",
trainable=False if freeze_norm else True)
bias_attr = ParamAttr(
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay),
name=name + "_bn_offset",
trainable=False if freeze_norm else True)
global_stats = True if freeze_norm else False
if norm_type == 'sync_bn':
batch_norm = nn.SyncBatchNorm
self.bn = nn.SyncBatchNorm(
out_c, weight_attr=param_attr, bias_attr=bias_attr)
else:
batch_norm = nn.BatchNorm2D
self.bn = batch_norm(
out_c,
weight_attr=ParamAttr(
learning_rate=norm_lr,
name=name + "_bn_scale",
regularizer=L2Decay(norm_decay)),
bias_attr=ParamAttr(
learning_rate=norm_lr,
name=name + "_bn_offset",
regularizer=L2Decay(norm_decay)))
self.bn = nn.BatchNorm(
out_c,
act=None,
param_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=global_stats,
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
norm_params = self.bn.parameters()
if freeze_norm:
for param in norm_params:
param.stop_gradient = True
def forward(self, x):
x = self.conv(x)
......
......@@ -21,6 +21,7 @@ import paddle
import paddle.nn as nn
import paddle.optimizer as optimizer
from paddle.optimizer.lr import CosineAnnealingDecay
import paddle.regularizer as regularizer
from paddle import cos
......@@ -32,6 +33,42 @@ from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
@serializable
class CosineDecay(object):
"""
Cosine learning rate decay
Args:
max_epochs (int): max epochs for the training process.
if you commbine cosine decay with warmup, it is recommended that
the max_iters is much larger than the warmup iter
"""
def __init__(self, max_epochs=1000, use_warmup=True):
self.max_epochs = max_epochs
self.use_warmup = use_warmup
def __call__(self,
base_lr=None,
boundary=None,
value=None,
step_per_epoch=None):
assert base_lr is not None, "either base LR or values should be provided"
max_iters = self.max_epochs * int(step_per_epoch)
if boundary is not None and value is not None and self.use_warmup:
for i in range(int(boundary[-1]), max_iters):
boundary.append(i)
decayed_lr = base_lr * 0.5 * (
math.cos(i * math.pi / max_iters) + 1)
value.append(decayed_lr)
return optimizer.lr.PiecewiseDecay(boundary, value)
return optimizer.lr.CosineAnnealingDecay(base_lr, T_max=max_iters)
@serializable
class PiecewiseDecay(object):
"""
......@@ -42,7 +79,11 @@ class PiecewiseDecay(object):
milestones (list): steps at which to decay learning rate
"""
def __init__(self, gamma=[0.1, 0.01], milestones=[8, 11]):
def __init__(self,
gamma=[0.1, 0.01],
milestones=[8, 11],
values=None,
use_warmup=True):
super(PiecewiseDecay, self).__init__()
if type(gamma) is not list:
self.gamma = []
......@@ -51,15 +92,26 @@ class PiecewiseDecay(object):
else:
self.gamma = gamma
self.milestones = milestones
self.values = values
self.use_warmup = use_warmup
def __call__(self,
base_lr=None,
boundary=None,
value=None,
step_per_epoch=None):
if boundary is not None:
if boundary is not None and self.use_warmup:
boundary.extend([int(step_per_epoch) * i for i in self.milestones])
else:
# do not use LinearWarmup
boundary = [int(step_per_epoch) * i for i in self.milestones]
# self.values is setted directly in config
if self.values is not None:
assert len(self.milestones) + 1 == len(self.values)
return optimizer.lr.PiecewiseDecay(boundary, self.values)
# value is computed by self.gamma
if value is not None:
for i in self.gamma:
value.append(base_lr * i)
......@@ -114,6 +166,11 @@ class LearningRate(object):
self.schedulers = schedulers
def __call__(self, step_per_epoch):
assert len(self.schedulers) >= 1
if not self.schedulers[0].use_warmup:
return self.schedulers[0](base_lr=self.base_lr,
step_per_epoch=step_per_epoch)
# TODO: split warmup & decay
# warmup
boundary, value = self.schedulers[1](self.base_lr)
......@@ -127,7 +184,6 @@ class LearningRate(object):
class OptimizerBuilder():
"""
Build optimizer handles
Args:
regularizer (object): an `Regularizer` instance
optimizer (object): an `Optimizer` instance
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册