convert_super.py 31.5 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import decorator
import logging
C
ceci3 已提交
18
import numbers
C
ceci3 已提交
19
import paddle
C
ceci3 已提交
20
from ...common import get_logger
C
ceci3 已提交
21 22 23 24 25
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
if pd_ver == 185:
    import paddle.fluid.dygraph.nn as nn
    from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding
C
ceci3 已提交
26
    from paddle.fluid import ParamAttr
27 28
    from .layers_old import *
    from . import layers_old as layers
C
ceci3 已提交
29
    Layer = paddle.fluid.dygraph.Layer
C
ceci3 已提交
30 31 32
else:
    import paddle.nn as nn
    from paddle.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding
C
ceci3 已提交
33
    from paddle import ParamAttr
34 35
    from .layers import *
    from . import layers
C
ceci3 已提交
36
    Layer = paddle.nn.Layer
C
ceci3 已提交
37 38 39

_logger = get_logger(__name__, level=logging.INFO)

C
ceci3 已提交
40
__all__ = ['supernet', 'Convert']
C
ceci3 已提交
41

C
ceci3 已提交
42
WEIGHT_LAYER = ['conv', 'linear', 'embedding']
C
ceci3 已提交
43 44 45


class Convert:
46 47 48 49 50 51 52 53 54 55 56
    """
    Convert network to the supernet according to the search space.
    Parameters:
        context(paddleslim.nas.ofa.supernet): search space defined by the user.
    Examples:
        .. code-block:: python
          from paddleslim.nas.ofa import supernet, Convert
          sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
          convert = Convert(sp_net_config)
    """

C
ceci3 已提交
57 58 59
    def __init__(self, context):
        self.context = context

C
ceci3 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
    def _change_name(self, layer, pd_ver, has_bias=True, conv=False):
        if conv:
            w_attr = layer._param_attr
        else:
            w_attr = layer._param_attr if pd_ver == 185 else layer._weight_attr

        if isinstance(w_attr, ParamAttr):
            if w_attr != None and not isinstance(w_attr, bool):
                w_attr.name = 'super_' + w_attr.name

        if has_bias:
            if isinstance(layer._bias_attr, ParamAttr):
                if layer._bias_attr != None and not isinstance(layer._bias_attr,
                                                               bool):
                    layer._bias_attr.name = 'super_' + layer._bias_attr.name

C
ceci3 已提交
76
    def convert(self, network):
77 78 79 80 81 82 83 84 85 86 87
        """
        The function to convert the network to a supernet.
        Parameters:
            network(paddle.nn.Layer|list(paddle.nn.Layer)): instance of the model or list of instance of layers.
        Examples:
            .. code-block:: python
              from paddle.vision.models import mobilenet_v1
              from paddleslim.nas.ofa import supernet, Convert
              sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
              convert = Convert(sp_net_config).convert(mobilenet_v1())
        """
C
ceci3 已提交
88 89
        # search the first and last weight layer, don't change out channel of the last weight layer
        # don't change in channel of the first weight layer
C
ceci3 已提交
90 91 92 93 94 95 96
        model = []
        if isinstance(network, Layer):
            for name, sublayer in network.named_sublayers():
                model.append(sublayer)
        else:
            model = network

C
ceci3 已提交
97 98 99 100
        first_weight_layer_idx = -1
        last_weight_layer_idx = -1
        weight_layer_count = 0
        # NOTE: pre_channel store for shortcut module
C
ceci3 已提交
101
        pre_channel = None
C
ceci3 已提交
102 103 104
        cur_channel = None
        for idx, layer in enumerate(model):
            cls_name = layer.__class__.__name__.lower()
C
ceci3 已提交
105
            if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name:
