diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index c2826d6b1194cf38e52042c3fcd8a3ac0e5b8f2e..f241952d2f928aed9cf41e778dbc74741b4357c8 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -18,11 +18,14 @@ import importlib import paddle.nn as nn from . import backbone +from . import head from .backbone import * -from ppcls.arch.loss_metrics.loss import * +from .head import * from .utils import * +__all__ = ["build_model", "RecModel"] + def build_model(config): config = copy.deepcopy(config) @@ -35,31 +38,31 @@ def build_model(config): class RecModel(nn.Layer): def __init__(self, **config): super().__init__() + backbone_config = config["Backbone"] backbone_name = backbone_config.pop("name") - self.backbone = getattr(backbone_name)(**backbone_config) - if "backbone_stop_layer" in config: - backbone_stop_layer = config["backbone_stop_layer"] - self.backbone.stop_layer(backbone_stop_layer) + self.backbone = eval(backbone_name)(**backbone_config) - if "Neck" in config: - neck_config = config["Neck"] - neck_name = neck_config.pop("name") - self.neck = getattr(neck_name)(**neck_config) + assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \ + please specified a Stoplayer config" + stop_layer_config = config["Stoplayer"] + self.backbone.stop_after(stop_layer_config["name"]) + + if stop_layer_config.get("embedding_size", 0) > 0: + self.neck = nn.Linear(stop_layer_config["output_dim"], stop_layer_config["embedding_size"]) + embedding_size = stop_layer_config["embedding_size"] else: self.neck = None + embedding_size = stop_layer_config["output_dim"] + + assert "Head" in config, "Head should be specified in retrieval task \ + please specify a Head config" + config["Head"]["embedding_size"] = embedding_size + self.head = build_head(config["Head"]) - if "Head" in config: - head_config = config["Head"] - head_name = head_config.pop("name") - self.head = getattr(head_name)(**head_config) - else: - self.head = None - - def forward(self, x): - y = self.backbone(x) + def forward(self, x, label): + x = self.backbone(x) if self.neck is not None: - y = self.neck(y) - if self.head is not None: - y = self.head(y) - return y + x = self.neck(x) + y = self.head(x, label) + return {"features":x, "logits":y}