convert_super.py 29.7 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import decorator
import logging
C
ceci3 已提交
18
import numbers
C
ceci3 已提交
19
import paddle
C
ceci3 已提交
20
from ...common import get_logger
C
ceci3 已提交
21 22 23 24 25
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
if pd_ver == 185:
    import paddle.fluid.dygraph.nn as nn
    from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding
C
ceci3 已提交
26
    from paddle.fluid import ParamAttr
C
ceci3 已提交
27 28
    from .layers import *
    from . import layers
C
ceci3 已提交
29
    Layer = paddle.fluid.dygraph.Layer
C
ceci3 已提交
30 31 32
else:
    import paddle.nn as nn
    from paddle.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding
C
ceci3 已提交
33
    from paddle import ParamAttr
C
ceci3 已提交
34 35 36
    from .layers_new import *
    from . import layers_new as layers
    Layer = paddle.nn.Layer
C
ceci3 已提交
37 38 39

_logger = get_logger(__name__, level=logging.INFO)

C
ceci3 已提交
40
__all__ = ['supernet', 'Convert']
C
ceci3 已提交
41

C
ceci3 已提交
42
WEIGHT_LAYER = ['conv', 'linear', 'embedding']
C
ceci3 已提交
43 44 45 46 47 48


class Convert:
    def __init__(self, context):
        self.context = context

C
ceci3 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    def _change_name(self, layer, pd_ver, has_bias=True, conv=False):
        if conv:
            w_attr = layer._param_attr
        else:
            w_attr = layer._param_attr if pd_ver == 185 else layer._weight_attr

        if isinstance(w_attr, ParamAttr):
            if w_attr != None and not isinstance(w_attr, bool):
                w_attr.name = 'super_' + w_attr.name

        if has_bias:
            if isinstance(layer._bias_attr, ParamAttr):
                if layer._bias_attr != None and not isinstance(layer._bias_attr,
                                                               bool):
                    layer._bias_attr.name = 'super_' + layer._bias_attr.name

C
ceci3 已提交
65
    def convert(self, network):
C
ceci3 已提交
66 67
        # search the first and last weight layer, don't change out channel of the last weight layer
        # don't change in channel of the first weight layer
C
ceci3 已提交
68 69 70 71 72 73 74
        model = []
        if isinstance(network, Layer):
            for name, sublayer in network.named_sublayers():
                model.append(sublayer)
        else:
            model = network

C
ceci3 已提交
75 76 77 78
        first_weight_layer_idx = -1
        last_weight_layer_idx = -1
        weight_layer_count = 0
        # NOTE: pre_channel store for shortcut module
C
ceci3 已提交
79
        pre_channel = None
C
ceci3 已提交
80 81 82
        cur_channel = None
        for idx, layer in enumerate(model):
            cls_name = layer.__class__.__name__.lower()
C
ceci3 已提交
83
            if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name:
C
ceci3 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
                weight_layer_count += 1
                last_weight_layer_idx = idx
                if first_weight_layer_idx == -1:
                    first_weight_layer_idx = idx

        if getattr(self.context, 'channel', None) != None:
            assert len(
                self.context.channel
            ) == weight_layer_count, "length of channel must same as weight layer."

        for idx, layer in enumerate(model):
            if isinstance(layer, Conv2D):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']

                new_attr_name = [
C
ceci3 已提交
100
                    'stride', 'padding', 'dilation', 'groups', 'bias_attr'
C
ceci3 已提交
101
                ]
C
ceci3 已提交
102 103 104 105 106 107
                if pd_ver == 185:
                    new_attr_name += ['param_attr', 'use_cudnn', 'act', 'dtype']
                else:
                    new_attr_name += [
                        'weight_attr', 'data_format', 'padding_mode'
                    ]
C
ceci3 已提交
108

C
ceci3 已提交
109
                self._change_name(layer, pd_ver, conv=True)
C
ceci3 已提交
110
                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
111
                new_attr_dict['candidate_config'] = dict()
C
ceci3 已提交
112 113 114 115 116 117 118 119
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                    new_attr_dict['num_filters'] = None
                    new_attr_dict['filter_size'] = None
                else:
                    new_attr_dict['in_channels'] = None
                    new_attr_dict['out_channels'] = None
                    new_attr_dict['kernel_size'] = None
