convert_super.py 28.5 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import decorator
import logging
C
ceci3 已提交
18
import numbers
C
ceci3 已提交
19
import paddle
C
ceci3 已提交
20
from ...common import get_logger
C
ceci3 已提交
21 22 23 24 25 26 27
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
if pd_ver == 185:
    import paddle.fluid.dygraph.nn as nn
    from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding
    from .layers import *
    from . import layers
C
ceci3 已提交
28
    Layer = paddle.fluid.dygraph.Layer
C
ceci3 已提交
29 30 31 32 33 34
else:
    import paddle.nn as nn
    from paddle.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding
    from .layers_new import *
    from . import layers_new as layers
    Layer = paddle.nn.Layer
C
ceci3 已提交
35 36 37

_logger = get_logger(__name__, level=logging.INFO)

C
ceci3 已提交
38
__all__ = ['supernet', 'Convert']
C
ceci3 已提交
39

C
ceci3 已提交
40
WEIGHT_LAYER = ['conv', 'linear', 'embedding']
C
ceci3 已提交
41 42 43 44 45 46


class Convert:
    def __init__(self, context):
        self.context = context

C
ceci3 已提交
47
    def convert(self, network):
C
ceci3 已提交
48 49
        # search the first and last weight layer, don't change out channel of the last weight layer
        # don't change in channel of the first weight layer
C
ceci3 已提交
50 51 52 53 54 55 56
        model = []
        if isinstance(network, Layer):
            for name, sublayer in network.named_sublayers():
                model.append(sublayer)
        else:
            model = network

C
ceci3 已提交
57 58 59 60
        first_weight_layer_idx = -1
        last_weight_layer_idx = -1
        weight_layer_count = 0
        # NOTE: pre_channel store for shortcut module
C
ceci3 已提交
61
        pre_channel = None
C
ceci3 已提交
62 63 64
        cur_channel = None
        for idx, layer in enumerate(model):
            cls_name = layer.__class__.__name__.lower()
C
ceci3 已提交
65
            if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name:
C
ceci3 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
                weight_layer_count += 1
                last_weight_layer_idx = idx
                if first_weight_layer_idx == -1:
                    first_weight_layer_idx = idx

        if getattr(self.context, 'channel', None) != None:
            assert len(
                self.context.channel
            ) == weight_layer_count, "length of channel must same as weight layer."

        for idx, layer in enumerate(model):
            if isinstance(layer, Conv2D):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']

                new_attr_name = [
C
ceci3 已提交
82
                    'stride', 'padding', 'dilation', 'groups', 'bias_attr'
C
ceci3 已提交
83
                ]
C
ceci3 已提交
84 85 86 87 88 89
                if pd_ver == 185:
                    new_attr_name += ['param_attr', 'use_cudnn', 'act', 'dtype']
                else:
                    new_attr_name += [
                        'weight_attr', 'data_format', 'padding_mode'
                    ]
C
ceci3 已提交
90

C
ceci3 已提交
91
                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
92
                new_attr_dict['candidate_config'] = dict()
C
ceci3 已提交
93 94 95 96 97 98 99 100
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                    new_attr_dict['num_filters'] = None
                    new_attr_dict['filter_size'] = None
                else:
                    new_attr_dict['in_channels'] = None
                    new_attr_dict['out_channels'] = None
                    new_attr_dict['kernel_size'] = None
C
ceci3 已提交
101 102 103
                self.kernel_size = getattr(self.context, 'kernel_size', None)

                # if the kernel_size of conv is 1, don't change it.
C
ceci3 已提交
104 105 106 107 108 109 110 111 112
                fks = '_filter_size' if '_filter_size' in attr_dict.keys(
                ) else '_kernel_size'

                ks = list(attr_dict[fks]) if isinstance(
                    attr_dict[fks], numbers.Integral) else attr_dict[fks]

                if self.kernel_size and int(ks[0]) != 1:
                    new_attr_dict['transform_kernel'] = True
                    new_attr_dict[fks[1:]] = max(self.kernel_size)
