__init__.py 1.9 KB
Newer Older
D
dongshuilong 已提交
1
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

L
littletomatodonkey 已提交
15 16 17 18 19
import copy
import importlib

import paddle.nn as nn

D
dongshuilong 已提交
20
from . import backbone, gears
W
weishengyu 已提交
21
from .backbone import *
D
dongshuilong 已提交
22
from .gears import build_gear
W
WuHaobo 已提交
23
from .utils import *
L
littletomatodonkey 已提交
24

B
Bin Lu 已提交
25 26
__all__ = ["build_model", "RecModel"]

L
littletomatodonkey 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40

def build_model(config):
    config = copy.deepcopy(config)
    model_type = config.pop("name")
    mod = importlib.import_module(__name__)
    arch = getattr(mod, model_type)(**config)
    return arch


class RecModel(nn.Layer):
    def __init__(self, **config):
        super().__init__()
        backbone_config = config["Backbone"]
        backbone_name = backbone_config.pop("name")
D
dongshuilong 已提交
41
        self.backbone = eval(backbone_name)(**backbone_config)
D
dongshuilong 已提交
42
        if "BackboneStopLayer" in config:
D
dongshuilong 已提交
43 44
            backbone_stop_layer = config["BackboneStopLayer"]["name"]
            self.backbone.stop_after(backbone_stop_layer)
D
dongshuilong 已提交
45

D
dongshuilong 已提交
46 47
        if "Neck" in config:
            self.neck = build_gear(config["Neck"])
L
littletomatodonkey 已提交
48 49
        else:
            self.neck = None
D
dongshuilong 已提交
50

D
dongshuilong 已提交
51 52 53 54
        if "Head" in config:
            self.head = build_gear(config["Head"])
        else:
            self.head = None
L
littletomatodonkey 已提交
55

W
weishengyu 已提交
56
    def forward(self, x, label=None):
D
dongshuilong 已提交
57
        x = self.backbone(x)
L
littletomatodonkey 已提交
58
        if self.neck is not None:
D
dongshuilong 已提交
59
            x = self.neck(x)
D
dongshuilong 已提交
60
        if self.head is not None:
D
dongshuilong 已提交
61
            y = self.head(x, label)
W
dbg  
weishengyu 已提交
62 63
        else:
            y = None
D
dongshuilong 已提交
64
        return {"features": x, "logits": y}