C
ceci3 已提交
120 121 122
                self.kernel_size = getattr(self.context, 'kernel_size', None)

                # if the kernel_size of conv is 1, don't change it.
C
ceci3 已提交
123 124 125
                fks = '_filter_size' if '_filter_size' in attr_dict.keys(
                ) else '_kernel_size'

C
ceci3 已提交
126
                ks = [attr_dict[fks]] if isinstance(
C
ceci3 已提交
127 128 129 130 131
                    attr_dict[fks], numbers.Integral) else attr_dict[fks]

                if self.kernel_size and int(ks[0]) != 1:
                    new_attr_dict['transform_kernel'] = True
                    new_attr_dict[fks[1:]] = max(self.kernel_size)
C
ceci3 已提交
132 133 134 135
                    new_attr_dict['candidate_config'].update({
                        'kernel_size': self.kernel_size
                    })
                else:
C
ceci3 已提交
136
                    new_attr_dict[fks[1:]] = attr_dict[fks]
C
ceci3 已提交
137

C
ceci3 已提交
138 139 140 141
                in_key = '_num_channels' if '_num_channels' in attr_dict.keys(
                ) else '_in_channels'
                out_key = '_num_filters' if '_num_filters' in attr_dict.keys(
                ) else '_out_channels'
C
ceci3 已提交
142 143 144
                if self.context.expand:
                    ### first super convolution
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
145
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
146
                    else:
C
ceci3 已提交
147 148 149
                        new_attr_dict[in_key[1:]] = int(self.context.expand *
                                                        attr_dict[in_key])

C
ceci3 已提交
150 151
                    ### last super convolution
                    if idx == last_weight_layer_idx:
C
ceci3 已提交
152
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
153
                    else:
C
ceci3 已提交
154 155
                        new_attr_dict[out_key[1:]] = int(self.context.expand *
                                                         attr_dict[out_key])
C
ceci3 已提交
156 157 158 159 160
                        new_attr_dict['candidate_config'].update({
                            'expand_ratio': self.context.expand_ratio
                        })
                elif self.context.channel:
                    if attr_dict['_groups'] != None and (
C
ceci3 已提交
161 162
                            int(attr_dict['_groups']) == int(attr_dict[in_key])
                    ):
C
ceci3 已提交
163 164 165 166 167 168 169 170 171 172
                        ### depthwise conv, if conv is depthwise, use pre channel as cur_channel
                        _logger.warn(
                        "If convolution is a depthwise conv, output channel change" \
                        " to the same channel with input, output channel in search is not used."
                        )
                        cur_channel = pre_channel
                    else:
                        cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
173
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
174
                    else:
C
ceci3 已提交
175
                        new_attr_dict[in_key[1:]] = max(pre_channel)
C
ceci3 已提交
176 177

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
178
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
179
                    else:
C
ceci3 已提交
180
                        new_attr_dict[out_key[1:]] = max(cur_channel)
C
ceci3 已提交
181 182 183 184 185
                        new_attr_dict['candidate_config'].update({
                            'channel': cur_channel
                        })
                        pre_channel = cur_channel
                else:
C
ceci3 已提交
186 187
                    new_attr_dict[in_key[1:]] = attr_dict[in_key]
                    new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
188 189

                for attr in new_attr_name:
C
ceci3 已提交
190 191 192 193
                    if attr == 'weight_attr':
                        new_attr_dict[attr] = attr_dict['_param_attr']
                    else:
                        new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
194 195 196 197 198 199 200

                del layer

                if attr_dict['_groups'] == None or int(attr_dict[
                        '_groups']) == 1:
                    ### standard conv
                    layer = Block(SuperConv2D(**new_attr_dict), key=key)
C
ceci3 已提交
201
                elif int(attr_dict['_groups']) == int(attr_dict[in_key]):
C
ceci3 已提交
202 203 204
                    # if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
                    # channel in candidate_config = in_channel_list
                    if 'channel' in new_attr_dict['candidate_config']:
C
ceci3 已提交
205 206
                        new_attr_dict[in_key[1:]] = max(cur_channel)
                        new_attr_dict[out_key[1:]] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