C
ceci3 已提交
113 114 115 116
                    new_attr_dict['candidate_config'].update({
                        'kernel_size': self.kernel_size
                    })
                else:
C
ceci3 已提交
117
                    new_attr_dict[fks[1:]] = attr_dict[fks]
C
ceci3 已提交
118

C
ceci3 已提交
119 120 121 122
                in_key = '_num_channels' if '_num_channels' in attr_dict.keys(
                ) else '_in_channels'
                out_key = '_num_filters' if '_num_filters' in attr_dict.keys(
                ) else '_out_channels'
C
ceci3 已提交
123 124 125
                if self.context.expand:
                    ### first super convolution
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
126
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
127
                    else:
C
ceci3 已提交
128 129 130
                        new_attr_dict[in_key[1:]] = int(self.context.expand *
                                                        attr_dict[in_key])

C
ceci3 已提交
131 132
                    ### last super convolution
                    if idx == last_weight_layer_idx:
C
ceci3 已提交
133
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
134
                    else:
C
ceci3 已提交
135 136
                        new_attr_dict[out_key[1:]] = int(self.context.expand *
                                                         attr_dict[out_key])
C
ceci3 已提交
137 138 139 140 141
                        new_attr_dict['candidate_config'].update({
                            'expand_ratio': self.context.expand_ratio
                        })
                elif self.context.channel:
                    if attr_dict['_groups'] != None and (
C
ceci3 已提交
142 143
                            int(attr_dict['_groups']) == int(attr_dict[in_key])
                    ):
C
ceci3 已提交
144 145 146 147 148 149 150 151 152 153
                        ### depthwise conv, if conv is depthwise, use pre channel as cur_channel
                        _logger.warn(
                        "If convolution is a depthwise conv, output channel change" \
                        " to the same channel with input, output channel in search is not used."
                        )
                        cur_channel = pre_channel
                    else:
                        cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
154
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
155
                    else:
C
ceci3 已提交
156
                        new_attr_dict[in_key[1:]] = max(pre_channel)
C
ceci3 已提交
157 158

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
159
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
160
                    else:
C
ceci3 已提交
161
                        new_attr_dict[out_key[1:]] = max(cur_channel)
C
ceci3 已提交
162 163 164 165 166
                        new_attr_dict['candidate_config'].update({
                            'channel': cur_channel
                        })
                        pre_channel = cur_channel
                else:
C
ceci3 已提交
167 168
                    new_attr_dict[in_key[1:]] = attr_dict[in_key]
                    new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
169 170

                for attr in new_attr_name:
C
ceci3 已提交
171 172 173 174
                    if attr == 'weight_attr':
                        new_attr_dict[attr] = attr_dict['_param_attr']
                    else:
                        new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
175 176 177 178 179 180 181

                del layer

                if attr_dict['_groups'] == None or int(attr_dict[
                        '_groups']) == 1:
                    ### standard conv
                    layer = Block(SuperConv2D(**new_attr_dict), key=key)
C
ceci3 已提交
182
                elif int(attr_dict['_groups']) == int(attr_dict[in_key]):
C
ceci3 已提交
183 184 185
                    # if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
                    # channel in candidate_config = in_channel_list
                    if 'channel' in new_attr_dict['candidate_config']:
C
ceci3 已提交
186 187
                        new_attr_dict[in_key[1:]] = max(cur_channel)
                        new_attr_dict[out_key[1:]] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
188 189
                        new_attr_dict['candidate_config'][
                            'channel'] = cur_channel
C
ceci3 已提交
190
                    new_attr_dict['groups'] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
191 192 193 194 195 196 197
                    layer = Block(
                        SuperDepthwiseConv2D(**new_attr_dict), key=key)
                else:
                    ### group conv
                    layer = Block(SuperGroupConv2D(**new_attr_dict), key=key)
                model[idx] = layer

C
ceci3 已提交
198 199 200 201
            elif isinstance(layer,
                            getattr(nn, 'BatchNorm2D', nn.BatchNorm)) and (
                                getattr(self.context, 'expand', None) != None or
                                getattr(self.context, 'channel', None) != None):
