__init__.py 2.3 KB
Newer Older
W
WuHaobo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

L
littletomatodonkey 已提交
15 16 17 18 19
import copy
import importlib

import paddle.nn as nn

W
weishengyu 已提交
20
from . import backbone
W
weishengyu 已提交
21
from . import gears
W
WuHaobo 已提交
22

W
weishengyu 已提交
23
from .backbone import *
W
weishengyu 已提交
24
from .gears import *
W
WuHaobo 已提交
25
from .utils import *
L
littletomatodonkey 已提交
26

B
Bin Lu 已提交
27 28
__all__ = ["build_model", "RecModel"]

L
littletomatodonkey 已提交
29 30 31 32 33 34 35 36 37 38 39 40

def build_model(config):
    config = copy.deepcopy(config)
    model_type = config.pop("name")
    mod = importlib.import_module(__name__)
    arch = getattr(mod, model_type)(**config)
    return arch


class RecModel(nn.Layer):
    def __init__(self, **config):
        super().__init__()
B
Bin Lu 已提交
41

L
littletomatodonkey 已提交
42 43
        backbone_config = config["Backbone"]
        backbone_name = backbone_config.pop("name")
B
Bin Lu 已提交
44
        self.backbone = eval(backbone_name)(**backbone_config)
L
littletomatodonkey 已提交
45

D
dongshuilong 已提交
46
        assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \
B
Bin Lu 已提交
47
                please specified a Stoplayer config"
D
dongshuilong 已提交
48

B
Bin Lu 已提交
49 50
        stop_layer_config = config["Stoplayer"]
        self.backbone.stop_after(stop_layer_config["name"])
D
dongshuilong 已提交
51

B
Bin Lu 已提交
52
        if stop_layer_config.get("embedding_size", 0) > 0:
D
dongshuilong 已提交
53
            self.neck = nn.Linear(stop_layer_config["output_dim"],
D
dongshuilong 已提交
54
                                  stop_layer_config["embedding_size"])
B
Bin Lu 已提交
55
            embedding_size = stop_layer_config["embedding_size"]
L
littletomatodonkey 已提交
56 57
        else:
            self.neck = None
B
Bin Lu 已提交
58
            embedding_size = stop_layer_config["output_dim"]
D
dongshuilong 已提交
59 60

        assert "Head" in config, "Head should be specified in retrieval task \
B
Bin Lu 已提交
61
                please specify a Head config"
D
dongshuilong 已提交
62

B
Bin Lu 已提交
63 64
        config["Head"]["embedding_size"] = embedding_size
        self.head = build_head(config["Head"])
L
littletomatodonkey 已提交
65

B
Bin Lu 已提交
66 67
    def forward(self, x, label):
        x = self.backbone(x)
L
littletomatodonkey 已提交
68
        if self.neck is not None:
B
Bin Lu 已提交
69 70
            x = self.neck(x)
        y = self.head(x, label)
D
dongshuilong 已提交
71
        return {"features": x, "logits": y}