C
ceci3 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
                weight_layer_count += 1
                last_weight_layer_idx = idx
                if first_weight_layer_idx == -1:
                    first_weight_layer_idx = idx

        if getattr(self.context, 'channel', None) != None:
            assert len(
                self.context.channel
            ) == weight_layer_count, "length of channel must same as weight layer."

        for idx, layer in enumerate(model):
            if isinstance(layer, Conv2D):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']

                new_attr_name = [
C
ceci3 已提交
122
                    'stride', 'padding', 'dilation', 'groups', 'bias_attr'
C
ceci3 已提交
123
                ]
C
ceci3 已提交
124 125 126 127 128 129
                if pd_ver == 185:
                    new_attr_name += ['param_attr', 'use_cudnn', 'act', 'dtype']
                else:
                    new_attr_name += [
                        'weight_attr', 'data_format', 'padding_mode'
                    ]
C
ceci3 已提交
130

C
ceci3 已提交
131
                self._change_name(layer, pd_ver, conv=True)
C
ceci3 已提交
132
                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
133
                new_attr_dict['candidate_config'] = dict()
C
ceci3 已提交
134 135 136 137 138 139 140 141
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                    new_attr_dict['num_filters'] = None
                    new_attr_dict['filter_size'] = None
                else:
                    new_attr_dict['in_channels'] = None
                    new_attr_dict['out_channels'] = None
                    new_attr_dict['kernel_size'] = None
C
ceci3 已提交
142 143 144
                self.kernel_size = getattr(self.context, 'kernel_size', None)

                # if the kernel_size of conv is 1, don't change it.
C
ceci3 已提交
145 146 147
                fks = '_filter_size' if '_filter_size' in attr_dict.keys(
                ) else '_kernel_size'

C
ceci3 已提交
148
                ks = [attr_dict[fks]] if isinstance(
C
ceci3 已提交
149 150 151 152 153
                    attr_dict[fks], numbers.Integral) else attr_dict[fks]

                if self.kernel_size and int(ks[0]) != 1:
                    new_attr_dict['transform_kernel'] = True
                    new_attr_dict[fks[1:]] = max(self.kernel_size)
C
ceci3 已提交
154 155 156 157
                    new_attr_dict['candidate_config'].update({
                        'kernel_size': self.kernel_size
                    })
                else:
C
ceci3 已提交
158
                    new_attr_dict[fks[1:]] = attr_dict[fks]
C
ceci3 已提交
159

C
ceci3 已提交
160 161 162 163
                in_key = '_num_channels' if '_num_channels' in attr_dict.keys(
                ) else '_in_channels'
                out_key = '_num_filters' if '_num_filters' in attr_dict.keys(
                ) else '_out_channels'
C
ceci3 已提交
164 165 166
                if self.context.expand:
                    ### first super convolution
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
167
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
168
                    else:
C
ceci3 已提交
169 170 171
                        new_attr_dict[in_key[1:]] = int(self.context.expand *
                                                        attr_dict[in_key])

C
ceci3 已提交
172 173
                    ### last super convolution
                    if idx == last_weight_layer_idx:
C
ceci3 已提交
174
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
175
                    else:
C
ceci3 已提交
176 177
                        new_attr_dict[out_key[1:]] = int(self.context.expand *
                                                         attr_dict[out_key])
C
ceci3 已提交
178 179 180 181 182
                        new_attr_dict['candidate_config'].update({
                            'expand_ratio': self.context.expand_ratio
                        })
                elif self.context.channel:
                    if attr_dict['_groups'] != None and (
C
ceci3 已提交
183 184
                            int(attr_dict['_groups']) == int(attr_dict[in_key])
                    ):
C
ceci3 已提交
185 186 187 188 189 190 191 192 193 194
                        ### depthwise conv, if conv is depthwise, use pre channel as cur_channel
                        _logger.warn(
                        "If convolution is a depthwise conv, output channel change" \
                        " to the same channel with input, output channel in search is not used."
                        )
                        cur_channel = pre_channel
                    else:
                        cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
195
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
196
                    else:
