__init__.py 6.3 KB
Newer Older
D
dongshuilong 已提交
1
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

15
import sys
L
littletomatodonkey 已提交
16
import copy
G
gaotingquan 已提交
17
import importlib
18

L
littletomatodonkey 已提交
19
import paddle.nn as nn
A
Aurelius84 已提交
20 21
from paddle.jit import to_static
from paddle.static import InputSpec
L
littletomatodonkey 已提交
22

G
gaotingquan 已提交
23
from . import backbone
G
gaotingquan 已提交
24
from .backbone import *
D
dongshuilong 已提交
25
from .gears import build_gear
W
WuHaobo 已提交
26
from .utils import *
R
root 已提交
27 28 29 30 31
from .backbone.base.theseus_layer import TheseusLayer
from ..utils import logger
from ..utils.save_load import load_dygraph_pretrain
from .slim import prune_model, quantize_model
from .distill.afd_attention import LinearTransformStudent, LinearTransformTeacher
32
from ..utils.amp import AMPForwardDecorator
W
weishengyu 已提交
33

wc晨曦's avatar
wc晨曦 已提交
34
__all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"]
B
Bin Lu 已提交
35

L
littletomatodonkey 已提交
36

littletomatodonkey's avatar
littletomatodonkey 已提交
37
def build_model(config, mode="train"):
W
weishengyu 已提交
38 39
    arch_config = copy.deepcopy(config["Arch"])
    model_type = arch_config.pop("name")
C
cuicheng01 已提交
40
    use_sync_bn = arch_config.pop("use_sync_bn", False)
41

G
gaotingquan 已提交
42 43
    mod = importlib.import_module(__name__)
    model = getattr(mod, model_type)(**arch_config)
44

C
cuicheng01 已提交
45
    if use_sync_bn:
46
        if config["Global"]["device"] == "gpu":
47
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
48 49 50
        else:
            msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored."
            logger.warning(msg)
C
cuicheng01 已提交
51

52 53 54
    if isinstance(model, TheseusLayer):
        prune_model(config, model)
        quantize_model(config, model, mode)
55

56 57 58
    # set @to_static for benchmark, skip this by default.
    model = apply_to_static(config, model)

59 60 61 62 63 64
    if AMPForwardDecorator.amp_level:
        model = paddle.amp.decorate(
            models=model,
            level=AMPForwardDecorator.amp_level,
            save_dtype='float32')

65
    return model
L
littletomatodonkey 已提交
66 67


A
Aurelius84 已提交
68 69 70 71 72 73 74
def apply_to_static(config, model):
    support_to_static = config['Global'].get('to_static', False)

    if support_to_static:
        specs = None
        if 'image_shape' in config['Global']:
            specs = [InputSpec([None] + config['Global']['image_shape'])]
75
            specs[0].stop_gradient = True
A
Aurelius84 已提交
76 77 78 79 80 81
        model = to_static(model, input_spec=specs)
        logger.info("Successfully to apply @to_static with specs: {}".format(
            specs))
    return model


W
weishengyu 已提交
82
class RecModel(TheseusLayer):
L
littletomatodonkey 已提交
83 84 85 86
    def __init__(self, **config):
        super().__init__()
        backbone_config = config["Backbone"]
        backbone_name = backbone_config.pop("name")
G
gaotingquan 已提交
87
        self.backbone = getattr(backbone, backbone_name)(**backbone_config)
Z
zh-hike 已提交
88 89
        self.head_feature_from = config.get('head_feature_from', 'neck')

D
dongshuilong 已提交
90
        if "BackboneStopLayer" in config:
D
dongshuilong 已提交
91 92
            backbone_stop_layer = config["BackboneStopLayer"]["name"]
            self.backbone.stop_after(backbone_stop_layer)
D
dongshuilong 已提交
93

D
dongshuilong 已提交
94 95
        if "Neck" in config:
            self.neck = build_gear(config["Neck"])
L
littletomatodonkey 已提交
96 97
        else:
            self.neck = None
D
dongshuilong 已提交
98

D
dongshuilong 已提交
99 100 101 102
        if "Head" in config:
            self.head = build_gear(config["Head"])
        else:
            self.head = None
L
littletomatodonkey 已提交
103

G
debug  
gaotingquan 已提交
104
    def forward(self, x, label=None):
105
        out = dict()
D
dongshuilong 已提交
106
        x = self.backbone(x)
107
        out["backbone"] = x
L
littletomatodonkey 已提交
108
        if self.neck is not None:
109 110 111
            feat = self.neck(x)
            out["neck"] = feat
        out["features"] = out['neck'] if self.neck else x
D
dongshuilong 已提交
112
        if self.head is not None:
Z
zh-hike 已提交
113 114 115 116
            if self.head_feature_from == 'backbone':
                y = self.head(out['backbone'], label)
            elif self.head_feature_from == 'neck':
                y = self.head(out['features'], label)
littletomatodonkey's avatar
littletomatodonkey 已提交
117
            out["logits"] = y
118
        return out
119 120 121 122 123 124


class DistillationModel(nn.Layer):
    def __init__(self,
                 models=None,
                 pretrained_list=None,
125 126
                 freeze_params_list=None,
                 **kargs):
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        super().__init__()
        assert isinstance(models, list)
        self.model_list = []
        self.model_name_list = []
        if pretrained_list is not None:
            assert len(pretrained_list) == len(models)

        if freeze_params_list is None:
            freeze_params_list = [False] * len(models)
        assert len(freeze_params_list) == len(models)
        for idx, model_config in enumerate(models):
            assert len(model_config) == 1
            key = list(model_config.keys())[0]
            model_config = model_config[key]
            model_name = model_config.pop("name")
            model = eval(model_name)(**model_config)

            if freeze_params_list[idx]:
                for param in model.parameters():
                    param.trainable = False
            self.model_list.append(self.add_sublayer(key, model))
            self.model_name_list.append(key)

        if pretrained_list is not None:
            for idx, pretrained in enumerate(pretrained_list):
                if pretrained is not None:
                    load_dygraph_pretrain(
                        self.model_name_list[idx], path=pretrained)

G
debug  
gaotingquan 已提交
156
    def forward(self, x, label=None):
157 158 159 160 161
        result_dict = dict()
        for idx, model_name in enumerate(self.model_name_list):
            if label is None:
                result_dict[model_name] = self.model_list[idx](x)
            else:
162
                result_dict[model_name] = self.model_list[idx](x, label)
163
        return result_dict
wc晨曦's avatar
wc晨曦 已提交
164 165 166 167 168 169 170 171 172 173


class AttentionModel(DistillationModel):
    def __init__(self,
                 models=None,
                 pretrained_list=None,
                 freeze_params_list=None,
                 **kargs):
        super().__init__(models, pretrained_list, freeze_params_list, **kargs)

G
debug  
gaotingquan 已提交
174
    def forward(self, x, label=None):
wc晨曦's avatar
wc晨曦 已提交
175 176 177 178 179 180 181 182 183
        result_dict = dict()
        out = x
        for idx, model_name in enumerate(self.model_name_list):
            if label is None:
                out = self.model_list[idx](out)
                result_dict.update(out)
            else:
                out = self.model_list[idx](out, label)
                result_dict.update(out)
184
        return result_dict