C
ceci3 已提交
202 203 204 205 206
                # num_features in BatchNorm don't change after last weight operators
                if idx > last_weight_layer_idx:
                    continue

                attr_dict = layer.__dict__
C
ceci3 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
                new_attr_name = ['momentum', 'epsilon', 'bias_attr']

                if pd_ver == 185:
                    new_attr_name += [
                        'param_attr', 'act', 'dtype', 'in_place', 'data_layout',
                        'is_test', 'use_global_stats', 'trainable_statistics'
                    ]
                else:
                    new_attr_name += ['weight_attr', 'data_format', 'name']

                new_attr_dict = dict.fromkeys(new_attr_name, None)
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                else:
                    new_attr_dict['num_features'] = None
                new_key = 'num_channels' if 'num_channels' in new_attr_dict.keys(
                ) else 'num_features'
C
ceci3 已提交
224
                if self.context.expand:
C
ceci3 已提交
225 226
                    new_attr_dict[new_key] = int(
                        self.context.expand *
C
ceci3 已提交
227 228
                        layer._parameters['weight'].shape[0])
                elif self.context.channel:
C
ceci3 已提交
229
                    new_attr_dict[new_key] = max(cur_channel)
C
ceci3 已提交
230
                else:
C
ceci3 已提交
231 232 233
                    new_attr_dict[new_key] = attr_dict[
                        '_num_channels'] if '_num_channels' in attr_dict.keys(
                        ) else attr_dict['_num_features']
C
ceci3 已提交
234 235

                for attr in new_attr_name:
C
ceci3 已提交
236
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
237 238 239

                del layer, attr_dict

C
ceci3 已提交
240 241
                layer = getattr(layers, 'SuperBatchNorm', SuperBatchNorm2D)(
                    **new_attr_dict)
C
ceci3 已提交
242 243 244 245 246 247 248 249 250
                model[idx] = layer

            ### assume output_size = None, filter_size != None
            ### NOTE: output_size != None may raise error, solve when it happend. 
            elif isinstance(layer, Conv2DTranspose):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']

                new_attr_name = [
C
ceci3 已提交
251
                    'stride', 'padding', 'dilation', 'groups', 'bias_attr'
C
ceci3 已提交
252
                ]
C
ceci3 已提交
253 254 255
                assert getattr(
                    attr_dict, '_filter_size', '_kernel_size'
                ) != None, "Conv2DTranspose only support kernel size != None now"
C
ceci3 已提交
256

C
ceci3 已提交
257 258 259 260 261 262 263 264 265 266
                if pd_ver == 185:
                    new_attr_name += [
                        'output_size', 'param_attr', 'use_cudnn', 'act', 'dtype'
                    ]
                else:
                    new_attr_name += [
                        'output_padding', 'weight_attr', 'data_format'
                    ]

                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
267
                new_attr_dict['candidate_config'] = dict()
C
ceci3 已提交
268 269 270 271 272 273 274 275
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                    new_attr_dict['num_filters'] = None
                    new_attr_dict['filter_size'] = None
                else:
                    new_attr_dict['in_channels'] = None
                    new_attr_dict['out_channels'] = None
                    new_attr_dict['kernel_size'] = None
C
ceci3 已提交
276 277 278
                self.kernel_size = getattr(self.context, 'kernel_size', None)

                # if the kernel_size of conv transpose is 1, don't change it.
C
ceci3 已提交
279 280 281 282 283 284 285 286
                fks = '_filter_size' if '_filter_size' in attr_dict.keys(
                ) else '_kernel_size'
                ks = list(attr_dict[fks]) if isinstance(
                    attr_dict[fks], numbers.Integral) else attr_dict[fks]

                if self.kernel_size and int(ks[0]) != 1:
                    new_attr_dict['transform_kernel'] = True
                    new_attr_dict[fks[1:]] = max(self.kernel_size)