C
ceci3 已提交
197
                        new_attr_dict[in_key[1:]] = max(pre_channel)
C
ceci3 已提交
198 199

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
200
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
201
                    else:
C
ceci3 已提交
202
                        new_attr_dict[out_key[1:]] = max(cur_channel)
C
ceci3 已提交
203 204 205 206 207
                        new_attr_dict['candidate_config'].update({
                            'channel': cur_channel
                        })
                        pre_channel = cur_channel
                else:
C
ceci3 已提交
208 209
                    new_attr_dict[in_key[1:]] = attr_dict[in_key]
                    new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
210 211

                for attr in new_attr_name:
C
ceci3 已提交
212 213 214 215
                    if attr == 'weight_attr':
                        new_attr_dict[attr] = attr_dict['_param_attr']
                    else:
                        new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
216 217 218 219 220 221 222

                del layer

                if attr_dict['_groups'] == None or int(attr_dict[
                        '_groups']) == 1:
                    ### standard conv
                    layer = Block(SuperConv2D(**new_attr_dict), key=key)
C
ceci3 已提交
223
                elif int(attr_dict['_groups']) == int(attr_dict[in_key]):
C
ceci3 已提交
224 225 226
                    # if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
                    # channel in candidate_config = in_channel_list
                    if 'channel' in new_attr_dict['candidate_config']:
C
ceci3 已提交
227 228
                        new_attr_dict[in_key[1:]] = max(cur_channel)
                        new_attr_dict[out_key[1:]] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
229 230
                        new_attr_dict['candidate_config'][
                            'channel'] = cur_channel
C
ceci3 已提交
231
                    new_attr_dict['groups'] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
232 233 234 235 236 237 238
                    layer = Block(
                        SuperDepthwiseConv2D(**new_attr_dict), key=key)
                else:
                    ### group conv
                    layer = Block(SuperGroupConv2D(**new_attr_dict), key=key)
                model[idx] = layer

C
ceci3 已提交
239 240 241 242
            elif isinstance(layer,
                            getattr(nn, 'BatchNorm2D', nn.BatchNorm)) and (
                                getattr(self.context, 'expand', None) != None or
                                getattr(self.context, 'channel', None) != None):
C
ceci3 已提交
243 244 245 246 247
                # num_features in BatchNorm don't change after last weight operators
                if idx > last_weight_layer_idx:
                    continue

                attr_dict = layer.__dict__
C
ceci3 已提交
248 249 250 251 252 253 254 255 256 257
                new_attr_name = ['momentum', 'epsilon', 'bias_attr']

                if pd_ver == 185:
                    new_attr_name += [
                        'param_attr', 'act', 'dtype', 'in_place', 'data_layout',
                        'is_test', 'use_global_stats', 'trainable_statistics'
                    ]
                else:
                    new_attr_name += ['weight_attr', 'data_format', 'name']

C
ceci3 已提交
258
                self._change_name(layer, pd_ver)
C
ceci3 已提交
259 260 261 262 263 264 265
                new_attr_dict = dict.fromkeys(new_attr_name, None)
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                else:
                    new_attr_dict['num_features'] = None
                new_key = 'num_channels' if 'num_channels' in new_attr_dict.keys(
                ) else 'num_features'
C
ceci3 已提交
266
                if self.context.expand:
C
ceci3 已提交
267 268
                    new_attr_dict[new_key] = int(
                        self.context.expand *
C
ceci3 已提交
269 270
                        layer._parameters['weight'].shape[0])
                elif self.context.channel:
C
ceci3 已提交
271
                    new_attr_dict[new_key] = max(cur_channel)
C
ceci3 已提交
272
                else:
C
ceci3 已提交
273 274 275
                    new_attr_dict[new_key] = attr_dict[
                        '_num_channels'] if '_num_channels' in attr_dict.keys(
                        ) else attr_dict['_num_features']