207 208
                        new_attr_dict['candidate_config'][
                            'channel'] = cur_channel
C
ceci3 已提交
209
                    new_attr_dict['groups'] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
210 211 212 213 214 215 216
                    layer = Block(
                        SuperDepthwiseConv2D(**new_attr_dict), key=key)
                else:
                    ### group conv
                    layer = Block(SuperGroupConv2D(**new_attr_dict), key=key)
                model[idx] = layer

C
ceci3 已提交
217 218 219 220
            elif isinstance(layer,
                            getattr(nn, 'BatchNorm2D', nn.BatchNorm)) and (
                                getattr(self.context, 'expand', None) != None or
                                getattr(self.context, 'channel', None) != None):
C
ceci3 已提交
221 222 223 224 225
                # num_features in BatchNorm don't change after last weight operators
                if idx > last_weight_layer_idx:
                    continue

                attr_dict = layer.__dict__
C
ceci3 已提交
226 227 228 229 230 231 232 233 234 235
                new_attr_name = ['momentum', 'epsilon', 'bias_attr']

                if pd_ver == 185:
                    new_attr_name += [
                        'param_attr', 'act', 'dtype', 'in_place', 'data_layout',
                        'is_test', 'use_global_stats', 'trainable_statistics'
                    ]
                else:
                    new_attr_name += ['weight_attr', 'data_format', 'name']

C
ceci3 已提交
236
                self._change_name(layer, pd_ver)
C
ceci3 已提交
237 238 239 240 241 242 243
                new_attr_dict = dict.fromkeys(new_attr_name, None)
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                else:
                    new_attr_dict['num_features'] = None
                new_key = 'num_channels' if 'num_channels' in new_attr_dict.keys(
                ) else 'num_features'
C
ceci3 已提交
244
                if self.context.expand:
C
ceci3 已提交
245 246
                    new_attr_dict[new_key] = int(
                        self.context.expand *
C
ceci3 已提交
247 248
                        layer._parameters['weight'].shape[0])
                elif self.context.channel:
C
ceci3 已提交
249
                    new_attr_dict[new_key] = max(cur_channel)
C
ceci3 已提交
250
                else:
C
ceci3 已提交
251 252 253
                    new_attr_dict[new_key] = attr_dict[
                        '_num_channels'] if '_num_channels' in attr_dict.keys(
                        ) else attr_dict['_num_features']
C
ceci3 已提交
254 255

                for attr in new_attr_name:
C
ceci3 已提交
256
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
257 258 259

                del layer, attr_dict

C
ceci3 已提交
260 261 262
                layer = layers.SuperBatchNorm(
                    **new_attr_dict
                ) if pd_ver == 185 else layers.SuperBatchNorm2D(**new_attr_dict)
C
ceci3 已提交
263 264 265 266 267 268 269 270 271
                model[idx] = layer

            ### assume output_size = None, filter_size != None
            ### NOTE: output_size != None may raise error, solve when it happend. 
            elif isinstance(layer, Conv2DTranspose):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']

                new_attr_name = [
C
ceci3 已提交
272
                    'stride', 'padding', 'dilation', 'groups', 'bias_attr'
C
ceci3 已提交
273
                ]
C
ceci3 已提交
274 275 276
                assert getattr(
                    attr_dict, '_filter_size', '_kernel_size'
                ) != None, "Conv2DTranspose only support kernel size != None now"
C
ceci3 已提交
277

C
ceci3 已提交
278 279 280 281 282 283 284 285 286 287
                if pd_ver == 185:
                    new_attr_name += [
                        'output_size', 'param_attr', 'use_cudnn', 'act', 'dtype'
                    ]
                else:
                    new_attr_name += [
                        'output_padding', 'weight_attr', 'data_format'
                    ]

                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
288
                new_attr_dict['candidate_config'] = dict()
C
ceci3 已提交
289 290 291 292 293 294 295 296
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                    new_attr_dict['num_filters'] = None
                    new_attr_dict['filter_size'] = None
                else:
                    new_attr_dict['in_channels'] = None
                    new_attr_dict['out_channels'] = None
                    new_attr_dict['kernel_size'] = None
