From 97935164fe608435f00c317c141cea5ec4e34df2 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Sun, 5 Mar 2023 08:48:32 +0000 Subject: [PATCH] use decorator to parse batch --- ppcls/arch/__init__.py | 26 ++++---------------------- ppcls/arch/backbone/base/__init__.py | 6 ++++++ 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index 1ec112de..77b66c42 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -14,12 +14,14 @@ import sys import copy +import importlib import paddle.nn as nn from paddle.jit import to_static from paddle.static import InputSpec from . import backbone +from .backbone import * from .gears import build_gear from .utils import * from .backbone.base.theseus_layer import TheseusLayer @@ -36,11 +38,8 @@ def build_model(config, mode="train"): model_type = arch_config.pop("name") use_sync_bn = arch_config.pop("use_sync_bn", False) - if hasattr(backbone, model_type): - model = ClassModel(model_type, **arch_config) - else: - model = getattr(sys.modules[__name__], model_type)("ClassModel", - **arch_config) + mod = importlib.import_module(__name__) + model = getattr(mod, model_type)(**arch_config) if use_sync_bn: if config["Global"]["device"] == "gpu": @@ -73,23 +72,6 @@ def apply_to_static(config, model): return model -# TODO(gaotingquan): export model -class ClassModel(TheseusLayer): - def __init__(self, model_type, **config): - super().__init__() - if model_type == "ClassModel": - backbone_config = config["Backbone"] - backbone_name = backbone_config.pop("name") - else: - backbone_name = model_type - backbone_config = config - self.backbone = getattr(backbone, backbone_name)(**backbone_config) - - def forward(self, batch): - x, label = batch[0], batch[1] - return self.backbone(x) - - class RecModel(TheseusLayer): def __init__(self, **config): super().__init__() diff --git a/ppcls/arch/backbone/base/__init__.py b/ppcls/arch/backbone/base/__init__.py index e69de29b..7a1fec9b 100644 --- a/ppcls/arch/backbone/base/__init__.py +++ b/ppcls/arch/backbone/base/__init__.py @@ -0,0 +1,6 @@ +def clas_forward_decorator(forward_func): + def parse_batch_wrapper(model, batch): + x, label = batch[0], batch[1] + return forward_func(model, x) + + return parse_batch_wrapper -- GitLab