__init__.py 6.8 KB
Newer Older
D
dongshuilong 已提交
1
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
W
WuHaobo 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

15
import sys
L
littletomatodonkey 已提交
16
import copy
17

L
littletomatodonkey 已提交
18
import paddle.nn as nn
A
Aurelius84 已提交
19 20
from paddle.jit import to_static
from paddle.static import InputSpec
L
littletomatodonkey 已提交
21

22
from . import backbone as backbone_zoo
D
dongshuilong 已提交
23
from .gears import build_gear
W
WuHaobo 已提交
24
from .utils import *
R
root 已提交
25 26 27 28 29
from .backbone.base.theseus_layer import TheseusLayer
from ..utils import logger
from ..utils.save_load import load_dygraph_pretrain
from .slim import prune_model, quantize_model
from .distill.afd_attention import LinearTransformStudent, LinearTransformTeacher
W
weishengyu 已提交
30

wc晨曦's avatar
wc晨曦 已提交
31
__all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"]
B
Bin Lu 已提交
32

L
littletomatodonkey 已提交
33

littletomatodonkey's avatar
littletomatodonkey 已提交
34
def build_model(config, mode="train"):
W
weishengyu 已提交
35 36
    arch_config = copy.deepcopy(config["Arch"])
    model_type = arch_config.pop("name")
C
cuicheng01 已提交
37
    use_sync_bn = arch_config.pop("use_sync_bn", False)
38 39 40 41 42 43 44

    if hasattr(backbone_zoo, model_type):
        model = ClassModel(model_type, **arch_config)
    else:
        model = getattr(sys.modules[__name__], model_type)("ClassModel",
                                                           **arch_config)

C
cuicheng01 已提交
45
    if use_sync_bn:
46
        if config["Global"]["device"] == "gpu":
47
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
48 49 50
        else:
            msg = "SyncBatchNorm can only be used on GPU device. The releated setting has been ignored."
            logger.warning(msg)
C
cuicheng01 已提交
51

52 53 54
    if isinstance(model, TheseusLayer):
        prune_model(config, model)
        quantize_model(config, model, mode)
55

56 57 58 59
    # set @to_static for benchmark, skip this by default.
    model = apply_to_static(config, model)

    return model
L
littletomatodonkey 已提交
60 61


A
Aurelius84 已提交
62 63 64 65 66 67 68
def apply_to_static(config, model):
    support_to_static = config['Global'].get('to_static', False)

    if support_to_static:
        specs = None
        if 'image_shape' in config['Global']:
            specs = [InputSpec([None] + config['Global']['image_shape'])]
69
            specs[0].stop_gradient = True
A
Aurelius84 已提交
70 71 72 73 74 75
        model = to_static(model, input_spec=specs)
        logger.info("Successfully to apply @to_static with specs: {}".format(
            specs))
    return model


76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
# TODO(gaotingquan): export model
class ClassModel(TheseusLayer):
    def __init__(self, model_type, **config):
        super().__init__()
        if model_type == "ClassModel":
            backbone_config = config["Backbone"]
            backbone_name = backbone_config.pop("name")
        else:
            backbone_name = model_type
            backbone_config = config
        self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config)

    def forward(self, batch):
        x, label = batch[0], batch[1]
        return self.backbone(x)


W
weishengyu 已提交
93
class RecModel(TheseusLayer):
L
littletomatodonkey 已提交
94 95 96 97
    def __init__(self, **config):
        super().__init__()
        backbone_config = config["Backbone"]
        backbone_name = backbone_config.pop("name")
98
        self.backbone = getattr(backbone_zoo, backbone_name)(**backbone_config)
Z
zh-hike 已提交
99 100
        self.head_feature_from = config.get('head_feature_from', 'neck')

D
dongshuilong 已提交
101
        if "BackboneStopLayer" in config:
D
dongshuilong 已提交
102 103
            backbone_stop_layer = config["BackboneStopLayer"]["name"]
            self.backbone.stop_after(backbone_stop_layer)
D
dongshuilong 已提交
104

D
dongshuilong 已提交
105 106
        if "Neck" in config:
            self.neck = build_gear(config["Neck"])
L
littletomatodonkey 已提交
107 108
        else:
            self.neck = None
D
dongshuilong 已提交
109

D
dongshuilong 已提交
110 111 112 113
        if "Head" in config:
            self.head = build_gear(config["Head"])
        else:
            self.head = None
L
littletomatodonkey 已提交
114

115 116
    def forward(self, batch):
        x, label = batch[0], batch[1]
117
        out = dict()
D
dongshuilong 已提交
118
        x = self.backbone(x)
119
        out["backbone"] = x
L
littletomatodonkey 已提交
120
        if self.neck is not None:
121 122 123
            feat = self.neck(x)
            out["neck"] = feat
        out["features"] = out['neck'] if self.neck else x
D
dongshuilong 已提交
124
        if self.head is not None:
Z
zh-hike 已提交
125 126 127 128
            if self.head_feature_from == 'backbone':
                y = self.head(out['backbone'], label)
            elif self.head_feature_from == 'neck':
                y = self.head(out['features'], label)
littletomatodonkey's avatar
littletomatodonkey 已提交
129
            out["logits"] = y
130
        return out
131 132 133 134 135 136


class DistillationModel(nn.Layer):
    def __init__(self,
                 models=None,
                 pretrained_list=None,
137 138
                 freeze_params_list=None,
                 **kargs):
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        super().__init__()
        assert isinstance(models, list)
        self.model_list = []
        self.model_name_list = []
        if pretrained_list is not None:
            assert len(pretrained_list) == len(models)

        if freeze_params_list is None:
            freeze_params_list = [False] * len(models)
        assert len(freeze_params_list) == len(models)
        for idx, model_config in enumerate(models):
            assert len(model_config) == 1
            key = list(model_config.keys())[0]
            model_config = model_config[key]
            model_name = model_config.pop("name")
            model = eval(model_name)(**model_config)

            if freeze_params_list[idx]:
                for param in model.parameters():
                    param.trainable = False
            self.model_list.append(self.add_sublayer(key, model))
            self.model_name_list.append(key)

        if pretrained_list is not None:
            for idx, pretrained in enumerate(pretrained_list):
                if pretrained is not None:
                    load_dygraph_pretrain(
                        self.model_name_list[idx], path=pretrained)

168 169
    def forward(self, batch):
        x, label = batch[0], batch[1]
170 171 172 173 174
        result_dict = dict()
        for idx, model_name in enumerate(self.model_name_list):
            if label is None:
                result_dict[model_name] = self.model_list[idx](x)
            else:
175
                result_dict[model_name] = self.model_list[idx](x, label)
176
        return result_dict
wc晨曦's avatar
wc晨曦 已提交
177 178 179 180 181 182 183 184 185 186


class AttentionModel(DistillationModel):
    def __init__(self,
                 models=None,
                 pretrained_list=None,
                 freeze_params_list=None,
                 **kargs):
        super().__init__(models, pretrained_list, freeze_params_list, **kargs)

187 188
    def forward(self, batch):
        x, label = batch[0], batch[1]
wc晨曦's avatar
wc晨曦 已提交
189 190 191 192 193 194 195 196 197
        result_dict = dict()
        out = x
        for idx, model_name in enumerate(self.model_name_list):
            if label is None:
                out = self.model_list[idx](out)
                result_dict.update(out)
            else:
                out = self.model_list[idx](out, label)
                result_dict.update(out)
198
        return result_dict