未验证 提交 ea2f81d8 编写于 作者: W wangxinxin08 提交者: GitHub

refine sync bn (#4361)

* refine sync bn

* fix bugs of batch_norm

* fix bugs while deploying and modify BatchNorm to BatchNorm2D

* param_attr -> weight_attr in BatchNorm2D

* modify BatchNorm to BatchNorm2D
上级 3cf6e926
...@@ -336,6 +336,12 @@ class Trainer(object): ...@@ -336,6 +336,12 @@ class Trainer(object):
assert self.mode == 'train', "Model not in 'train' mode" assert self.mode == 'train', "Model not in 'train' mode"
Init_mark = False Init_mark = False
sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
self.cfg.use_gpu and self._nranks > 1)
if sync_bn:
self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
self.model)
model = self.model model = self.model
if self.cfg.get('fleet', False): if self.cfg.get('fleet', False):
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
......
...@@ -58,11 +58,8 @@ class ConvBNLayer(nn.Layer): ...@@ -58,11 +58,8 @@ class ConvBNLayer(nn.Layer):
learning_rate=conv_lr, initializer=KaimingNormal()), learning_rate=conv_lr, initializer=KaimingNormal()),
bias_attr=False) bias_attr=False)
if norm_type == 'sync_bn': if norm_type in ['bn', 'sync_bn']:
self._batch_norm = nn.SyncBatchNorm(out_channels) self._batch_norm = nn.BatchNorm2D(out_channels)
else:
self._batch_norm = nn.BatchNorm(
out_channels, act=None, use_global_stats=False)
def forward(self, x): def forward(self, x):
x = self._conv(x) x = self._conv(x)
......
...@@ -62,11 +62,11 @@ class ConvNormLayer(nn.Layer): ...@@ -62,11 +62,11 @@ class ConvNormLayer(nn.Layer):
learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
bias_attr = ParamAttr( bias_attr = ParamAttr(
learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
global_stats = True if freeze_norm else False global_stats = True if freeze_norm else None
if norm_type in ['bn', 'sync_bn']: if norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm( self.norm = nn.BatchNorm2D(
ch_out, ch_out,
param_attr=param_attr, weight_attr=param_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
use_global_stats=global_stats) use_global_stats=global_stats)
elif norm_type == 'gn': elif norm_type == 'gn':
......
...@@ -81,9 +81,9 @@ class ConvBNLayer(nn.Layer): ...@@ -81,9 +81,9 @@ class ConvBNLayer(nn.Layer):
weight_attr=ParamAttr(initializer=KaimingNormal()), weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False) bias_attr=False)
self.bn = BatchNorm( self.bn = BatchNorm2D(
num_filters, num_filters,
param_attr=ParamAttr(regularizer=L2Decay(0.0)), weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0))) bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.hardswish = nn.Hardswish() self.hardswish = nn.Hardswish()
......
...@@ -56,11 +56,11 @@ class ConvNormLayer(nn.Layer): ...@@ -56,11 +56,11 @@ class ConvNormLayer(nn.Layer):
regularizer=L2Decay(norm_decay), ) regularizer=L2Decay(norm_decay), )
bias_attr = ParamAttr( bias_attr = ParamAttr(
learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
global_stats = True if freeze_norm else False global_stats = True if freeze_norm else None
if norm_type in ['bn', 'sync_bn']: if norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm( self.norm = nn.BatchNorm2D(
ch_out, ch_out,
param_attr=param_attr, weight_attr=param_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
use_global_stats=global_stats, ) use_global_stats=global_stats, )
elif norm_type == 'gn': elif norm_type == 'gn':
...@@ -582,7 +582,7 @@ class LiteHRNetModule(nn.Layer): ...@@ -582,7 +582,7 @@ class LiteHRNetModule(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
bias=False, ), bias=False, ),
nn.BatchNorm(self.in_channels[i]), nn.BatchNorm2D(self.in_channels[i]),
nn.Upsample( nn.Upsample(
scale_factor=2**(j - i), mode='nearest'))) scale_factor=2**(j - i), mode='nearest')))
elif j == i: elif j == i:
...@@ -601,7 +601,7 @@ class LiteHRNetModule(nn.Layer): ...@@ -601,7 +601,7 @@ class LiteHRNetModule(nn.Layer):
padding=1, padding=1,
groups=self.in_channels[j], groups=self.in_channels[j],
bias=False, ), bias=False, ),
nn.BatchNorm(self.in_channels[j]), nn.BatchNorm2D(self.in_channels[j]),
L.Conv2d( L.Conv2d(
self.in_channels[j], self.in_channels[j],
self.in_channels[i], self.in_channels[i],
...@@ -609,7 +609,7 @@ class LiteHRNetModule(nn.Layer): ...@@ -609,7 +609,7 @@ class LiteHRNetModule(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
bias=False, ), bias=False, ),
nn.BatchNorm(self.in_channels[i]))) nn.BatchNorm2D(self.in_channels[i])))
else: else:
conv_downsamples.append( conv_downsamples.append(
nn.Sequential( nn.Sequential(
...@@ -621,7 +621,7 @@ class LiteHRNetModule(nn.Layer): ...@@ -621,7 +621,7 @@ class LiteHRNetModule(nn.Layer):
padding=1, padding=1,
groups=self.in_channels[j], groups=self.in_channels[j],
bias=False, ), bias=False, ),
nn.BatchNorm(self.in_channels[j]), nn.BatchNorm2D(self.in_channels[j]),
L.Conv2d( L.Conv2d(
self.in_channels[j], self.in_channels[j],
self.in_channels[j], self.in_channels[j],
...@@ -629,7 +629,7 @@ class LiteHRNetModule(nn.Layer): ...@@ -629,7 +629,7 @@ class LiteHRNetModule(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
bias=False, ), bias=False, ),
nn.BatchNorm(self.in_channels[j]), nn.BatchNorm2D(self.in_channels[j]),
nn.ReLU())) nn.ReLU()))
fuse_layer.append(nn.Sequential(*conv_downsamples)) fuse_layer.append(nn.Sequential(*conv_downsamples))
...@@ -777,7 +777,7 @@ class LiteHRNet(nn.Layer): ...@@ -777,7 +777,7 @@ class LiteHRNet(nn.Layer):
padding=1, padding=1,
groups=num_channels_pre_layer[i], groups=num_channels_pre_layer[i],
bias=False), bias=False),
nn.BatchNorm(num_channels_pre_layer[i]), nn.BatchNorm2D(num_channels_pre_layer[i]),
L.Conv2d( L.Conv2d(
num_channels_pre_layer[i], num_channels_pre_layer[i],
num_channels_cur_layer[i], num_channels_cur_layer[i],
...@@ -785,7 +785,7 @@ class LiteHRNet(nn.Layer): ...@@ -785,7 +785,7 @@ class LiteHRNet(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
bias=False, ), bias=False, ),
nn.BatchNorm(num_channels_cur_layer[i]), nn.BatchNorm2D(num_channels_cur_layer[i]),
nn.ReLU())) nn.ReLU()))
else: else:
transition_layers.append(None) transition_layers.append(None)
...@@ -802,7 +802,7 @@ class LiteHRNet(nn.Layer): ...@@ -802,7 +802,7 @@ class LiteHRNet(nn.Layer):
stride=2, stride=2,
padding=1, padding=1,
bias=False, ), bias=False, ),
nn.BatchNorm(num_channels_pre_layer[-1]), nn.BatchNorm2D(num_channels_pre_layer[-1]),
L.Conv2d( L.Conv2d(
num_channels_pre_layer[-1], num_channels_pre_layer[-1],
num_channels_cur_layer[i] num_channels_cur_layer[i]
...@@ -812,9 +812,9 @@ class LiteHRNet(nn.Layer): ...@@ -812,9 +812,9 @@ class LiteHRNet(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
bias=False, ), bias=False, ),
nn.BatchNorm(num_channels_cur_layer[i] nn.BatchNorm2D(num_channels_cur_layer[i]
if j == i - num_branches_pre else if j == i - num_branches_pre else
num_channels_pre_layer[-1]), num_channels_pre_layer[-1]),
nn.ReLU())) nn.ReLU()))
transition_layers.append(nn.Sequential(*conv_downsamples)) transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.LayerList(transition_layers) return nn.LayerList(transition_layers)
......
...@@ -59,16 +59,9 @@ class ConvBNLayer(nn.Layer): ...@@ -59,16 +59,9 @@ class ConvBNLayer(nn.Layer):
param_attr = ParamAttr(regularizer=L2Decay(norm_decay)) param_attr = ParamAttr(regularizer=L2Decay(norm_decay))
bias_attr = ParamAttr(regularizer=L2Decay(norm_decay)) bias_attr = ParamAttr(regularizer=L2Decay(norm_decay))
if norm_type == 'sync_bn': if norm_type in ['sync_bn', 'bn']:
self._batch_norm = nn.SyncBatchNorm( self._batch_norm = nn.BatchNorm2D(
out_channels, weight_attr=param_attr, bias_attr=bias_attr) out_channels, weight_attr=param_attr, bias_attr=bias_attr)
else:
self._batch_norm = nn.BatchNorm(
out_channels,
act=None,
param_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=False)
def forward(self, x): def forward(self, x):
x = self._conv(x) x = self._conv(x)
......
...@@ -74,15 +74,11 @@ class ConvBNLayer(nn.Layer): ...@@ -74,15 +74,11 @@ class ConvBNLayer(nn.Layer):
learning_rate=norm_lr, learning_rate=norm_lr,
regularizer=L2Decay(norm_decay), regularizer=L2Decay(norm_decay),
trainable=False if freeze_norm else True) trainable=False if freeze_norm else True)
global_stats = True if freeze_norm else False global_stats = True if freeze_norm else None
if norm_type == 'sync_bn': if norm_type in ['sync_bn', 'bn']:
self.bn = nn.SyncBatchNorm( self.bn = nn.BatchNorm2D(
out_c, weight_attr=param_attr, bias_attr=bias_attr)
else:
self.bn = nn.BatchNorm(
out_c, out_c,
act=None, weight_attr=param_attr,
param_attr=param_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
use_global_stats=global_stats) use_global_stats=global_stats)
norm_params = self.bn.parameters() norm_params = self.bn.parameters()
......
...@@ -100,15 +100,11 @@ class ConvNormLayer(nn.Layer): ...@@ -100,15 +100,11 @@ class ConvNormLayer(nn.Layer):
regularizer=L2Decay(norm_decay), regularizer=L2Decay(norm_decay),
trainable=False if freeze_norm else True) trainable=False if freeze_norm else True)
global_stats = True if freeze_norm else False global_stats = True if freeze_norm else None
if norm_type == 'sync_bn': if norm_type in ['sync_bn', 'bn']:
self.norm = nn.SyncBatchNorm( self.norm = nn.BatchNorm2D(
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
else:
self.norm = nn.BatchNorm(
ch_out, ch_out,
act=None, weight_attr=param_attr,
param_attr=param_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
use_global_stats=global_stats) use_global_stats=global_stats)
norm_params = self.norm.parameters() norm_params = self.norm.parameters()
......
...@@ -51,15 +51,17 @@ class ConvBNLayer(nn.Layer): ...@@ -51,15 +51,17 @@ class ConvBNLayer(nn.Layer):
weight_attr=ParamAttr(initializer=KaimingNormal()), weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False) bias_attr=False)
self._batch_norm = BatchNorm( self._batch_norm = BatchNorm2D(
out_channels, out_channels,
param_attr=ParamAttr(regularizer=L2Decay(0.0)), weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)), bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
act=act) self.act = act
def forward(self, inputs): def forward(self, inputs):
y = self._conv(inputs) y = self._conv(inputs)
y = self._batch_norm(y) y = self._batch_norm(y)
if self.act:
y = getattr(F, self.act)(y)
return y return y
......
...@@ -174,12 +174,9 @@ class ConvNormLayer(nn.Layer): ...@@ -174,12 +174,9 @@ class ConvNormLayer(nn.Layer):
bias_attr = ParamAttr( bias_attr = ParamAttr(
learning_rate=norm_lr, learning_rate=norm_lr,
regularizer=L2Decay(norm_decay) if norm_decay is not None else None) regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
if norm_type == 'bn': if norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm2D( self.norm = nn.BatchNorm2D(
ch_out, weight_attr=param_attr, bias_attr=bias_attr) ch_out, weight_attr=param_attr, bias_attr=bias_attr)
elif norm_type == 'sync_bn':
self.norm = nn.SyncBatchNorm(
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
elif norm_type == 'gn': elif norm_type == 'gn':
self.norm = nn.GroupNorm( self.norm = nn.GroupNorm(
num_groups=norm_groups, num_groups=norm_groups,
......
...@@ -52,10 +52,8 @@ class SeparableConvLayer(nn.Layer): ...@@ -52,10 +52,8 @@ class SeparableConvLayer(nn.Layer):
self.pointwise_conv = nn.Conv2D(in_channels, self.out_channels, 1) self.pointwise_conv = nn.Conv2D(in_channels, self.out_channels, 1)
# norm type # norm type
if self.norm_type == 'bn': if self.norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm2D(self.out_channels) self.norm = nn.BatchNorm2D(self.out_channels)
elif self.norm_type == 'sync_bn':
self.norm = nn.SyncBatchNorm(self.out_channels)
elif self.norm_type == 'gn': elif self.norm_type == 'gn':
self.norm = nn.GroupNorm( self.norm = nn.GroupNorm(
num_groups=self.norm_groups, num_channels=self.out_channels) num_groups=self.norm_groups, num_channels=self.out_channels)
......
...@@ -54,11 +54,8 @@ class ConvBNLayer(nn.Layer): ...@@ -54,11 +54,8 @@ class ConvBNLayer(nn.Layer):
learning_rate=conv_lr, initializer=KaimingNormal()), learning_rate=conv_lr, initializer=KaimingNormal()),
bias_attr=False) bias_attr=False)
if norm_type == 'sync_bn': if norm_type in ['sync_bn', 'bn']:
self._batch_norm = nn.SyncBatchNorm(out_channels) self._batch_norm = nn.BatchNorm2D(out_channels)
else:
self._batch_norm = nn.BatchNorm(
out_channels, act=None, use_global_stats=False)
def forward(self, x): def forward(self, x):
x = self._conv(x) x = self._conv(x)
......
...@@ -50,10 +50,6 @@ def batch_norm(ch, ...@@ -50,10 +50,6 @@ def batch_norm(ch,
freeze_norm=False, freeze_norm=False,
initializer=None, initializer=None,
data_format='NCHW'): data_format='NCHW'):
if norm_type == 'sync_bn':
batch_norm = nn.SyncBatchNorm
else:
batch_norm = nn.BatchNorm2D
norm_lr = 0. if freeze_norm else 1. norm_lr = 0. if freeze_norm else 1.
weight_attr = ParamAttr( weight_attr = ParamAttr(
...@@ -66,11 +62,12 @@ def batch_norm(ch, ...@@ -66,11 +62,12 @@ def batch_norm(ch,
regularizer=L2Decay(norm_decay), regularizer=L2Decay(norm_decay),
trainable=False if freeze_norm else True) trainable=False if freeze_norm else True)
norm_layer = batch_norm( if norm_type in ['sync_bn', 'bn']:
ch, norm_layer = nn.BatchNorm2D(
weight_attr=weight_attr, ch,
bias_attr=bias_attr, weight_attr=weight_attr,
data_format=data_format) bias_attr=bias_attr,
data_format=data_format)
norm_params = norm_layer.parameters() norm_params = norm_layer.parameters()
if freeze_norm: if freeze_norm:
......
...@@ -21,7 +21,7 @@ import paddle.nn as nn ...@@ -21,7 +21,7 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant from paddle.nn.initializer import Normal, Constant
from paddle import ParamAttr from paddle import ParamAttr
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Linear from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Linear
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from paddle.nn.initializer import KaimingNormal, XavierNormal from paddle.nn.initializer import KaimingNormal, XavierNormal
from ppdet.core.workspace import register from ppdet.core.workspace import register
...@@ -77,9 +77,9 @@ class ConvBNLayer(nn.Layer): ...@@ -77,9 +77,9 @@ class ConvBNLayer(nn.Layer):
weight_attr=ParamAttr(initializer=KaimingNormal()), weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False) bias_attr=False)
self.bn = BatchNorm( self.bn = BatchNorm2D(
num_filters, num_filters,
param_attr=ParamAttr(regularizer=L2Decay(0.0)), weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0))) bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.hardswish = nn.Hardswish() self.hardswish = nn.Hardswish()
......
...@@ -55,12 +55,14 @@ class ConvBNLayer(nn.Layer): ...@@ -55,12 +55,14 @@ class ConvBNLayer(nn.Layer):
bias_attr=False, bias_attr=False,
data_format=data_format) data_format=data_format)
self._batch_norm = nn.BatchNorm( self._batch_norm = nn.BatchNorm2D(num_filters, data_layout=data_format)
num_filters, act=act, data_layout=data_format) self.act = act
def forward(self, inputs): def forward(self, inputs):
y = self._conv(inputs) y = self._conv(inputs)
y = self._batch_norm(y) y = self._batch_norm(y)
if self.act:
y = getattr(F, self.act)(y)
return y return y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册