combine_search_space.py 4.6 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
C
ceci3 已提交
22 23
import logging
from ...common import get_logger
C
ceci3 已提交
24 25 26 27 28 29
from .search_space_base import SearchSpaceBase
from .search_space_registry import SEARCHSPACE
from .base_layer import conv_bn_layer

__all__ = ["CombineSearchSpace"]

C
ceci3 已提交
30
_logger = get_logger(__name__, level=logging.INFO)
C
ceci3 已提交
31

C
fix  
ceci3 已提交
32

C
ceci3 已提交
33 34 35 36 37 38
class CombineSearchSpace(object):
    """
    Combine Search Space.
    Args:
        configs(list<tuple>): multi config.
    """
C
ceci3 已提交
39

C
ceci3 已提交
40 41 42 43
    def __init__(self, config_lists):
        self.lens = len(config_lists)
        self.spaces = []
        for config_list in config_lists:
C
ceci3 已提交
44 45
            if isinstance(config_list, tuple):
                key, config = config_list
C
fix  
ceci3 已提交
46
            elif isinstance(config_list, str):
C
ceci3 已提交
47 48 49
                key = config_list
                config = None
            else:
C
fix  
ceci3 已提交
50 51 52
                raise NotImplementedError(
                    'the type of config is Error!!! Please check the config information. Receive the type of config is {}'.
                    format(type(config_list)))
C
ceci3 已提交
53
            self.spaces.append(self._get_single_search_space(key, config))
W
wanghaoshuang 已提交
54
        self.init_tokens()
C
ceci3 已提交
55 56 57 58 59 60 61 62 63 64 65 66

    def _get_single_search_space(self, key, config):
        """
        get specific model space based on key and config.

        Args:
            key(str): model space name.
            config(dict): basic config information.
        return:
            model space(class)
        """
        cls = SEARCHSPACE.get(key)
C
fix  
ceci3 已提交
67 68
        assert cls != None, '{} is NOT a correct space, the space we support is {}'.format(
            key, SEARCHSPACE)
C
ceci3 已提交
69 70 71 72 73 74 75 76

        if config is None:
            block_mask = None
            input_size = None
            output_size = None
            block_num = None
        else:
            if 'Block' not in cls.__name__:
C
fix  
ceci3 已提交
77 78 79
                _logger.warn(
                    'if space is not a Block space, config is useless, current space is {}'.
                    format(cls.__name__))
C
ceci3 已提交
80

C
fix  
ceci3 已提交
81 82 83 84 85 86
            block_mask = config[
                'block_mask'] if 'block_mask' in config else None
            input_size = config[
                'input_size'] if 'input_size' in config else None
            output_size = config[
                'output_size'] if 'output_size' in config else None
C
ceci3 已提交
87 88 89
            block_num = config['block_num'] if 'block_num' in config else None

        if 'Block' in cls.__name__:
C
ceci3 已提交
90 91 92
            if block_mask == None and (block_num == None or
                                       input_size == None or
                                       output_size == None):
C
fix  
ceci3 已提交
93 94 95
                raise NotImplementedError(
                    "block_mask or (block num and input_size and output_size) can NOT be None at the same time in Block SPACE!"
                )
C
ceci3 已提交
96

C
fix  
ceci3 已提交
97
        space = cls(input_size, output_size, block_num, block_mask=block_mask)
C
ceci3 已提交
98 99 100 101 102 103 104
        return space

    def init_tokens(self):
        """
        Combine init tokens.
        """
        tokens = []
C
ceci3 已提交
105
        self.single_token_num = []
C
ceci3 已提交
106 107
        for space in self.spaces:
            tokens.extend(space.init_tokens())
C
ceci3 已提交
108
            self.single_token_num.append(len(space.init_tokens()))
C
ceci3 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
        return tokens

    def range_table(self):
        """
        Combine range table.
        """
        range_tables = []
        for space in self.spaces:
            range_tables.extend(space.range_table())
        return range_tables

    def token2arch(self, tokens=None):
        """
        Combine model arch
        """
        if tokens is None:
C
ceci3 已提交
125 126 127 128 129 130 131 132 133 134
            tokens = self.init_tokens()

        token_list = []
        start_idx = 0
        end_idx = 0

        for i in range(len(self.single_token_num)):
            end_idx += self.single_token_num[i]
            token_list.append(tokens[start_idx:end_idx])
            start_idx = end_idx
C
ceci3 已提交
135 136

        model_archs = []
C
ceci3 已提交
137
        for space, token in zip(self.spaces, token_list):
C
ceci3 已提交
138 139
            model_archs.append(space.token2arch(token))

C
ceci3 已提交
140
        return model_archs