C
ceci3 已提交
287 288 289 290
                    new_attr_dict['candidate_config'].update({
                        'kernel_size': self.kernel_size
                    })
                else:
C
ceci3 已提交
291
                    new_attr_dict[fks[1:]] = attr_dict[fks]
C
ceci3 已提交
292

C
ceci3 已提交
293 294 295 296
                in_key = '_num_channels' if '_num_channels' in attr_dict.keys(
                ) else '_in_channels'
                out_key = '_num_filters' if '_num_filters' in attr_dict.keys(
                ) else '_out_channels'
C
ceci3 已提交
297 298 299
                if self.context.expand:
                    ### first super convolution transpose
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
300
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
301
                    else:
C
ceci3 已提交
302 303
                        new_attr_dict[in_key[1:]] = int(self.context.expand *
                                                        attr_dict[in_key])
C
ceci3 已提交
304 305
                    ### last super convolution transpose
                    if idx == last_weight_layer_idx:
C
ceci3 已提交
306
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
307
                    else:
C
ceci3 已提交
308 309
                        new_attr_dict[out_key[1:]] = int(self.context.expand *
                                                         attr_dict[out_key])
C
ceci3 已提交
310 311 312 313 314
                        new_attr_dict['candidate_config'].update({
                            'expand_ratio': self.context.expand_ratio
                        })
                elif self.context.channel:
                    if attr_dict['_groups'] != None and (
C
ceci3 已提交
315 316
                            int(attr_dict['_groups']) == int(attr_dict[in_key])
                    ):
C
ceci3 已提交
317 318 319 320 321 322 323 324 325 326
                        ### depthwise conv_transpose
                        _logger.warn(
                        "If convolution is a depthwise conv_transpose, output channel " \
                        "change to the same channel with input, output channel in search is not used."
                        )
                        cur_channel = pre_channel
                    else:
                        cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
327
                        new_attr_dict[in_key[1:]] = attr_dict[in_key]
C
ceci3 已提交
328
                    else:
C
ceci3 已提交
329
                        new_attr_dict[in_key[1:]] = max(pre_channel)
C
ceci3 已提交
330 331

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
332
                        new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
333
                    else:
C
ceci3 已提交
334
                        new_attr_dict[out_key[1:]] = max(cur_channel)
C
ceci3 已提交
335 336 337 338 339
                        new_attr_dict['candidate_config'].update({
                            'channel': cur_channel
                        })
                        pre_channel = cur_channel
                else:
C
ceci3 已提交
340 341
                    new_attr_dict[in_key[1:]] = attr_dict[in_key]
                    new_attr_dict[out_key[1:]] = attr_dict[out_key]
C
ceci3 已提交
342 343

                for attr in new_attr_name:
C
ceci3 已提交
344 345 346 347 348 349
                    if attr == 'weight_attr':
                        new_attr_dict[attr] = attr_dict['_param_attr']
                    elif attr == 'output_padding':
                        new_attr_dict[attr] = attr_dict[attr]
                    else:
                        new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
350 351 352

                del layer

C
ceci3 已提交
353
                if getattr(new_attr_dict, 'output_size', None) == []:
C
ceci3 已提交
354 355 356 357 358 359 360
                    new_attr_dict['output_size'] = None

                if attr_dict['_groups'] == None or int(attr_dict[
                        '_groups']) == 1:
                    ### standard conv_transpose
                    layer = Block(
                        SuperConv2DTranspose(**new_attr_dict), key=key)
C
ceci3 已提交
361
                elif int(attr_dict['_groups']) == int(attr_dict[in_key]):
C
ceci3 已提交
362 363 364
                    # if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
                    # channel in candidate_config = in_channel_list
                    if 'channel' in new_attr_dict['candidate_config']:
C
ceci3 已提交
365 366
                        new_attr_dict[in_key[1:]] = max(cur_channel)
                        new_attr_dict[out_key[1:]] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
367 368
                        new_attr_dict['candidate_config'][
                            'channel'] = cur_channel
C
ceci3 已提交
369
                    new_attr_dict['groups'] = new_attr_dict[in_key[1:]]
