diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 455e74474d40daebf05d59ed419ba536a77a14d5..9d3774cfcc338e6ee21d42b19572ae1e5e5e54a8 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -336,6 +336,12 @@ class Trainer(object): assert self.mode == 'train', "Model not in 'train' mode" Init_mark = False + sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and + self.cfg.use_gpu and self._nranks > 1) + if sync_bn: + self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( + self.model) + model = self.model if self.cfg.get('fleet', False): model = fleet.distributed_model(model) diff --git a/ppdet/modeling/backbones/blazenet.py b/ppdet/modeling/backbones/blazenet.py index 425f2a86eb4c1a87ecbb8c0c40a169ce95d39223..fbfdcec9de9f6caa7c2ad68c4c828ba48c66b8dd 100644 --- a/ppdet/modeling/backbones/blazenet.py +++ b/ppdet/modeling/backbones/blazenet.py @@ -58,11 +58,8 @@ class ConvBNLayer(nn.Layer): learning_rate=conv_lr, initializer=KaimingNormal()), bias_attr=False) - if norm_type == 'sync_bn': - self._batch_norm = nn.SyncBatchNorm(out_channels) - else: - self._batch_norm = nn.BatchNorm( - out_channels, act=None, use_global_stats=False) + if norm_type in ['bn', 'sync_bn']: + self._batch_norm = nn.BatchNorm2D(out_channels) def forward(self, x): x = self._conv(x) diff --git a/ppdet/modeling/backbones/hrnet.py b/ppdet/modeling/backbones/hrnet.py index d92aa95f539f12dfd9fb9d12237cf5b70d6f7c2e..0f09aedcaf7bc3552fd322ab670b25ebbd543dd4 100644 --- a/ppdet/modeling/backbones/hrnet.py +++ b/ppdet/modeling/backbones/hrnet.py @@ -62,11 +62,11 @@ class ConvNormLayer(nn.Layer): learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) bias_attr = ParamAttr( learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) - global_stats = True if freeze_norm else False + global_stats = True if freeze_norm else None if norm_type in ['bn', 'sync_bn']: - self.norm = nn.BatchNorm( + self.norm = nn.BatchNorm2D( ch_out, - param_attr=param_attr, + weight_attr=param_attr, bias_attr=bias_attr, use_global_stats=global_stats) elif norm_type == 'gn': diff --git a/ppdet/modeling/backbones/lcnet.py b/ppdet/modeling/backbones/lcnet.py index fd8ad4e46f5623165cc7cad09ab612465331554a..d4e3a2c15c31e0e2795caa7b4578c88cd2ee081d 100644 --- a/ppdet/modeling/backbones/lcnet.py +++ b/ppdet/modeling/backbones/lcnet.py @@ -81,9 +81,9 @@ class ConvBNLayer(nn.Layer): weight_attr=ParamAttr(initializer=KaimingNormal()), bias_attr=False) - self.bn = BatchNorm( + self.bn = BatchNorm2D( num_filters, - param_attr=ParamAttr(regularizer=L2Decay(0.0)), + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), bias_attr=ParamAttr(regularizer=L2Decay(0.0))) self.hardswish = nn.Hardswish() diff --git a/ppdet/modeling/backbones/lite_hrnet.py b/ppdet/modeling/backbones/lite_hrnet.py index 52bad3cbb423ef6ddedd1f1e66e75a2cc61134a9..d6832c509afb4dcd3311d276eab18dcd9679bf43 100644 --- a/ppdet/modeling/backbones/lite_hrnet.py +++ b/ppdet/modeling/backbones/lite_hrnet.py @@ -56,11 +56,11 @@ class ConvNormLayer(nn.Layer): regularizer=L2Decay(norm_decay), ) bias_attr = ParamAttr( learning_rate=norm_lr, regularizer=L2Decay(norm_decay)) - global_stats = True if freeze_norm else False + global_stats = True if freeze_norm else None if norm_type in ['bn', 'sync_bn']: - self.norm = nn.BatchNorm( + self.norm = nn.BatchNorm2D( ch_out, - param_attr=param_attr, + weight_attr=param_attr, bias_attr=bias_attr, use_global_stats=global_stats, ) elif norm_type == 'gn': @@ -582,7 +582,7 @@ class LiteHRNetModule(nn.Layer): stride=1, padding=0, bias=False, ), - nn.BatchNorm(self.in_channels[i]), + nn.BatchNorm2D(self.in_channels[i]), nn.Upsample( scale_factor=2**(j - i), mode='nearest'))) elif j == i: @@ -601,7 +601,7 @@ class LiteHRNetModule(nn.Layer): padding=1, groups=self.in_channels[j], bias=False, ), - nn.BatchNorm(self.in_channels[j]), + nn.BatchNorm2D(self.in_channels[j]), L.Conv2d( self.in_channels[j], self.in_channels[i], @@ -609,7 +609,7 @@ class LiteHRNetModule(nn.Layer): stride=1, padding=0, bias=False, ), - nn.BatchNorm(self.in_channels[i]))) + nn.BatchNorm2D(self.in_channels[i]))) else: conv_downsamples.append( nn.Sequential( @@ -621,7 +621,7 @@ class LiteHRNetModule(nn.Layer): padding=1, groups=self.in_channels[j], bias=False, ), - nn.BatchNorm(self.in_channels[j]), + nn.BatchNorm2D(self.in_channels[j]), L.Conv2d( self.in_channels[j], self.in_channels[j], @@ -629,7 +629,7 @@ class LiteHRNetModule(nn.Layer): stride=1, padding=0, bias=False, ), - nn.BatchNorm(self.in_channels[j]), + nn.BatchNorm2D(self.in_channels[j]), nn.ReLU())) fuse_layer.append(nn.Sequential(*conv_downsamples)) @@ -777,7 +777,7 @@ class LiteHRNet(nn.Layer): padding=1, groups=num_channels_pre_layer[i], bias=False), - nn.BatchNorm(num_channels_pre_layer[i]), + nn.BatchNorm2D(num_channels_pre_layer[i]), L.Conv2d( num_channels_pre_layer[i], num_channels_cur_layer[i], @@ -785,7 +785,7 @@ class LiteHRNet(nn.Layer): stride=1, padding=0, bias=False, ), - nn.BatchNorm(num_channels_cur_layer[i]), + nn.BatchNorm2D(num_channels_cur_layer[i]), nn.ReLU())) else: transition_layers.append(None) @@ -802,7 +802,7 @@ class LiteHRNet(nn.Layer): stride=2, padding=1, bias=False, ), - nn.BatchNorm(num_channels_pre_layer[-1]), + nn.BatchNorm2D(num_channels_pre_layer[-1]), L.Conv2d( num_channels_pre_layer[-1], num_channels_cur_layer[i] @@ -812,9 +812,9 @@ class LiteHRNet(nn.Layer): stride=1, padding=0, bias=False, ), - nn.BatchNorm(num_channels_cur_layer[i] - if j == i - num_branches_pre else - num_channels_pre_layer[-1]), + nn.BatchNorm2D(num_channels_cur_layer[i] + if j == i - num_branches_pre else + num_channels_pre_layer[-1]), nn.ReLU())) transition_layers.append(nn.Sequential(*conv_downsamples)) return nn.LayerList(transition_layers) diff --git a/ppdet/modeling/backbones/mobilenet_v1.py b/ppdet/modeling/backbones/mobilenet_v1.py index 7b9fa80eb894efafea1e42eb62a495744b7c792d..a39435be5289b47ef4ad8ac73580d9fe4cb21d10 100644 --- a/ppdet/modeling/backbones/mobilenet_v1.py +++ b/ppdet/modeling/backbones/mobilenet_v1.py @@ -59,16 +59,9 @@ class ConvBNLayer(nn.Layer): param_attr = ParamAttr(regularizer=L2Decay(norm_decay)) bias_attr = ParamAttr(regularizer=L2Decay(norm_decay)) - if norm_type == 'sync_bn': - self._batch_norm = nn.SyncBatchNorm( + if norm_type in ['sync_bn', 'bn']: + self._batch_norm = nn.BatchNorm2D( out_channels, weight_attr=param_attr, bias_attr=bias_attr) - else: - self._batch_norm = nn.BatchNorm( - out_channels, - act=None, - param_attr=param_attr, - bias_attr=bias_attr, - use_global_stats=False) def forward(self, x): x = self._conv(x) diff --git a/ppdet/modeling/backbones/mobilenet_v3.py b/ppdet/modeling/backbones/mobilenet_v3.py index 02021e87c036fbc1bc20f8a27eed1f6d964c7738..2bd88567a1487437a067ec68497ee9f3b62b4d47 100644 --- a/ppdet/modeling/backbones/mobilenet_v3.py +++ b/ppdet/modeling/backbones/mobilenet_v3.py @@ -74,15 +74,11 @@ class ConvBNLayer(nn.Layer): learning_rate=norm_lr, regularizer=L2Decay(norm_decay), trainable=False if freeze_norm else True) - global_stats = True if freeze_norm else False - if norm_type == 'sync_bn': - self.bn = nn.SyncBatchNorm( - out_c, weight_attr=param_attr, bias_attr=bias_attr) - else: - self.bn = nn.BatchNorm( + global_stats = True if freeze_norm else None + if norm_type in ['sync_bn', 'bn']: + self.bn = nn.BatchNorm2D( out_c, - act=None, - param_attr=param_attr, + weight_attr=param_attr, bias_attr=bias_attr, use_global_stats=global_stats) norm_params = self.bn.parameters() diff --git a/ppdet/modeling/backbones/resnet.py b/ppdet/modeling/backbones/resnet.py index d4bc878eacd117653ab10b3cbffc93e3ec7a879e..6f8eb0b89ccd3cab1d0a08b7fd2a41e6332c3055 100755 --- a/ppdet/modeling/backbones/resnet.py +++ b/ppdet/modeling/backbones/resnet.py @@ -100,15 +100,11 @@ class ConvNormLayer(nn.Layer): regularizer=L2Decay(norm_decay), trainable=False if freeze_norm else True) - global_stats = True if freeze_norm else False - if norm_type == 'sync_bn': - self.norm = nn.SyncBatchNorm( - ch_out, weight_attr=param_attr, bias_attr=bias_attr) - else: - self.norm = nn.BatchNorm( + global_stats = True if freeze_norm else None + if norm_type in ['sync_bn', 'bn']: + self.norm = nn.BatchNorm2D( ch_out, - act=None, - param_attr=param_attr, + weight_attr=param_attr, bias_attr=bias_attr, use_global_stats=global_stats) norm_params = self.norm.parameters() diff --git a/ppdet/modeling/backbones/shufflenet_v2.py b/ppdet/modeling/backbones/shufflenet_v2.py index 59b0502a1b5a27cdfe1b0ac37c36276a22a14466..059b15ed7d8515609f291bf3946d4f75656f835d 100644 --- a/ppdet/modeling/backbones/shufflenet_v2.py +++ b/ppdet/modeling/backbones/shufflenet_v2.py @@ -51,15 +51,17 @@ class ConvBNLayer(nn.Layer): weight_attr=ParamAttr(initializer=KaimingNormal()), bias_attr=False) - self._batch_norm = BatchNorm( + self._batch_norm = BatchNorm2D( out_channels, - param_attr=ParamAttr(regularizer=L2Decay(0.0)), - bias_attr=ParamAttr(regularizer=L2Decay(0.0)), - act=act) + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self.act = act def forward(self, inputs): y = self._conv(inputs) y = self._batch_norm(y) + if self.act: + y = getattr(F, self.act)(y) return y diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 73da16a147234dd441b92627a9781038488e68ad..894fa3c8f7991df8b89737623fead0ff726cc945 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -174,12 +174,9 @@ class ConvNormLayer(nn.Layer): bias_attr = ParamAttr( learning_rate=norm_lr, regularizer=L2Decay(norm_decay) if norm_decay is not None else None) - if norm_type == 'bn': + if norm_type in ['bn', 'sync_bn']: self.norm = nn.BatchNorm2D( ch_out, weight_attr=param_attr, bias_attr=bias_attr) - elif norm_type == 'sync_bn': - self.norm = nn.SyncBatchNorm( - ch_out, weight_attr=param_attr, bias_attr=bias_attr) elif norm_type == 'gn': self.norm = nn.GroupNorm( num_groups=norm_groups, diff --git a/ppdet/modeling/necks/bifpn.py b/ppdet/modeling/necks/bifpn.py index c607608930c866d36815838408c63dcfee58e8b1..9e794b8f50b92de6a98cc15ebcc3bca6cfaccf41 100644 --- a/ppdet/modeling/necks/bifpn.py +++ b/ppdet/modeling/necks/bifpn.py @@ -52,10 +52,8 @@ class SeparableConvLayer(nn.Layer): self.pointwise_conv = nn.Conv2D(in_channels, self.out_channels, 1) # norm type - if self.norm_type == 'bn': + if self.norm_type in ['bn', 'sync_bn']: self.norm = nn.BatchNorm2D(self.out_channels) - elif self.norm_type == 'sync_bn': - self.norm = nn.SyncBatchNorm(self.out_channels) elif self.norm_type == 'gn': self.norm = nn.GroupNorm( num_groups=self.norm_groups, num_channels=self.out_channels) diff --git a/ppdet/modeling/necks/blazeface_fpn.py b/ppdet/modeling/necks/blazeface_fpn.py index 18d7f3cf19cb4eea5728698cd03362cec548a5d1..b903c97b290b82d6c528cce8a4b205decd42e28b 100644 --- a/ppdet/modeling/necks/blazeface_fpn.py +++ b/ppdet/modeling/necks/blazeface_fpn.py @@ -54,11 +54,8 @@ class ConvBNLayer(nn.Layer): learning_rate=conv_lr, initializer=KaimingNormal()), bias_attr=False) - if norm_type == 'sync_bn': - self._batch_norm = nn.SyncBatchNorm(out_channels) - else: - self._batch_norm = nn.BatchNorm( - out_channels, act=None, use_global_stats=False) + if norm_type in ['sync_bn', 'bn']: + self._batch_norm = nn.BatchNorm2D(out_channels) def forward(self, x): x = self._conv(x) diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 294bcf2a12d52970b567a11a4079ed5bafc40f55..b157da5a269bcc1860db7cb18f92ed75c169aaa0 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -50,10 +50,6 @@ def batch_norm(ch, freeze_norm=False, initializer=None, data_format='NCHW'): - if norm_type == 'sync_bn': - batch_norm = nn.SyncBatchNorm - else: - batch_norm = nn.BatchNorm2D norm_lr = 0. if freeze_norm else 1. weight_attr = ParamAttr( @@ -66,11 +62,12 @@ def batch_norm(ch, regularizer=L2Decay(norm_decay), trainable=False if freeze_norm else True) - norm_layer = batch_norm( - ch, - weight_attr=weight_attr, - bias_attr=bias_attr, - data_format=data_format) + if norm_type in ['sync_bn', 'bn']: + norm_layer = nn.BatchNorm2D( + ch, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format) norm_params = norm_layer.parameters() if freeze_norm: diff --git a/ppdet/modeling/reid/pplcnet_embedding.py b/ppdet/modeling/reid/pplcnet_embedding.py index cad9f85beea4baddfab429bb11302d20c17bf429..d360f89149d807069345e6255d86190d517376b3 100644 --- a/ppdet/modeling/reid/pplcnet_embedding.py +++ b/ppdet/modeling/reid/pplcnet_embedding.py @@ -21,7 +21,7 @@ import paddle.nn as nn import paddle.nn.functional as F from paddle.nn.initializer import Normal, Constant from paddle import ParamAttr -from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Linear +from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Linear from paddle.regularizer import L2Decay from paddle.nn.initializer import KaimingNormal, XavierNormal from ppdet.core.workspace import register @@ -77,9 +77,9 @@ class ConvBNLayer(nn.Layer): weight_attr=ParamAttr(initializer=KaimingNormal()), bias_attr=False) - self.bn = BatchNorm( + self.bn = BatchNorm2D( num_filters, - param_attr=ParamAttr(regularizer=L2Decay(0.0)), + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), bias_attr=ParamAttr(regularizer=L2Decay(0.0))) self.hardswish = nn.Hardswish() diff --git a/ppdet/modeling/reid/resnet.py b/ppdet/modeling/reid/resnet.py index 968fe9774f116c846cd372c7086dc9671d135b7c..c2261e0d0776b2823633d52c384bf416314d1e0b 100644 --- a/ppdet/modeling/reid/resnet.py +++ b/ppdet/modeling/reid/resnet.py @@ -55,12 +55,14 @@ class ConvBNLayer(nn.Layer): bias_attr=False, data_format=data_format) - self._batch_norm = nn.BatchNorm( - num_filters, act=act, data_layout=data_format) + self._batch_norm = nn.BatchNorm2D(num_filters, data_layout=data_format) + self.act = act def forward(self, inputs): y = self._conv(inputs) y = self._batch_norm(y) + if self.act: + y = getattr(F, self.act)(y) return y