From a4e1da6610000cb890e2d26008bc913353b12acc Mon Sep 17 00:00:00 2001 From: zhiboniu Date: Wed, 25 May 2022 08:13:38 +0000 Subject: [PATCH] modify attr export model --- ppcls/arch/backbone/legendary_models/resnet.py | 3 ++- ppcls/configs/Attr/StrongBaselineAttr.yaml | 3 +-- ppcls/engine/engine.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index ca75c2ea..40bc563b 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -287,7 +287,8 @@ class ResNet(TheseusLayer): data_format="NCHW", input_image_channel=3, return_patterns=None, - return_stages=None): + return_stages=None, + **kargs): super().__init__() self.cfg = config diff --git a/ppcls/configs/Attr/StrongBaselineAttr.yaml b/ppcls/configs/Attr/StrongBaselineAttr.yaml index 7501669b..2324015d 100644 --- a/ppcls/configs/Attr/StrongBaselineAttr.yaml +++ b/ppcls/configs/Attr/StrongBaselineAttr.yaml @@ -20,6 +20,7 @@ Arch: name: "ResNet50" pretrained: True class_num: 26 + infer_add_softmax: False # loss function config for traing/eval process Loss: @@ -110,5 +111,3 @@ DataLoader: Metric: Eval: - ATTRMetric: - - diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index ef24094c..d4924b26 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -457,7 +457,9 @@ class Engine(object): def export(self): assert self.mode == "export" - use_multilabel = self.config["Global"].get("use_multilabel", False) + use_multilabel = self.config["Global"].get( + "use_multilabel", + False) and not "ATTRMetric" in self.config["Metric"]["Eval"][0] model = ExportModel(self.config["Arch"], self.model, use_multilabel) if self.config["Global"]["pretrained_model"] is not None: load_dygraph_pretrain(model.base_model, -- GitLab