C
ceci3 已提交
370 371 372 373 374 375 376 377 378 379 380 381 382
                    layer = Block(
                        SuperDepthwiseConv2DTranspose(**new_attr_dict), key=key)
                else:
                    ### group conv_transpose
                    layer = Block(
                        SuperGroupConv2DTranspose(**new_attr_dict), key=key)
                model[idx] = layer

            elif isinstance(layer, Linear) and (
                    getattr(self.context, 'expand', None) != None or
                    getattr(self.context, 'channel', None) != None):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']
C
ceci3 已提交
383 384 385 386
                if pd_ver == 185:
                    new_attr_name = ['param_attr', 'bias_attr', 'act', 'dtype']
                else:
                    new_attr_name = ['weight_attr', 'bias_attr']
C
ceci3 已提交
387 388
                in_nc, out_nc = layer._parameters['weight'].shape

C
ceci3 已提交
389
                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
390
                new_attr_dict['candidate_config'] = dict()
C
ceci3 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403
                if pd_ver == 185:
                    new_attr_dict['input_dim'] = None
                    new_attr_dict['output_dim'] = None
                else:
                    new_attr_dict['in_features'] = None
                    new_attr_dict['out_features'] = None

                in_key = '_input_dim' if '_input_dim' in attr_dict.keys(
                ) else '_in_features'
                out_key = '_output_dim' if '_output_dim' in attr_dict.keys(
                ) else '_out_features'
                attr_dict[in_key] = in_nc
                attr_dict[out_key] = out_nc
C
ceci3 已提交
404 405
                if self.context.expand:
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
406
                        new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
C
ceci3 已提交
407
                    else:
C
ceci3 已提交
408 409
                        new_attr_dict[in_key[1:]] = int(self.context.expand *
                                                        attr_dict[in_key])
C
ceci3 已提交
410 411

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
412
                        new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
C
ceci3 已提交
413
                    else:
C
ceci3 已提交
414 415
                        new_attr_dict[out_key[1:]] = int(self.context.expand *
                                                         attr_dict[out_key])
C
ceci3 已提交
416 417 418 419 420 421 422
                        new_attr_dict['candidate_config'].update({
                            'expand_ratio': self.context.expand_ratio
                        })
                elif self.context.channel:
                    cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
                    if idx == first_weight_layer_idx:
C
ceci3 已提交
423
                        new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
C
ceci3 已提交
424
                    else:
C
ceci3 已提交
425
                        new_attr_dict[in_key[1:]] = max(pre_channel)
C
ceci3 已提交
426 427

                    if idx == last_weight_layer_idx:
C
ceci3 已提交
428
                        new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
C
ceci3 已提交
429
                    else:
C
ceci3 已提交
430
                        new_attr_dict[out_key[1:]] = max(cur_channel)
C
ceci3 已提交
431 432 433 434 435
                        new_attr_dict['candidate_config'].update({
                            'channel': cur_channel
                        })
                        pre_channel = cur_channel
                else:
C
ceci3 已提交
436 437
                    new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
                    new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
C
ceci3 已提交
438 439

                for attr in new_attr_name:
C
ceci3 已提交
440
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
441 442 443 444 445 446

                del layer, attr_dict

                layer = Block(SuperLinear(**new_attr_dict), key=key)
                model[idx] = layer

C
ceci3 已提交
447 448 449 450 451 452
            elif isinstance(
                    layer,
                    getattr(nn, 'InstanceNorm2D',
                            paddle.fluid.dygraph.nn.InstanceNorm)) and (
                                getattr(self.context, 'expand', None) != None or
                                getattr(self.context, 'channel', None) != None):
C
ceci3 已提交
453 454 455 456 457
                # num_features in InstanceNorm don't change after last weight operators
                if idx > last_weight_layer_idx:
                    continue

                attr_dict = layer.__dict__
