未验证 提交 ea2f81d8 编写于 作者: W wangxinxin08 提交者: GitHub

refine sync bn (#4361)

* refine sync bn

* fix bugs of batch_norm

* fix bugs while deploying and modify BatchNorm to BatchNorm2D

* param_attr -> weight_attr in BatchNorm2D

* modify BatchNorm to BatchNorm2D
上级 3cf6e926
......@@ -336,6 +336,12 @@ class Trainer(object):
assert self.mode == 'train', "Model not in 'train' mode"
Init_mark = False
sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
self.cfg.use_gpu and self._nranks > 1)
if sync_bn:
self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
self.model)
model = self.model
if self.cfg.get('fleet', False):
model = fleet.distributed_model(model)
......
......@@ -58,11 +58,8 @@ class ConvBNLayer(nn.Layer):
learning_rate=conv_lr, initializer=KaimingNormal()),
bias_attr=False)
if norm_type == 'sync_bn':
self._batch_norm = nn.SyncBatchNorm(out_channels)
else:
self._batch_norm = nn.BatchNorm(
out_channels, act=None, use_global_stats=False)
if norm_type in ['bn', 'sync_bn']:
self._batch_norm = nn.BatchNorm2D(out_channels)
def forward(self, x):
x = self._conv(x)
......
......@@ -62,11 +62,11 @@ class ConvNormLayer(nn.Layer):
learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
bias_attr = ParamAttr(
learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
global_stats = True if freeze_norm else False
global_stats = True if freeze_norm else None
if norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm(
self.norm = nn.BatchNorm2D(
ch_out,
param_attr=param_attr,
weight_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=global_stats)
elif norm_type == 'gn':
......
......@@ -81,9 +81,9 @@ class ConvBNLayer(nn.Layer):
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self.bn = BatchNorm(
self.bn = BatchNorm2D(
num_filters,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.hardswish = nn.Hardswish()
......
......@@ -56,11 +56,11 @@ class ConvNormLayer(nn.Layer):
regularizer=L2Decay(norm_decay), )
bias_attr = ParamAttr(
learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
global_stats = True if freeze_norm else False
global_stats = True if freeze_norm else None
if norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm(
self.norm = nn.BatchNorm2D(
ch_out,
param_attr=param_attr,
weight_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=global_stats, )
elif norm_type == 'gn':
......@@ -582,7 +582,7 @@ class LiteHRNetModule(nn.Layer):
stride=1,
padding=0,
bias=False, ),
nn.BatchNorm(self.in_channels[i]),
nn.BatchNorm2D(self.in_channels[i]),
nn.Upsample(
scale_factor=2**(j - i), mode='nearest')))
elif j == i:
......@@ -601,7 +601,7 @@ class LiteHRNetModule(nn.Layer):
padding=1,
groups=self.in_channels[j],
bias=False, ),
nn.BatchNorm(self.in_channels[j]),
nn.BatchNorm2D(self.in_channels[j]),
L.Conv2d(
self.in_channels[j],
self.in_channels[i],
......@@ -609,7 +609,7 @@ class LiteHRNetModule(nn.Layer):
stride=1,
padding=0,
bias=False, ),
nn.BatchNorm(self.in_channels[i])))
nn.BatchNorm2D(self.in_channels[i])))
else:
conv_downsamples.append(
nn.Sequential(
......@@ -621,7 +621,7 @@ class LiteHRNetModule(nn.Layer):
padding=1,
groups=self.in_channels[j],
bias=False, ),
nn.BatchNorm(self.in_channels[j]),
nn.BatchNorm2D(self.in_channels[j]),
L.Conv2d(
self.in_channels[j],
self.in_channels[j],
......@@ -629,7 +629,7 @@ class LiteHRNetModule(nn.Layer):
stride=1,
padding=0,
bias=False, ),
nn.BatchNorm(self.in_channels[j]),
nn.BatchNorm2D(self.in_channels[j]),
nn.ReLU()))
fuse_layer.append(nn.Sequential(*conv_downsamples))
......@@ -777,7 +777,7 @@ class LiteHRNet(nn.Layer):
padding=1,
groups=num_channels_pre_layer[i],
bias=False),
nn.BatchNorm(num_channels_pre_layer[i]),
nn.BatchNorm2D(num_channels_pre_layer[i]),
L.Conv2d(
num_channels_pre_layer[i],
num_channels_cur_layer[i],
......@@ -785,7 +785,7 @@ class LiteHRNet(nn.Layer):
stride=1,
padding=0,
bias=False, ),
nn.BatchNorm(num_channels_cur_layer[i]),
nn.BatchNorm2D(num_channels_cur_layer[i]),
nn.ReLU()))
else:
transition_layers.append(None)
......@@ -802,7 +802,7 @@ class LiteHRNet(nn.Layer):
stride=2,
padding=1,
bias=False, ),
nn.BatchNorm(num_channels_pre_layer[-1]),
nn.BatchNorm2D(num_channels_pre_layer[-1]),
L.Conv2d(
num_channels_pre_layer[-1],
num_channels_cur_layer[i]
......@@ -812,9 +812,9 @@ class LiteHRNet(nn.Layer):
stride=1,
padding=0,
bias=False, ),
nn.BatchNorm(num_channels_cur_layer[i]
if j == i - num_branches_pre else
num_channels_pre_layer[-1]),
nn.BatchNorm2D(num_channels_cur_layer[i]
if j == i - num_branches_pre else
num_channels_pre_layer[-1]),
nn.ReLU()))
transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.LayerList(transition_layers)
......
......@@ -59,16 +59,9 @@ class ConvBNLayer(nn.Layer):
param_attr = ParamAttr(regularizer=L2Decay(norm_decay))
bias_attr = ParamAttr(regularizer=L2Decay(norm_decay))
if norm_type == 'sync_bn':
self._batch_norm = nn.SyncBatchNorm(
if norm_type in ['sync_bn', 'bn']:
self._batch_norm = nn.BatchNorm2D(
out_channels, weight_attr=param_attr, bias_attr=bias_attr)
else:
self._batch_norm = nn.BatchNorm(
out_channels,
act=None,
param_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=False)
def forward(self, x):
x = self._conv(x)
......
......@@ -74,15 +74,11 @@ class ConvBNLayer(nn.Layer):
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay),
trainable=False if freeze_norm else True)
global_stats = True if freeze_norm else False
if norm_type == 'sync_bn':
self.bn = nn.SyncBatchNorm(
out_c, weight_attr=param_attr, bias_attr=bias_attr)
else:
self.bn = nn.BatchNorm(
global_stats = True if freeze_norm else None
if norm_type in ['sync_bn', 'bn']:
self.bn = nn.BatchNorm2D(
out_c,
act=None,
param_attr=param_attr,
weight_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=global_stats)
norm_params = self.bn.parameters()
......
......@@ -100,15 +100,11 @@ class ConvNormLayer(nn.Layer):
regularizer=L2Decay(norm_decay),
trainable=False if freeze_norm else True)
global_stats = True if freeze_norm else False
if norm_type == 'sync_bn':
self.norm = nn.SyncBatchNorm(
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
else:
self.norm = nn.BatchNorm(
global_stats = True if freeze_norm else None
if norm_type in ['sync_bn', 'bn']:
self.norm = nn.BatchNorm2D(
ch_out,
act=None,
param_attr=param_attr,
weight_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=global_stats)
norm_params = self.norm.parameters()
......
......@@ -51,15 +51,17 @@ class ConvBNLayer(nn.Layer):
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self._batch_norm = BatchNorm(
self._batch_norm = BatchNorm2D(
out_channels,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)),
act=act)
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.act = act
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
if self.act:
y = getattr(F, self.act)(y)
return y
......
......@@ -174,12 +174,9 @@ class ConvNormLayer(nn.Layer):
bias_attr = ParamAttr(
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
if norm_type == 'bn':
if norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm2D(
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
elif norm_type == 'sync_bn':
self.norm = nn.SyncBatchNorm(
ch_out, weight_attr=param_attr, bias_attr=bias_attr)
elif norm_type == 'gn':
self.norm = nn.GroupNorm(
num_groups=norm_groups,
......
......@@ -52,10 +52,8 @@ class SeparableConvLayer(nn.Layer):
self.pointwise_conv = nn.Conv2D(in_channels, self.out_channels, 1)
# norm type
if self.norm_type == 'bn':
if self.norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm2D(self.out_channels)
elif self.norm_type == 'sync_bn':
self.norm = nn.SyncBatchNorm(self.out_channels)
elif self.norm_type == 'gn':
self.norm = nn.GroupNorm(
num_groups=self.norm_groups, num_channels=self.out_channels)
......
......@@ -54,11 +54,8 @@ class ConvBNLayer(nn.Layer):
learning_rate=conv_lr, initializer=KaimingNormal()),
bias_attr=False)
if norm_type == 'sync_bn':
self._batch_norm = nn.SyncBatchNorm(out_channels)
else:
self._batch_norm = nn.BatchNorm(
out_channels, act=None, use_global_stats=False)
if norm_type in ['sync_bn', 'bn']:
self._batch_norm = nn.BatchNorm2D(out_channels)
def forward(self, x):
x = self._conv(x)
......
......@@ -50,10 +50,6 @@ def batch_norm(ch,
freeze_norm=False,
initializer=None,
data_format='NCHW'):
if norm_type == 'sync_bn':
batch_norm = nn.SyncBatchNorm
else:
batch_norm = nn.BatchNorm2D
norm_lr = 0. if freeze_norm else 1.
weight_attr = ParamAttr(
......@@ -66,11 +62,12 @@ def batch_norm(ch,
regularizer=L2Decay(norm_decay),
trainable=False if freeze_norm else True)
norm_layer = batch_norm(
ch,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format)
if norm_type in ['sync_bn', 'bn']:
norm_layer = nn.BatchNorm2D(
ch,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format)
norm_params = norm_layer.parameters()
if freeze_norm:
......
......@@ -21,7 +21,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant
from paddle import ParamAttr
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Linear
from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Linear
from paddle.regularizer import L2Decay
from paddle.nn.initializer import KaimingNormal, XavierNormal
from ppdet.core.workspace import register
......@@ -77,9 +77,9 @@ class ConvBNLayer(nn.Layer):
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self.bn = BatchNorm(
self.bn = BatchNorm2D(
num_filters,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.hardswish = nn.Hardswish()
......
......@@ -55,12 +55,14 @@ class ConvBNLayer(nn.Layer):
bias_attr=False,
data_format=data_format)
self._batch_norm = nn.BatchNorm(
num_filters, act=act, data_layout=data_format)
self._batch_norm = nn.BatchNorm2D(num_filters, data_layout=data_format)
self.act = act
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
if self.act:
y = getattr(F, self.act)(y)
return y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册