C
ceci3 已提交
276 277

                for attr in new_attr_name:
C
ceci3 已提交
278
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
279 280 281

                del layer, attr_dict

C
ceci3 已提交
282 283 284
                layer = layers.SuperBatchNorm(
                    **new_attr_dict
                ) if pd_ver == 185 else layers.SuperBatchNorm2D(**new_attr_dict)
C
ceci3 已提交
285 286 287 288 289 290 291 292 293
                model[idx] = layer

            ### assume output_size = None, filter_size != None
            ### NOTE: output_size != None may raise error, solve when it happend. 
            elif isinstance(layer, Conv2DTranspose):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']

                new_attr_name = [
C
ceci3 已提交
294
                    'stride', 'padding', 'dilation', 'groups', 'bias_attr'
C
ceci3 已提交
295
                ]
C
ceci3 已提交
296 297 298
                assert getattr(
                    attr_dict, '_filter_size', '_kernel_size'
                ) != None, "Conv2DTranspose only support kernel size != None now"
C
ceci3 已提交
299

C
ceci3 已提交
300 301 302 303 304 305 306 307 308 309
                if pd_ver == 185:
                    new_attr_name += [
                        'output_size', 'param_attr', 'use_cudnn', 'act', 'dtype'
                    ]
                else:
                    new_attr_name += [
                        'output_padding', 'weight_attr', 'data_format'
                    ]

                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
310
                new_attr_dict['candidate_config'] = dict()
C
ceci3 已提交
311 312 313 314 315 316 317 318
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                    new_attr_dict['num_filters'] = None
                    new_attr_dict['filter_size'] = None
                else:
                    new_attr_dict['in_channels'] = None
                    new_attr_dict['out_channels'] = None
                    new_attr_dict['kernel_size'] = None
C
ceci3 已提交
319 320

                self._change_name(layer, pd_ver, conv=True)
C
ceci3 已提交
321 322 323
                self.kernel_size = getattr(self.context, 'kernel_size', None)

                # if the kernel_size of conv transpose is 1, don't change it.
C
ceci3 已提交
324 325
                fks = '_filter_size' if '_filter_size' in attr_dict.keys(
                ) else '_kernel_size'
C
ceci3 已提交
326
                ks = [attr_dict[fks]] if isinstance(
C
ceci3 已提交
327 328 329 330 331
                    attr_dict[fks], numbers.Integral) else attr_dict[fks]

                if self.kernel_size and int(ks[0]) != 1:
                    new_attr_dict['transform_kernel'] = True
                    new_attr_dict[fks[1:]] = max(self.kernel_size)
C
ceci3 已提交
332 333 334 335
                    new_attr_dict['candidate_config'].update({
                        'kernel_size': self.kernel_size
                    })
                else:
C
ceci3 已提交
336
                    new_attr_dict[fks[1:]] = attr_dict[fks]
C
ceci3 已提交
337

C
ceci3 已提交
338 339 340 341
                in_key = '_num_channels' if '_num_channels' in attr_dict.keys(
                ) else '_in_channels'
                out_key = '_num_filters' if '_num_filters' in attr_dict.keys(
                ) else '_out_channels'
C
ceci3 已提交
342 343 344
                if self.context.expand:
                    ### first super convolution transpose
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
345
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
346
                    else:
C
ceci3 已提交
347 348
                        new_attr_dict[in_key[1:]] = int(self.context.expand *
                                                        attr_dict[in_key])
C
ceci3 已提交
349 350
                    ### last super convolution transpose
                    if idx == last_weight_layer_idx:
C
ceci3 已提交
351
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
352
                    else:
C
ceci3 已提交
353 354
                        new_attr_dict[out_key[1:]] = int(self.context.expand *
                                                         attr_dict[out_key])