C
ceci3 已提交
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
                if pd_ver == 185:
                    new_attr_name = [
                        'bias_attr', 'epsilon', 'param_attr', 'dtype'
                    ]
                else:
                    new_attr_name = ['bias_attr', 'epsilon', 'weight_attr']
                new_attr_dict = dict.fromkeys(new_attr_name, None)
                if pd_ver == 185:
                    new_attr_dict['num_channels'] = None
                else:
                    new_attr_dict['num_features'] = None
                new_key = '_num_channels' if '_num_channels' in new_attr_dict.keys(
                ) else '_num_features'
                ### 10 is a default channel in the case of weight_attr=False, in this condition, num of channels if useless, so give it arbitrarily.
                attr_dict[new_key] = layer._parameters['scale'].shape[0] if len(
                    layer._parameters) != 0 else 10

C
ceci3 已提交
475
                if self.context.expand:
C
ceci3 已提交
476 477
                    new_attr_dict[new_key[1:]] = int(self.context.expand *
                                                     attr_dict[new_key])
C
ceci3 已提交
478
                elif self.context.channel:
C
ceci3 已提交
479
                    new_attr_dict[new_key[1:]] = max(cur_channel)
C
ceci3 已提交
480
                else:
C
ceci3 已提交
481
                    new_attr_dict[new_key[1:]] = attr_dict[new_key]
C
ceci3 已提交
482 483

                for attr in new_attr_name:
C
ceci3 已提交
484
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
485 486 487

                del layer, attr_dict

C
ceci3 已提交
488 489
                layer = getattr(layers, 'SuperInstanceNorm2D',
                                'SuperInstanceNorm')(**new_attr_dict)
C
ceci3 已提交
490 491
                model[idx] = layer

C
ceci3 已提交
492 493 494 495 496 497 498 499
            elif isinstance(layer, LayerNorm) and (
                    getattr(self.context, 'expand', None) != None or
                    getattr(self.context, 'channel', None) != None):
                ### TODO(ceci3): fix when normalized_shape != last_dim_of_input
                if idx > last_weight_layer_idx:
                    continue

                attr_dict = layer.__dict__
C
ceci3 已提交
500 501 502 503 504 505 506 507 508 509
                new_attr_name = ['epsilon', 'bias_attr']
                if pd_ver == 185:
                    new_attr_name += [
                        'scale', 'shift', 'param_attr', 'act', 'dtype'
                    ]
                else:
                    new_attr_name += ['weight_attr']

                new_attr_dict = dict.fromkeys(new_attr_name, None)
                new_attr_dict['normalized_shape'] = None
C
ceci3 已提交
510
                if self.context.expand:
C
ceci3 已提交
511 512
                    new_attr_dict['normalized_shape'] = int(
                        self.context.expand * attr_dict['_normalized_shape'][0])
C
ceci3 已提交
513 514 515 516 517 518 519
                elif self.context.channel:
                    new_attr_dict['normalized_shape'] = max(cur_channel)
                else:
                    new_attr_dict['normalized_shape'] = attr_dict[
                        '_normalized_shape']

                for attr in new_attr_name:
C
ceci3 已提交
520
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
521 522 523 524 525 526 527 528 529 530

                del layer, attr_dict
                layer = SuperLayerNorm(**new_attr_dict)
                model[idx] = layer

            elif isinstance(layer, Embedding) and (
                    getattr(self.context, 'expand', None) != None or
                    getattr(self.context, 'channel', None) != None):
                attr_dict = layer.__dict__
                key = attr_dict['_full_name']
C
ceci3 已提交
531 532 533 534 535 536 537 538 539 540 541
                new_attr_name = ['padding_idx', ]
                if pd_ver == 185:
                    new_attr_name += [
                        'size', 'is_sparse', 'is_distributed', 'param_attr',
                        'dtype'
                    ]
                else:
                    new_attr_name += [
                        'num_embeddings', 'embedding_dim', 'sparse',
                        'weight_attr', 'name'
                    ]
C
ceci3 已提交
542

C
ceci3 已提交
543
                new_attr_dict = dict.fromkeys(new_attr_name, None)
