combine_search_space.py 4.3 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
C
ceci3 已提交
22 23
import logging
from ...common import get_logger
C
ceci3 已提交
24 25 26 27 28 29
from .search_space_base import SearchSpaceBase
from .search_space_registry import SEARCHSPACE
from .base_layer import conv_bn_layer

__all__ = ["CombineSearchSpace"]

C
ceci3 已提交
30
_logger = get_logger(__name__, level=logging.INFO)
C
ceci3 已提交
31

C
ceci3 已提交
32 33 34 35 36 37
class CombineSearchSpace(object):
    """
    Combine Search Space.
    Args:
        configs(list<tuple>): multi config.
    """
C
ceci3 已提交
38

C
ceci3 已提交
39 40 41 42
    def __init__(self, config_lists):
        self.lens = len(config_lists)
        self.spaces = []
        for config_list in config_lists:
C
ceci3 已提交
43 44 45 46 47 48 49
            if isinstance(config_list, tuple):
                key, config = config_list
            if isinstance(config_list, str):
                key = config_list
                config = None
            else:
                raise NotImplementedError('the type of config is Error!!! Please check the config information. Receive the type of config is {}'.format(type(config_list)))
C
ceci3 已提交
50
            self.spaces.append(self._get_single_search_space(key, config))
W
wanghaoshuang 已提交
51
        self.init_tokens()
C
ceci3 已提交
52 53 54 55 56 57 58 59 60 61 62 63

    def _get_single_search_space(self, key, config):
        """
        get specific model space based on key and config.

        Args:
            key(str): model space name.
            config(dict): basic config information.
        return:
            model space(class)
        """
        cls = SEARCHSPACE.get(key)
C
ceci3 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

        if config is None:
            block_mask = None
            input_size = None
            output_size = None
            block_num = None
        else:
            if 'Block' not in cls.__name__:
                _logger.warn('if space is not a Block space, config is useless, current space is {}'.format(cls.__name__))

            block_mask = config['block_mask'] if 'block_mask' in config else None
            input_size = config['input_size'] if 'input_size' in config else None
            output_size = config['output_size'] if 'output_size' in config else None
            block_num = config['block_num'] if 'block_num' in config else None

        if 'Block' in cls.__name__:
            if block_mask == None and (self.block_num == None or self.input_size == None or self.output_size == None):
                raise NotImplementedError("block_mask or (block num and input_size and output_size) can NOT be None at the same time in Block SPACE!")

        space = cls(input_size,
                    output_size,
                    block_num,
86
                    block_mask=block_mask)
C
ceci3 已提交
87 88 89 90 91 92 93
        return space

    def init_tokens(self):
        """
        Combine init tokens.
        """
        tokens = []
C
ceci3 已提交
94
        self.single_token_num = []
C
ceci3 已提交
95 96
        for space in self.spaces:
            tokens.extend(space.init_tokens())
C
ceci3 已提交
97
            self.single_token_num.append(len(space.init_tokens()))
C
ceci3 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        return tokens

    def range_table(self):
        """
        Combine range table.
        """
        range_tables = []
        for space in self.spaces:
            range_tables.extend(space.range_table())
        return range_tables

    def token2arch(self, tokens=None):
        """
        Combine model arch
        """
        if tokens is None:
C
ceci3 已提交
114 115 116 117 118 119 120 121 122 123
            tokens = self.init_tokens()

        token_list = []
        start_idx = 0
        end_idx = 0

        for i in range(len(self.single_token_num)):
            end_idx += self.single_token_num[i]
            token_list.append(tokens[start_idx:end_idx])
            start_idx = end_idx
C
ceci3 已提交
124 125

        model_archs = []
C
ceci3 已提交
126
        for space, token in zip(self.spaces, token_list):
C
ceci3 已提交
127 128
            model_archs.append(space.token2arch(token))

C
ceci3 已提交
129
        return model_archs