C
ceci3 已提交
297 298

                self._change_name(layer, pd_ver, conv=True)
C
ceci3 已提交
299 300 301
                self.kernel_size = getattr(self.context, 'kernel_size', None)

                # if the kernel_size of conv transpose is 1, don't change it.
C
ceci3 已提交
302 303
                fks = '_filter_size' if '_filter_size' in attr_dict.keys(
                ) else '_kernel_size'
C
ceci3 已提交
304
                ks = [attr_dict[fks]] if isinstance(
C
ceci3 已提交
305 306 307 308 309
                    attr_dict[fks], numbers.Integral) else attr_dict[fks]

                if self.kernel_size and int(ks[0]) != 1:
                    new_attr_dict['transform_kernel'] = True
                    new_attr_dict[fks[1:]] = max(self.kernel_size)
C
ceci3 已提交
310 311 312 313
                    new_attr_dict['candidate_config'].update({
                        'kernel_size': self.kernel_size
                    })
                else:
C
ceci3 已提交
314
                    new_attr_dict[fks[1:]] = attr_dict[fks]
C
ceci3 已提交
315

C
ceci3 已提交
316 317 318 319
                in_key = '_num_channels' if '_num_channels' in attr_dict.keys(
                ) else '_in_channels'
                out_key = '_num_filters' if '_num_filters' in attr_dict.keys(
                ) else '_out_channels'
C
ceci3 已提交
320 321 322
                if self.context.expand:
                    ### first super convolution transpose
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
323
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
324
                    else:
C
ceci3 已提交
325 326
                        new_attr_dict[in_key[1:]] = int(self.context.expand *
                                                        attr_dict[in_key])
C
ceci3 已提交
327 328
                    ### last super convolution transpose
                    if idx == last_weight_layer_idx:
C
ceci3 已提交
329
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
330
                    else:
C
ceci3 已提交
331 332
                        new_attr_dict[out_key[1:]] = int(self.context.expand *
                                                         attr_dict[out_key])
C
ceci3 已提交
333 334 335 336 337
                        new_attr_dict['candidate_config'].update({
                            'expand_ratio': self.context.expand_ratio
                        })
                elif self.context.channel:
                    if attr_dict['_groups'] != None and (
C
ceci3 已提交
338 339
                            int(attr_dict['_groups']) == int(attr_dict[in_key])
                    ):
C
ceci3 已提交
340 341 342 343 344 345 346 347 348 349
                        ### depthwise conv_transpose
                        _logger.warn(
                        "If convolution is a depthwise conv_transpose, output channel " \
                        "change to the same channel with input, output channel in search is not used."
                        )
                        cur_channel = pre_channel
                    else:
                        cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
350
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
351
                    else:
C
ceci3 已提交
352
                        new_attr_dict[in_key[1:]] = max(pre_channel)
C
ceci3 已提交
353 354

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
355
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
356
                    else:
C
ceci3 已提交
357
                        new_attr_dict[out_key[1:]] = max(cur_channel)
C
ceci3 已提交
358 359 360 361 362
                        new_attr_dict['candidate_config'].update({
                            'channel': cur_channel
                        })
                        pre_channel = cur_channel
                else:
C
ceci3 已提交
363 364
                    new_attr_dict[in_key[1:]] = attr_dict[in_key]
                    new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
365 366

                for attr in new_attr_name:
C
ceci3 已提交
367 368 369 370 371 372
                    if attr == 'weight_attr':
                        new_attr_dict[attr] = attr_dict['_param_attr']
                    elif attr == 'output_padding':
                        new_attr_dict[attr] = attr_dict[attr]
                    else:
                        new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
373 374 375

                del layer

C
ceci3 已提交
376
                if getattr(new_attr_dict, 'output_size', None) == []:
C
ceci3 已提交
377 378 379 380 381 382 383
                    new_attr_dict['output_size'] = None

                if attr_dict['_groups'] == None or int(attr_dict[
                        '_groups']) == 1:
                    ### standard conv_transpose
                    layer = Block(
                        SuperConv2DTranspose(**new_attr_dict), key=key)
C
ceci3 已提交
384
                elif int(attr_dict['_groups']) == int(attr_dict[in_key]):
