combine_search_space.py 2.9 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase
from .search_space_registry import SEARCHSPACE
from .base_layer import conv_bn_layer

__all__ = ["CombineSearchSpace"]

C
ceci3 已提交
28

C
ceci3 已提交
29 30 31 32 33 34
class CombineSearchSpace(object):
    """
    Combine Search Space.
    Args:
        configs(list<tuple>): multi config.
    """
C
ceci3 已提交
35

C
ceci3 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    def __init__(self, config_lists):
        self.lens = len(config_lists)
        self.spaces = []
        for config_list in config_lists:
            key, config = config_list
            self.spaces.append(self._get_single_search_space(key, config))

    def _get_single_search_space(self, key, config):
        """
        get specific model space based on key and config.

        Args:
            key(str): model space name.
            config(dict): basic config information.
        return:
            model space(class)
        """
        cls = SEARCHSPACE.get(key)
        space = cls(config['input_size'], config['output_size'],
C
ceci3 已提交
55
                    config['block_num'], config['block_mask'])
C
ceci3 已提交
56 57 58 59 60 61 62 63

        return space

    def init_tokens(self):
        """
        Combine init tokens.
        """
        tokens = []
C
ceci3 已提交
64
        self.single_token_num = []
C
ceci3 已提交
65 66
        for space in self.spaces:
            tokens.extend(space.init_tokens())
C
ceci3 已提交
67
            self.single_token_num.append(len(space.init_tokens()))
C
ceci3 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        return tokens

    def range_table(self):
        """
        Combine range table.
        """
        range_tables = []
        for space in self.spaces:
            range_tables.extend(space.range_table())
        return range_tables

    def token2arch(self, tokens=None):
        """
        Combine model arch
        """
        if tokens is None:
C
ceci3 已提交
84 85 86 87 88 89 90 91 92 93
            tokens = self.init_tokens()

        token_list = []
        start_idx = 0
        end_idx = 0

        for i in range(len(self.single_token_num)):
            end_idx += self.single_token_num[i]
            token_list.append(tokens[start_idx:end_idx])
            start_idx = end_idx
C
ceci3 已提交
94 95

        model_archs = []
C
ceci3 已提交
96
        for space, token in zip(self.spaces, token_list):
C
ceci3 已提交
97 98
            model_archs.append(space.token2arch(token))

C
ceci3 已提交
99
        return model_archs