C
ceci3 已提交
544 545 546
                new_attr_dict['candidate_config'] = dict()
                bef_size = attr_dict['_size']
                if self.context.expand:
C
ceci3 已提交
547 548 549 550 551 552 553 554 555 556
                    if pd_ver == 185:
                        new_attr_dict['size'] = [
                            bef_size[0], int(self.context.expand * bef_size[1])
                        ]
                    else:
                        new_attr_dict['num_embeddings'] = attr_dict[
                            '_num_embeddings']
                        new_attr_dict['embedding_dim'] = int(
                            self.context.expand * attr_dict['_embedding_dim'])

C
ceci3 已提交
557 558 559 560 561 562 563
                    new_attr_dict['candidate_config'].update({
                        'expand_ratio': self.context.expand_ratio
                    })

                elif self.context.channel:
                    cur_channel = self.context.channel[0]
                    self.context.channel = self.context.channel[1:]
C
ceci3 已提交
564 565 566 567 568 569 570
                    if pd_ver == 185:
                        new_attr_dict['size'] = [bef_size[0], max(cur_channel)]
                    else:
                        new_attr_dict['num_embeddings'] = attr_dict[
                            '_num_embeddings']
                        new_attr_dict['embedding_dim'] = max(cur_channel)

C
ceci3 已提交
571 572 573 574 575
                    new_attr_dict['candidate_config'].update({
                        'channel': cur_channel
                    })
                    pre_channel = cur_channel
                else:
C
ceci3 已提交
576 577 578 579 580 581 582
                    if pf_ver == 185:
                        new_attr_dict['size'] = bef_size
                    else:
                        new_attr_dict['num_embeddings'] = attr_dict[
                            '_num_embeddings']
                        new_attr_dict['embedding_dim'] = attr_dict[
                            '_embedding_dim']
C
ceci3 已提交
583 584

                for attr in new_attr_name:
C
ceci3 已提交
585
                    new_attr_dict[attr] = attr_dict['_' + attr]
C
ceci3 已提交
586 587 588 589 590 591

                del layer, attr_dict

                layer = Block(SuperEmbedding(**new_attr_dict), key=key)
                model[idx] = layer

C
ceci3 已提交
592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609
        def split_prefix(net, name_list):
            if len(name_list) > 1:
                net = split_prefix(getattr(net, name_list[0]), name_list[1:])
            elif len(name_list) == 1:
                net = getattr(net, name_list[0])
            else:
                raise NotImplementedError("name error")
            return net

        if isinstance(network, Layer):
            for idx, (name, sublayer) in enumerate(network.named_sublayers()):
                if len(name.split('.')) > 1:
                    net = split_prefix(network, name.split('.')[:-1])
                else:
                    net = network
                setattr(net, name.split('.')[-1], model[idx])

        return network
C
ceci3 已提交
610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628


class supernet:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

        assert (
            getattr(self, 'expand_ratio', None) == None or
            getattr(self, 'channel', None) == None
        ), "expand_ratio and channel CANNOT be NOT None at the same time."

        self.expand = None
        if 'expand_ratio' in kwargs.keys():
            if isinstance(self.expand_ratio, list) or isinstance(
                    self.expand_ratio, tuple):
                self.expand = max(self.expand_ratio)
            elif isinstance(self.expand_ratio, int):
                self.expand = self.expand_ratio
C
ceci3 已提交
629 630
        if 'channel' not in kwargs.keys():
            self.channel = None
C
ceci3 已提交
631 632 633 634 635

    def __enter__(self):
        return Convert(self)

    def __exit__(self, exc_type, exc_val, exc_tb):
C
ceci3 已提交
636 637 638
        self.expand = None
        self.channel = None
        self.kernel_size = None
C
ceci3 已提交
639 640 641 642 643 644 645 646 647


#def ofa_supernet(kernel_size, expand_ratio):
#    def _ofa_supernet(func):
#        @functools.wraps(func)
#        def convert(*args, **kwargs):
#            supernet_convert(*args, **kwargs)
#        return convert
#    return _ofa_supernet