C
ceci3 已提交
385 386 387
                    # if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
                    # channel in candidate_config = in_channel_list
                    if 'channel' in new_attr_dict['candidate_config']:
C
ceci3 已提交
388 389
                        new_attr_dict[in_key[1:]] = max(cur_channel)
                        new_attr_dict[out_key[1:]] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
390 391
                        new_attr_dict['candidate_config'][
                            'channel'] = cur_channel
C
ceci3 已提交
392
                    new_attr_dict['groups'] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
393 394 395 396 397 398 399 400 401 402 403 404 405
                    layer = Block(
                        SuperDepthwiseConv2DTranspose(**new_attr_dict), key=key)
                else:
                    ### group conv_transpose
                    layer = Block(
                        SuperGroupConv2DTranspose(**new_attr_dict), key=key)
                model[idx] = layer

            elif isinstance(layer, Linear) and (
                    getattr(self.context, 'expand', None) != None or
                    getattr(self.context, 'channel', None) != None):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']
C
ceci3 已提交
406
                if pd_ver == 185:
C
ceci3 已提交
407
                    new_attr_name = ['act', 'dtype']
C
ceci3 已提交
408 409
                else:
                    new_attr_name = ['weight_attr', 'bias_attr']
C
ceci3 已提交
410 411
                in_nc, out_nc = layer._parameters['weight'].shape

C
ceci3 已提交
412
                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
413
                new_attr_dict['candidate_config'] = dict()
C
ceci3 已提交
414 415 416 417 418 419 420
                if pd_ver == 185:
                    new_attr_dict['input_dim'] = None
                    new_attr_dict['output_dim'] = None
                else:
                    new_attr_dict['in_features'] = None
                    new_attr_dict['out_features'] = None

C
ceci3 已提交
421 422
                in_key = '_input_dim' if pd_ver == 185 else '_in_features'
                out_key = '_output_dim' if pd_ver == 185 else '_out_features'
C
ceci3 已提交
423 424
                attr_dict[in_key] = in_nc
                attr_dict[out_key] = out_nc
C
ceci3 已提交
425 426
                if self.context.expand:
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
427
                        new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
C
ceci3 已提交
428
                    else:
C
ceci3 已提交
429 430
                        new_attr_dict[in_key[1:]] = int(self.context.expand *
                                                        attr_dict[in_key])
C
ceci3 已提交
431 432

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
433
                        new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
C
ceci3 已提交
434
                    else:
C
ceci3 已提交
435 436
                        new_attr_dict[out_key[1:]] = int(self.context.expand *
                                                         attr_dict[out_key])
C
ceci3 已提交
437 438 439 440 441 442 443
                        new_attr_dict['candidate_config'].update({
                            'expand_ratio': self.context.expand_ratio
                        })
                elif self.context.channel:
                    cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
444
                        new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
C
ceci3 已提交
445
                    else:
C
ceci3 已提交
446
                        new_attr_dict[in_key[1:]] = max(pre_channel)
C
ceci3 已提交
447 448

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
449
                        new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
C
ceci3 已提交
450
                    else:
C
ceci3 已提交
451
                        new_attr_dict[out_key[1:]] = max(cur_channel)
C
ceci3 已提交
452 453 454 455 456
                        new_attr_dict['candidate_config'].update({
                            'channel': cur_channel
                        })
                        pre_channel = cur_channel
                else:
C
ceci3 已提交
457 458
                    new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
                    new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
C
ceci3 已提交
459 460

                for attr in new_attr_name:
C
ceci3 已提交
461
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
462 463 464 465 466 467

                del layer, attr_dict

                layer = Block(SuperLinear(**new_attr_dict), key=key)
                model[idx] = layer

C
ceci3 已提交
468 469 470 471 472 473
            elif isinstance(
                    layer,
                    getattr(nn, 'InstanceNorm2D',
                            paddle.fluid.dygraph.nn.InstanceNorm)) and (
                                getattr(self.context, 'expand', None) != None or
                                getattr(self.context, 'channel', None) != None):
C
ceci3 已提交
474 475 476 477 478
                # num_features in InstanceNorm don't change after last weight operators
                if idx > last_weight_layer_idx:
                    continue

                attr_dict = layer.__dict__