C
ceci3 已提交
355 356 357 358 359
                        new_attr_dict['candidate_config'].update({
                            'expand_ratio': self.context.expand_ratio
                        })
                elif self.context.channel:
                    if attr_dict['_groups'] != None and (
C
ceci3 已提交
360 361
                            int(attr_dict['_groups']) == int(attr_dict[in_key])
                    ):
C
ceci3 已提交
362 363 364 365 366 367 368 369 370 371
                        ### depthwise conv_transpose
                        _logger.warn(
                        "If convolution is a depthwise conv_transpose, output channel " \
                        "change to the same channel with input, output channel in search is not used."
                        )
                        cur_channel = pre_channel
                    else:
                        cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
372
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
373
                    else:
C
ceci3 已提交
374
                        new_attr_dict[in_key[1:]] = max(pre_channel)
C
ceci3 已提交
375 376

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
377
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
378
                    else:
C
ceci3 已提交
379
                        new_attr_dict[out_key[1:]] = max(cur_channel)
C
ceci3 已提交
380 381 382 383 384
                        new_attr_dict['candidate_config'].update({
                            'channel': cur_channel
                        })
                        pre_channel = cur_channel
                else:
C
ceci3 已提交
385 386
                    new_attr_dict[in_key[1:]] = attr_dict[in_key]
                    new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
387 388

                for attr in new_attr_name:
C
ceci3 已提交
389 390 391 392 393 394
                    if attr == 'weight_attr':
                        new_attr_dict[attr] = attr_dict['_param_attr']
                    elif attr == 'output_padding':
                        new_attr_dict[attr] = attr_dict[attr]
                    else:
                        new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
395 396 397

                del layer

C
ceci3 已提交
398
                if getattr(new_attr_dict, 'output_size', None) == []:
C
ceci3 已提交
399 400 401 402 403 404 405
                    new_attr_dict['output_size'] = None

                if attr_dict['_groups'] == None or int(attr_dict[
                        '_groups']) == 1:
                    ### standard conv_transpose
                    layer = Block(
                        SuperConv2DTranspose(**new_attr_dict), key=key)
C
ceci3 已提交
406
                elif int(attr_dict['_groups']) == int(attr_dict[in_key]):
C
ceci3 已提交
407 408 409
                    # if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
                    # channel in candidate_config = in_channel_list
                    if 'channel' in new_attr_dict['candidate_config']:
C
ceci3 已提交
410 411
                        new_attr_dict[in_key[1:]] = max(cur_channel)
                        new_attr_dict[out_key[1:]] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
412 413
                        new_attr_dict['candidate_config'][
                            'channel'] = cur_channel
C
ceci3 已提交
414
                    new_attr_dict['groups'] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
415 416 417 418 419 420 421 422 423 424 425 426 427
                    layer = Block(
                        SuperDepthwiseConv2DTranspose(**new_attr_dict), key=key)
                else:
                    ### group conv_transpose
                    layer = Block(
                        SuperGroupConv2DTranspose(**new_attr_dict), key=key)
                model[idx] = layer

            elif isinstance(layer, Linear) and (
                    getattr(self.context, 'expand', None) != None or
                    getattr(self.context, 'channel', None) != None):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']
C
ceci3 已提交
428
                if pd_ver == 185:
C
ceci3 已提交
429
                    new_attr_name = ['act', 'dtype']
C
ceci3 已提交
430 431
                else:
                    new_attr_name = ['weight_attr', 'bias_attr']
C
ceci3 已提交
432 433
                in_nc, out_nc = layer._parameters['weight'].shape

C
ceci3 已提交
434
                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
435
                new_attr_dict['candidate_config'] = dict()
C
ceci3 已提交
436 437 438 439 440 441 442
                if pd_ver == 185:
                    new_attr_dict['input_dim'] = None
                    new_attr_dict['output_dim'] = None
                else:
                    new_attr_dict['in_features'] = None
                    new_attr_dict['out_features'] = None

