From a9b843259790c0dd0911fbf09d95e1ab6b493d89 Mon Sep 17 00:00:00 2001 From: tianyi1997 <93087391+tianyi1997@users.noreply.github.com> Date: Fri, 10 Feb 2023 21:59:05 +0800 Subject: [PATCH] fix: wrong base class & simplify train func --- ppcls/arch/gears/metabnneck.py | 2 +- ppcls/engine/train/train_metabin.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ppcls/arch/gears/metabnneck.py b/ppcls/arch/gears/metabnneck.py index 7dd78cd8..d2f743da 100644 --- a/ppcls/arch/gears/metabnneck.py +++ b/ppcls/arch/gears/metabnneck.py @@ -22,7 +22,7 @@ import paddle.nn.functional as F from ..utils import get_param_attr_dict -class MetaBN1D(nn.BatchNorm2D): +class MetaBN1D(nn.BatchNorm1D): def forward(self, inputs, opt={}): mode = opt.get("bn_mode", "general") if self.training else "eval" if mode == "general": # update, but not apply running_mean/var diff --git a/ppcls/engine/train/train_metabin.py b/ppcls/engine/train/train_metabin.py index 21a4a786..b4994b69 100644 --- a/ppcls/engine/train/train_metabin.py +++ b/ppcls/engine/train/train_metabin.py @@ -141,15 +141,15 @@ def setup_opt(engine, stage): opt["bn_mode"] = "hold" opt["enable_inside_update"] = True opt["lr_gate"] = norm_lr * cyclic_lr - for name, layer in engine.model.backbone.named_sublayers(): - if "bn" == name.split('.')[-1]: + for layer in engine.model.backbone.sublayers(): + if type_name(layer) == "MetaBIN": layer.setup_opt(opt) engine.model.neck.setup_opt(opt) def reset_opt(model): - for name, layer in model.backbone.named_sublayers(): - if "bn" == name.split('.')[-1]: + for layer in model.backbone.sublayers(): + if type_name(layer) == "MetaBIN": layer.reset_opt() model.neck.reset_opt() -- GitLab