C
ceci3 已提交
479 480 481 482 483 484
                if pd_ver == 185:
                    new_attr_name = [
                        'bias_attr', 'epsilon', 'param_attr', 'dtype'
                    ]
                else:
                    new_attr_name = ['bias_attr', 'epsilon', 'weight_attr']
C
ceci3 已提交
485 486

                self._change_name(layer, pd_ver)
C
ceci3 已提交
487 488 489 490 491 492 493 494 495 496 497
                new_attr_dict = dict.fromkeys(new_attr_name, None)
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                else:
                    new_attr_dict['num_features'] = None
                new_key = '_num_channels' if '_num_channels' in new_attr_dict.keys(
                ) else '_num_features'
                ### 10 is a default channel in the case of weight_attr=False, in this condition, num of channels if useless, so give it arbitrarily.
                attr_dict[new_key] = layer._parameters['scale'].shape[0] if len(
                    layer._parameters) != 0 else 10

C
ceci3 已提交
498
                if self.context.expand:
C
ceci3 已提交
499 500
                    new_attr_dict[new_key[1:]] = int(self.context.expand *
                                                     attr_dict[new_key])
C
ceci3 已提交
501
                elif self.context.channel:
C
ceci3 已提交
502
                    new_attr_dict[new_key[1:]] = max(cur_channel)
C
ceci3 已提交
503
                else:
C
ceci3 已提交
504
                    new_attr_dict[new_key[1:]] = attr_dict[new_key]
C
ceci3 已提交
505 506

                for attr in new_attr_name:
C
ceci3 已提交
507
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
508 509 510

                del layer, attr_dict

C
ceci3 已提交
511 512 513 514
                layer = layers.SuperInstanceNorm(
                    **new_attr_dict
                ) if pd_ver == 185 else layers.SuperInstanceNorm2D(
                    **new_attr_dict)
C
ceci3 已提交
515 516
                model[idx] = layer

C
ceci3 已提交
517 518 519 520 521 522 523 524
            elif isinstance(layer, LayerNorm) and (
                    getattr(self.context, 'expand', None) != None or
                    getattr(self.context, 'channel', None) != None):
                ### TODO(ceci3): fix when normalized_shape != last_dim_of_input
                if idx > last_weight_layer_idx:
                    continue

                attr_dict = layer.__dict__
C
ceci3 已提交
525 526 527 528 529 530 531 532
                new_attr_name = ['epsilon', 'bias_attr']
                if pd_ver == 185:
                    new_attr_name += [
                        'scale', 'shift', 'param_attr', 'act', 'dtype'
                    ]
                else:
                    new_attr_name += ['weight_attr']

C
ceci3 已提交
533
                self._change_name(layer, pd_ver)
C
ceci3 已提交
534 535
                new_attr_dict = dict.fromkeys(new_attr_name, None)
                new_attr_dict['normalized_shape'] = None
C
ceci3 已提交
536
                if self.context.expand:
C
ceci3 已提交
537 538
                    new_attr_dict['normalized_shape'] = int(
                        self.context.expand * attr_dict['_normalized_shape'][0])
C
ceci3 已提交
539 540 541 542 543 544 545
                elif self.context.channel:
                    new_attr_dict['normalized_shape'] = max(cur_channel)
                else:
                    new_attr_dict['normalized_shape'] = attr_dict[
                        '_normalized_shape']

                for attr in new_attr_name:
C
ceci3 已提交
546
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
547 548 549 550 551 552 553 554 555 556

                del layer, attr_dict
                layer = SuperLayerNorm(**new_attr_dict)
                model[idx] = layer

            elif isinstance(layer, Embedding) and (
                    getattr(self.context, 'expand', None) != None or
                    getattr(self.context, 'channel', None) != None):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']
C
ceci3 已提交
557
                new_attr_name = []
C
ceci3 已提交
558 559 560 561 562 563 564 565 566 567
                if pd_ver == 185:
                    new_attr_name += [
                        'size', 'is_sparse', 'is_distributed', 'param_attr',
                        'dtype'
                    ]
                else:
                    new_attr_name += [
                        'num_embeddings', 'embedding_dim', 'sparse',
                        'weight_attr', 'name'
                    ]
C
ceci3 已提交
568