C
ceci3 已提交
443 444
                in_key = '_input_dim' if pd_ver == 185 else '_in_features'
                out_key = '_output_dim' if pd_ver == 185 else '_out_features'
C
ceci3 已提交
445 446
                attr_dict[in_key] = in_nc
                attr_dict[out_key] = out_nc
C
ceci3 已提交
447 448
                if self.context.expand:
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
449
                        new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
C
ceci3 已提交
450
                    else:
C
ceci3 已提交
451 452
                        new_attr_dict[in_key[1:]] = int(self.context.expand *
                                                        attr_dict[in_key])
C
ceci3 已提交
453 454

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
455
                        new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
C
ceci3 已提交
456
                    else:
C
ceci3 已提交
457 458
                        new_attr_dict[out_key[1:]] = int(self.context.expand *
                                                         attr_dict[out_key])
C
ceci3 已提交
459 460 461 462 463 464 465
                        new_attr_dict['candidate_config'].update({
                            'expand_ratio': self.context.expand_ratio
                        })
                elif self.context.channel:
                    cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
466
                        new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
C
ceci3 已提交
467
                    else:
C
ceci3 已提交
468
                        new_attr_dict[in_key[1:]] = max(pre_channel)
C
ceci3 已提交
469 470

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
471
                        new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
C
ceci3 已提交
472
                    else:
C
ceci3 已提交
473
                        new_attr_dict[out_key[1:]] = max(cur_channel)
C
ceci3 已提交
474 475 476 477 478
                        new_attr_dict['candidate_config'].update({
                            'channel': cur_channel
                        })
                        pre_channel = cur_channel
                else:
C
ceci3 已提交
479 480
                    new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
                    new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
C
ceci3 已提交
481 482

                for attr in new_attr_name:
C
ceci3 已提交
483
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
484 485 486 487 488 489

                del layer, attr_dict

                layer = Block(SuperLinear(**new_attr_dict), key=key)
                model[idx] = layer

C
ceci3 已提交
490 491 492 493 494 495
            elif isinstance(
                    layer,
                    getattr(nn, 'InstanceNorm2D',
                            paddle.fluid.dygraph.nn.InstanceNorm)) and (
                                getattr(self.context, 'expand', None) != None or
                                getattr(self.context, 'channel', None) != None):
C
ceci3 已提交
496 497 498 499 500
                # num_features in InstanceNorm don't change after last weight operators
                if idx > last_weight_layer_idx:
                    continue

                attr_dict = layer.__dict__
C
ceci3 已提交
501 502 503 504 505 506
                if pd_ver == 185:
                    new_attr_name = [
                        'bias_attr', 'epsilon', 'param_attr', 'dtype'
                    ]
                else:
                    new_attr_name = ['bias_attr', 'epsilon', 'weight_attr']
C
ceci3 已提交
507 508

                self._change_name(layer, pd_ver)
C
ceci3 已提交
509 510 511 512 513 514 515 516 517 518 519
                new_attr_dict = dict.fromkeys(new_attr_name, None)
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                else:
                    new_attr_dict['num_features'] = None
                new_key = '_num_channels' if '_num_channels' in new_attr_dict.keys(
                ) else '_num_features'
                ### 10 is a default channel in the case of weight_attr=False, in this condition, num of channels if useless, so give it arbitrarily.
                attr_dict[new_key] = layer._parameters['scale'].shape[0] if len(
                    layer._parameters) != 0 else 10

C
ceci3 已提交
520
                if self.context.expand:
C
ceci3 已提交
521 522
                    new_attr_dict[new_key[1:]] = int(self.context.expand *
                                                     attr_dict[new_key])
C
ceci3 已提交
523
                elif self.context.channel:
C
ceci3 已提交
524
                    new_attr_dict[new_key[1:]] = max(cur_channel)
C
ceci3 已提交
525
                else:
C
ceci3 已提交
526
                    new_attr_dict[new_key[1:]] = attr_dict[new_key]
