From 3e0e35853de01c316223099a255e8317b9ea1fb1 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Tue, 21 Dec 2021 21:24:12 +0800 Subject: [PATCH] fix solov2 infer (#4972) * support different norm_type for different module when convert to sync batch norm * remove BatchNorm in esnet --- ppdet/engine/trainer.py | 6 +++--- ppdet/modeling/architectures/meta_arch.py | 25 +++++++++++++++++++---- ppdet/modeling/backbones/esnet.py | 2 +- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 2e6afc328..24f846b28 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -33,6 +33,7 @@ from paddle.static import InputSpec from ppdet.optimizer import ModelEMA from ppdet.core.workspace import create +from ppdet.modeling.architectures.meta_arch import BaseArch from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.visualizer import visualize_results, save_result from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval @@ -340,11 +341,10 @@ class Trainer(object): assert self.mode == 'train', "Model not in 'train' mode" Init_mark = False - sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and + sync_bn = (getattr(self.cfg, 'norm_type', None) in [None, 'sync_bn'] and self.cfg.use_gpu and self._nranks > 1) if sync_bn: - self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( - self.model) + self.model = BaseArch.convert_sync_batchnorm(self.model) model = self.model if self.cfg.get('fleet', False): diff --git a/ppdet/modeling/architectures/meta_arch.py b/ppdet/modeling/architectures/meta_arch.py index d01c34735..770f1fd19 100644 --- a/ppdet/modeling/architectures/meta_arch.py +++ b/ppdet/modeling/architectures/meta_arch.py @@ -70,7 +70,7 @@ class BaseArch(nn.Layer): outs.append(self.get_pred()) # multi-scale test - if len(outs)>1: + if len(outs) > 1: out = self.merge_multi_scale_predictions(outs) else: out = outs[0] @@ -87,7 +87,9 @@ class BaseArch(nn.Layer): keep_top_k = self.bbox_post_process.nms.keep_top_k nms_threshold = self.bbox_post_process.nms.nms_threshold else: - raise Exception("Multi scale test only supports CascadeRCNN, FasterRCNN and MaskRCNN for now") + raise Exception( + "Multi scale test only supports CascadeRCNN, FasterRCNN and MaskRCNN for now" + ) final_boxes = [] all_scale_outs = paddle.concat([o['bbox'] for o in outs]).numpy() @@ -96,9 +98,11 @@ class BaseArch(nn.Layer): if np.count_nonzero(idxs) == 0: continue r = nms(all_scale_outs[idxs, 1:], nms_threshold) - final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1)) + final_boxes.append( + np.concatenate([np.full((r.shape[0], 1), c), r], 1)) out = np.concatenate(final_boxes) - out = np.concatenate(sorted(out, key=lambda e: e[1])[-keep_top_k:]).reshape((-1, 6)) + out = np.concatenate(sorted( + out, key=lambda e: e[1])[-keep_top_k:]).reshape((-1, 6)) out = { 'bbox': paddle.to_tensor(out), 'bbox_num': paddle.to_tensor(np.array([out.shape[0], ])) @@ -120,3 +124,16 @@ class BaseArch(nn.Layer): def get_pred(self, ): raise NotImplementedError("Should implement get_pred method!") + + @classmethod + def convert_sync_batchnorm(cls, layer): + layer_output = layer + if getattr(layer, 'norm_type', None) == 'sync_bn': + layer_output = nn.SyncBatchNorm.convert_sync_batchnorm(layer) + else: + for name, sublayer in layer.named_children(): + layer_output.add_sublayer(name, + cls.convert_sync_batchnorm(sublayer)) + + del layer + return layer_output diff --git a/ppdet/modeling/backbones/esnet.py b/ppdet/modeling/backbones/esnet.py index 2b3f3c54a..86c28655d 100644 --- a/ppdet/modeling/backbones/esnet.py +++ b/ppdet/modeling/backbones/esnet.py @@ -20,7 +20,7 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr -from paddle.nn import Conv2D, MaxPool2D, AdaptiveAvgPool2D, BatchNorm +from paddle.nn import Conv2D, MaxPool2D, AdaptiveAvgPool2D from paddle.nn.initializer import KaimingNormal from paddle.regularizer import L2Decay -- GitLab