C
ceci3 已提交
569 570
                self._change_name(layer, pd_ver, has_bias=False)

C
ceci3 已提交
571
                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
572 573 574
                new_attr_dict['candidate_config'] = dict()
                bef_size = attr_dict['_size']
                if self.context.expand:
C
ceci3 已提交
575 576 577 578 579 580 581 582 583 584
                    if pd_ver == 185:
                        new_attr_dict['size'] = [
                            bef_size[0], int(self.context.expand * bef_size[1])
                        ]
                    else:
                        new_attr_dict['num_embeddings'] = attr_dict[
                            '_num_embeddings']
                        new_attr_dict['embedding_dim'] = int(
                            self.context.expand * attr_dict['_embedding_dim'])

C
ceci3 已提交
585 586 587 588 589 590 591
                    new_attr_dict['candidate_config'].update({
                        'expand_ratio': self.context.expand_ratio
                    })

                elif self.context.channel:
                    cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
C
ceci3 已提交
592 593 594 595 596 597 598
                    if pd_ver == 185:
                        new_attr_dict['size'] = [bef_size[0], max(cur_channel)]
                    else:
                        new_attr_dict['num_embeddings'] = attr_dict[
                            '_num_embeddings']
                        new_attr_dict['embedding_dim'] = max(cur_channel)

C
ceci3 已提交
599 600 601 602 603
                    new_attr_dict['candidate_config'].update({
                        'channel': cur_channel
                    })
                    pre_channel = cur_channel
                else:
C
ceci3 已提交
604 605 606 607 608 609 610
                    if pf_ver == 185:
                        new_attr_dict['size'] = bef_size
                    else:
                        new_attr_dict['num_embeddings'] = attr_dict[
                            '_num_embeddings']
                        new_attr_dict['embedding_dim'] = attr_dict[
                            '_embedding_dim']
C
ceci3 已提交
611 612

                for attr in new_attr_name:
C
ceci3 已提交
613
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
614

C
ceci3 已提交
615 616 617
                new_attr_dict['padding_idx'] = None if attr_dict[
                    '_padding_idx'] == -1 else attr_dict['_padding_idx']

C
ceci3 已提交
618 619 620 621 622
                del layer, attr_dict

                layer = Block(SuperEmbedding(**new_attr_dict), key=key)
                model[idx] = layer

C
ceci3 已提交
623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640
        def split_prefix(net, name_list):
            if len(name_list) > 1:
                net = split_prefix(getattr(net, name_list[0]), name_list[1:])
            elif len(name_list) == 1:
                net = getattr(net, name_list[0])
            else:
                raise NotImplementedError("name error")
            return net

        if isinstance(network, Layer):
            for idx, (name, sublayer) in enumerate(network.named_sublayers()):
                if len(name.split('.')) > 1:
                    net = split_prefix(network, name.split('.')[:-1])
                else:
                    net = network
                setattr(net, name.split('.')[-1], model[idx])

        return network
C
ceci3 已提交
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659


class supernet:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

        assert (
            getattr(self, 'expand_ratio', None) == None or
            getattr(self, 'channel', None) == None
        ), "expand_ratio and channel CANNOT be NOT None at the same time."

        self.expand = None
        if 'expand_ratio' in kwargs.keys():
            if isinstance(self.expand_ratio, list) or isinstance(
                    self.expand_ratio, tuple):
                self.expand = max(self.expand_ratio)
            elif isinstance(self.expand_ratio, int):
                self.expand = self.expand_ratio
C
ceci3 已提交
660 661
        if 'channel' not in kwargs.keys():
            self.channel = None
C
ceci3 已提交
662 663 664 665 666

    def __enter__(self):
        return Convert(self)

    def __exit__(self, exc_type, exc_val, exc_tb):
C
ceci3 已提交
667 668 669
        self.expand = None
        self.channel = None
        self.kernel_size = None
C
ceci3 已提交
670 671 672 673 674 675 676 677 678


#def ofa_supernet(kernel_size, expand_ratio):
#    def _ofa_supernet(func):
#        @functools.wraps(func)
#        def convert(*args, **kwargs):
#            supernet_convert(*args, **kwargs)
#        return convert
#    return _ofa_supernet