__init__.py 3.6 KB
Newer Older
D
dongshuilong 已提交
1
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

L
littletomatodonkey 已提交
15 16 17 18 19
import copy
import importlib

import paddle.nn as nn

D
dongshuilong 已提交
20
from . import backbone, gears
W
weishengyu 已提交
21
from .backbone import *
D
dongshuilong 已提交
22
from .gears import build_gear
W
WuHaobo 已提交
23
from .utils import *
24
from ppcls.utils.save_load import load_dygraph_pretrain
L
littletomatodonkey 已提交
25

26
__all__ = ["build_model", "RecModel", "DistillationModel"]
B
Bin Lu 已提交
27

L
littletomatodonkey 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41

def build_model(config):
    config = copy.deepcopy(config)
    model_type = config.pop("name")
    mod = importlib.import_module(__name__)
    arch = getattr(mod, model_type)(**config)
    return arch


class RecModel(nn.Layer):
    def __init__(self, **config):
        super().__init__()
        backbone_config = config["Backbone"]
        backbone_name = backbone_config.pop("name")
D
dongshuilong 已提交
42
        self.backbone = eval(backbone_name)(**backbone_config)
D
dongshuilong 已提交
43
        if "BackboneStopLayer" in config:
D
dongshuilong 已提交
44 45
            backbone_stop_layer = config["BackboneStopLayer"]["name"]
            self.backbone.stop_after(backbone_stop_layer)
D
dongshuilong 已提交
46

D
dongshuilong 已提交
47 48
        if "Neck" in config:
            self.neck = build_gear(config["Neck"])
L
littletomatodonkey 已提交
49 50
        else:
            self.neck = None
D
dongshuilong 已提交
51

D
dongshuilong 已提交
52 53 54 55
        if "Head" in config:
            self.head = build_gear(config["Head"])
        else:
            self.head = None
L
littletomatodonkey 已提交
56

W
weishengyu 已提交
57
    def forward(self, x, label=None):
D
dongshuilong 已提交
58
        x = self.backbone(x)
L
littletomatodonkey 已提交
59
        if self.neck is not None:
D
dongshuilong 已提交
60
            x = self.neck(x)
D
dongshuilong 已提交
61
        if self.head is not None:
D
dongshuilong 已提交
62
            y = self.head(x, label)
W
dbg  
weishengyu 已提交
63 64
        else:
            y = None
D
dongshuilong 已提交
65
        return {"features": x, "logits": y}
66 67 68 69 70 71


class DistillationModel(nn.Layer):
    def __init__(self,
                 models=None,
                 pretrained_list=None,
72 73
                 freeze_params_list=None,
                 **kargs):
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
        super().__init__()
        assert isinstance(models, list)
        self.model_list = []
        self.model_name_list = []
        if pretrained_list is not None:
            assert len(pretrained_list) == len(models)

        if freeze_params_list is None:
            freeze_params_list = [False] * len(models)
        assert len(freeze_params_list) == len(models)
        for idx, model_config in enumerate(models):
            assert len(model_config) == 1
            key = list(model_config.keys())[0]
            model_config = model_config[key]
            model_name = model_config.pop("name")
            model = eval(model_name)(**model_config)

            if freeze_params_list[idx]:
                for param in model.parameters():
                    param.trainable = False
            self.model_list.append(self.add_sublayer(key, model))
            self.model_name_list.append(key)

        if pretrained_list is not None:
            for idx, pretrained in enumerate(pretrained_list):
                if pretrained is not None:
                    load_dygraph_pretrain(
                        self.model_name_list[idx], path=pretrained)

    def forward(self, x, label=None):
        result_dict = dict()
        for idx, model_name in enumerate(self.model_name_list):
            if label is None:
                result_dict[model_name] = self.model_list[idx](x)
            else:
109
                result_dict[model_name] = self.model_list[idx](x, label)
110
        return result_dict