diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index 1ec112de5107a7d20dc2db54c15a1c680099a89a..77b66c421ab3ce1712c83e92fdbfb55abb0d2302 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -14,12 +14,14 @@ import sys import copy +import importlib import paddle.nn as nn from paddle.jit import to_static from paddle.static import InputSpec from . import backbone +from .backbone import * from .gears import build_gear from .utils import * from .backbone.base.theseus_layer import TheseusLayer @@ -36,11 +38,8 @@ def build_model(config, mode="train"): model_type = arch_config.pop("name") use_sync_bn = arch_config.pop("use_sync_bn", False) - if hasattr(backbone, model_type): - model = ClassModel(model_type, **arch_config) - else: - model = getattr(sys.modules[__name__], model_type)("ClassModel", - **arch_config) + mod = importlib.import_module(__name__) + model = getattr(mod, model_type)(**arch_config) if use_sync_bn: if config["Global"]["device"] == "gpu": @@ -73,23 +72,6 @@ def apply_to_static(config, model): return model -# TODO(gaotingquan): export model -class ClassModel(TheseusLayer): - def __init__(self, model_type, **config): - super().__init__() - if model_type == "ClassModel": - backbone_config = config["Backbone"] - backbone_name = backbone_config.pop("name") - else: - backbone_name = model_type - backbone_config = config - self.backbone = getattr(backbone, backbone_name)(**backbone_config) - - def forward(self, batch): - x, label = batch[0], batch[1] - return self.backbone(x) - - class RecModel(TheseusLayer): def __init__(self, **config): super().__init__() diff --git a/ppcls/arch/backbone/base/__init__.py b/ppcls/arch/backbone/base/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..7a1fec9b8b2b69786aabb7ac8000fc2adadff1ee 100644 --- a/ppcls/arch/backbone/base/__init__.py +++ b/ppcls/arch/backbone/base/__init__.py @@ -0,0 +1,6 @@ +def clas_forward_decorator(forward_func): + def parse_batch_wrapper(model, batch): + x, label = batch[0], batch[1] + return forward_func(model, x) + + return parse_batch_wrapper