From 85ca658ea296fecdb0c8350a65f811dbaca41488 Mon Sep 17 00:00:00 2001 From: Bin Lu Date: Tue, 1 Jun 2021 11:30:26 +0800 Subject: [PATCH] Update __init__.py --- ppcls/arch/__init__.py | 47 ++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index c2826d6b..f241952d 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -18,11 +18,14 @@ import importlib import paddle.nn as nn from . import backbone +from . import head from .backbone import * -from ppcls.arch.loss_metrics.loss import * +from .head import * from .utils import * +__all__ = ["build_model", "RecModel"] + def build_model(config): config = copy.deepcopy(config) @@ -35,31 +38,31 @@ def build_model(config): class RecModel(nn.Layer): def __init__(self, **config): super().__init__() + backbone_config = config["Backbone"] backbone_name = backbone_config.pop("name") - self.backbone = getattr(backbone_name)(**backbone_config) - if "backbone_stop_layer" in config: - backbone_stop_layer = config["backbone_stop_layer"] - self.backbone.stop_layer(backbone_stop_layer) + self.backbone = eval(backbone_name)(**backbone_config) - if "Neck" in config: - neck_config = config["Neck"] - neck_name = neck_config.pop("name") - self.neck = getattr(neck_name)(**neck_config) + assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \ + please specified a Stoplayer config" + stop_layer_config = config["Stoplayer"] + self.backbone.stop_after(stop_layer_config["name"]) + + if stop_layer_config.get("embedding_size", 0) > 0: + self.neck = nn.Linear(stop_layer_config["output_dim"], stop_layer_config["embedding_size"]) + embedding_size = stop_layer_config["embedding_size"] else: self.neck = None + embedding_size = stop_layer_config["output_dim"] + + assert "Head" in config, "Head should be specified in retrieval task \ + please specify a Head config" + config["Head"]["embedding_size"] = embedding_size + self.head = build_head(config["Head"]) - if "Head" in config: - head_config = config["Head"] - head_name = head_config.pop("name") - self.head = getattr(head_name)(**head_config) - else: - self.head = None - - def forward(self, x): - y = self.backbone(x) + def forward(self, x, label): + x = self.backbone(x) if self.neck is not None: - y = self.neck(y) - if self.head is not None: - y = self.head(y) - return y + x = self.neck(x) + y = self.head(x, label) + return {"features":x, "logits":y} -- GitLab