C
ceci3 已提交
527 528

                for attr in new_attr_name:
C
ceci3 已提交
529
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
530 531 532

                del layer, attr_dict

C
ceci3 已提交
533 534 535 536
                layer = layers.SuperInstanceNorm(
                    **new_attr_dict
                ) if pd_ver == 185 else layers.SuperInstanceNorm2D(
                    **new_attr_dict)
C
ceci3 已提交
537 538
                model[idx] = layer

C
ceci3 已提交
539 540 541 542 543 544 545 546
            elif isinstance(layer, LayerNorm) and (
                    getattr(self.context, 'expand', None) != None or
                    getattr(self.context, 'channel', None) != None):
                ### TODO(ceci3): fix when normalized_shape != last_dim_of_input
                if idx > last_weight_layer_idx:
                    continue

                attr_dict = layer.__dict__
C
ceci3 已提交
547 548 549 550 551 552 553 554
                new_attr_name = ['epsilon', 'bias_attr']
                if pd_ver == 185:
                    new_attr_name += [
                        'scale', 'shift', 'param_attr', 'act', 'dtype'
                    ]
                else:
                    new_attr_name += ['weight_attr']

C
ceci3 已提交
555
                self._change_name(layer, pd_ver)
C
ceci3 已提交
556 557
                new_attr_dict = dict.fromkeys(new_attr_name, None)
                new_attr_dict['normalized_shape'] = None
C
ceci3 已提交
558
                if self.context.expand:
C
ceci3 已提交
559 560
                    new_attr_dict['normalized_shape'] = int(
                        self.context.expand * attr_dict['_normalized_shape'][0])
C
ceci3 已提交
561 562 563 564 565 566 567
                elif self.context.channel:
                    new_attr_dict['normalized_shape'] = max(cur_channel)
                else:
                    new_attr_dict['normalized_shape'] = attr_dict[
                        '_normalized_shape']

                for attr in new_attr_name:
C
ceci3 已提交
568
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
569 570 571 572 573 574 575 576 577 578

                del layer, attr_dict
                layer = SuperLayerNorm(**new_attr_dict)
                model[idx] = layer

            elif isinstance(layer, Embedding) and (
                    getattr(self.context, 'expand', None) != None or
                    getattr(self.context, 'channel', None) != None):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']
C
ceci3 已提交
579
                new_attr_name = []
C
ceci3 已提交
580 581
                if pd_ver == 185:
                    new_attr_name += [
582
                        'is_sparse', 'is_distributed', 'param_attr', 'dtype'
C
ceci3 已提交
583 584
                    ]
                else:
585
                    new_attr_name += ['sparse', 'weight_attr', 'name']
C
ceci3 已提交
586

C
ceci3 已提交
587 588
                self._change_name(layer, pd_ver, has_bias=False)

C
ceci3 已提交
589
                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
590 591 592
                new_attr_dict['candidate_config'] = dict()
                bef_size = attr_dict['_size']
                if self.context.expand:
C
ceci3 已提交
593 594 595 596 597 598 599 600 601 602
                    if pd_ver == 185:
                        new_attr_dict['size'] = [
                            bef_size[0], int(self.context.expand * bef_size[1])
                        ]
                    else:
                        new_attr_dict['num_embeddings'] = attr_dict[
                            '_num_embeddings']
                        new_attr_dict['embedding_dim'] = int(
                            self.context.expand * attr_dict['_embedding_dim'])

C
ceci3 已提交
603 604 605 606 607 608 609
                    new_attr_dict['candidate_config'].update({
                        'expand_ratio': self.context.expand_ratio
                    })

                elif self.context.channel:
                    cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
C
ceci3 已提交
610 611 612 613 614 615 616
                    if pd_ver == 185:
                        new_attr_dict['size'] = [bef_size[0], max(cur_channel)]
                    else:
                        new_attr_dict['num_embeddings'] = attr_dict[
                            '_num_embeddings']
                        new_attr_dict['embedding_dim'] = max(cur_channel)

