From 145d155623abde6ad9dcc9357c05dfa437ef42a6 Mon Sep 17 00:00:00 2001 From: Wenyu Date: Thu, 16 Jun 2022 13:52:20 +0800 Subject: [PATCH] add layer_norm for convnext outputs (#6201) --- ppdet/modeling/backbones/convnext.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/ppdet/modeling/backbones/convnext.py b/ppdet/modeling/backbones/convnext.py index a16c45c34..5b6f80157 100644 --- a/ppdet/modeling/backbones/convnext.py +++ b/ppdet/modeling/backbones/convnext.py @@ -141,6 +141,7 @@ class ConvNeXt(nn.Layer): layer_scale_init_value=1e-6, head_init_scale=1., return_idx=[1, 2, 3], + norm_output=True, pretrained=None, ): super().__init__() @@ -178,6 +179,14 @@ class ConvNeXt(nn.Layer): self.return_idx = return_idx self.dims = [dims[i] for i in return_idx] # [::-1] + self.norm_output = norm_output + if norm_output: + self.norms = nn.LayerList([ + LayerNorm( + c, eps=1e-6, data_format="channels_first") + for c in self.dims + ]) + self.apply(self._init_weights) # self.head.weight.set_value(self.head.weight.numpy() * head_init_scale) # self.head.bias.set_value(self.head.weight.numpy() * head_init_scale) @@ -202,9 +211,11 @@ class ConvNeXt(nn.Layer): x = self.stages[i](x) output.append(x) - output = [output[i] for i in self.return_idx] + outputs = [output[i] for i in self.return_idx] + if self.norm_output: + outputs = [self.norms[i](out) for i, out in enumerate(outputs)] - return output + return outputs def forward(self, x): x = self.forward_features(x['image']) -- GitLab