From fbf981d1e48d1c6879abc5377044f6f8e7dc51d9 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Fri, 7 Jan 2022 18:00:54 +0800 Subject: [PATCH] Revert "fix solov2 infer (#4972)" (#5074) This reverts commit 3e0e35853de01c316223099a255e8317b9ea1fb1. --- ppdet/engine/trainer.py | 6 +++--- ppdet/modeling/architectures/meta_arch.py | 25 ++++------------------- ppdet/modeling/backbones/esnet.py | 2 +- 3 files changed, 8 insertions(+), 25 deletions(-) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index b8d3b2aca..360a71efc 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -34,7 +34,6 @@ from paddle.static import InputSpec from ppdet.optimizer import ModelEMA from ppdet.core.workspace import create -from ppdet.modeling.architectures.meta_arch import BaseArch from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.visualizer import visualize_results, save_result from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval @@ -346,10 +345,11 @@ class Trainer(object): assert self.mode == 'train', "Model not in 'train' mode" Init_mark = False - sync_bn = (getattr(self.cfg, 'norm_type', None) in [None, 'sync_bn'] and + sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and self.cfg.use_gpu and self._nranks > 1) if sync_bn: - self.model = BaseArch.convert_sync_batchnorm(self.model) + self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( + self.model) model = self.model if self.cfg.get('fleet', False): diff --git a/ppdet/modeling/architectures/meta_arch.py b/ppdet/modeling/architectures/meta_arch.py index 770f1fd19..d01c34735 100644 --- a/ppdet/modeling/architectures/meta_arch.py +++ b/ppdet/modeling/architectures/meta_arch.py @@ -70,7 +70,7 @@ class BaseArch(nn.Layer): outs.append(self.get_pred()) # multi-scale test - if len(outs) > 1: + if len(outs)>1: out = self.merge_multi_scale_predictions(outs) else: out = outs[0] @@ -87,9 +87,7 @@ class BaseArch(nn.Layer): keep_top_k = self.bbox_post_process.nms.keep_top_k nms_threshold = self.bbox_post_process.nms.nms_threshold else: - raise Exception( - "Multi scale test only supports CascadeRCNN, FasterRCNN and MaskRCNN for now" - ) + raise Exception("Multi scale test only supports CascadeRCNN, FasterRCNN and MaskRCNN for now") final_boxes = [] all_scale_outs = paddle.concat([o['bbox'] for o in outs]).numpy() @@ -98,11 +96,9 @@ class BaseArch(nn.Layer): if np.count_nonzero(idxs) == 0: continue r = nms(all_scale_outs[idxs, 1:], nms_threshold) - final_boxes.append( - np.concatenate([np.full((r.shape[0], 1), c), r], 1)) + final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1)) out = np.concatenate(final_boxes) - out = np.concatenate(sorted( - out, key=lambda e: e[1])[-keep_top_k:]).reshape((-1, 6)) + out = np.concatenate(sorted(out, key=lambda e: e[1])[-keep_top_k:]).reshape((-1, 6)) out = { 'bbox': paddle.to_tensor(out), 'bbox_num': paddle.to_tensor(np.array([out.shape[0], ])) @@ -124,16 +120,3 @@ class BaseArch(nn.Layer): def get_pred(self, ): raise NotImplementedError("Should implement get_pred method!") - - @classmethod - def convert_sync_batchnorm(cls, layer): - layer_output = layer - if getattr(layer, 'norm_type', None) == 'sync_bn': - layer_output = nn.SyncBatchNorm.convert_sync_batchnorm(layer) - else: - for name, sublayer in layer.named_children(): - layer_output.add_sublayer(name, - cls.convert_sync_batchnorm(sublayer)) - - del layer - return layer_output diff --git a/ppdet/modeling/backbones/esnet.py b/ppdet/modeling/backbones/esnet.py index 86c28655d..2b3f3c54a 100644 --- a/ppdet/modeling/backbones/esnet.py +++ b/ppdet/modeling/backbones/esnet.py @@ -20,7 +20,7 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr -from paddle.nn import Conv2D, MaxPool2D, AdaptiveAvgPool2D +from paddle.nn import Conv2D, MaxPool2D, AdaptiveAvgPool2D, BatchNorm from paddle.nn.initializer import KaimingNormal from paddle.regularizer import L2Decay -- GitLab