C
ceci3 已提交
617 618 619 620 621
                    new_attr_dict['candidate_config'].update({
                        'channel': cur_channel
                    })
                    pre_channel = cur_channel
                else:
C
ceci3 已提交
622 623 624 625 626 627 628
                    if pf_ver == 185:
                        new_attr_dict['size'] = bef_size
                    else:
                        new_attr_dict['num_embeddings'] = attr_dict[
                            '_num_embeddings']
                        new_attr_dict['embedding_dim'] = attr_dict[
                            '_embedding_dim']
C
ceci3 已提交
629 630

                for attr in new_attr_name:
C
ceci3 已提交
631
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
632

C
ceci3 已提交
633 634 635
                new_attr_dict['padding_idx'] = None if attr_dict[
                    '_padding_idx'] == -1 else attr_dict['_padding_idx']

C
ceci3 已提交
636 637 638 639 640
                del layer, attr_dict

                layer = Block(SuperEmbedding(**new_attr_dict), key=key)
                model[idx] = layer

C
ceci3 已提交
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
        def split_prefix(net, name_list):
            if len(name_list) > 1:
                net = split_prefix(getattr(net, name_list[0]), name_list[1:])
            elif len(name_list) == 1:
                net = getattr(net, name_list[0])
            else:
                raise NotImplementedError("name error")
            return net

        if isinstance(network, Layer):
            for idx, (name, sublayer) in enumerate(network.named_sublayers()):
                if len(name.split('.')) > 1:
                    net = split_prefix(network, name.split('.')[:-1])
                else:
                    net = network
                setattr(net, name.split('.')[-1], model[idx])

        return network
C
ceci3 已提交
659 660 661


class supernet:
662 663 664 665 666 667 668 669
    """
    Search space of the network.
    Parameters:
        kernel_size(list|tuple, optional): search space for the kernel size of the Conv2D.
        expand_ratio(list|tuple, optional): the search space for the expand ratio of the number of channels of Conv2D, the expand ratio of the output dimensions of the Embedding or Linear, which means this parameter get the number of channels of each OP in the converted super network based on the the channels of each OP in the original model, so this parameter The length is 1. Just set one between this parameter and ``channel``.
        channel(list|tuple, optional): the search space for the number of channels of Conv2D, the output dimensions of the Embedding or Linear, this parameter directly sets the number of channels of each OP in the super network, so the length of this parameter needs to be the same as the total number that of Conv2D, Embedding, and Linear included in the network. Just set one between this parameter and ``expand_ratio``.
    """

C
ceci3 已提交
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

        assert (
            getattr(self, 'expand_ratio', None) == None or
            getattr(self, 'channel', None) == None
        ), "expand_ratio and channel CANNOT be NOT None at the same time."

        self.expand = None
        if 'expand_ratio' in kwargs.keys():
            if isinstance(self.expand_ratio, list) or isinstance(
                    self.expand_ratio, tuple):
                self.expand = max(self.expand_ratio)
            elif isinstance(self.expand_ratio, int):
                self.expand = self.expand_ratio
C
ceci3 已提交
686 687
        if 'channel' not in kwargs.keys():
            self.channel = None
C
ceci3 已提交
688 689 690 691 692

    def __enter__(self):
        return Convert(self)

    def __exit__(self, exc_type, exc_val, exc_tb):
C
ceci3 已提交
693 694 695
        self.expand = None
        self.channel = None
        self.kernel_size = None
C
ceci3 已提交
696 697 698 699 700 701 702 703 704


#def ofa_supernet(kernel_size, expand_ratio):
#    def _ofa_supernet(func):
#        @functools.wraps(func)
#        def convert(*args, **kwargs):
#            supernet_convert(*args, **kwargs)
#        return convert
#    return _ofa_supernet