blazefacespace_nas.py 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from paddleslim.nas.search_space.search_space_base import SearchSpaceBase
from paddleslim.nas.search_space.search_space_registry import SEARCHSPACE
from ppdet.modeling.backbones.blazenet import BlazeNet
from ppdet.modeling.architectures.blazeface import BlazeFace


@SEARCHSPACE.register
class BlazeFaceNasSpace(SearchSpaceBase):
    def __init__(self, input_size, output_size, block_num, block_mask):
        super(BlazeFaceNasSpace, self).__init__(input_size, output_size,
                                                block_num, block_mask)
Z
zhouzj 已提交
31 32
        self.blaze_filter_num1 = np.array([4, 8, 12, 16, 24, 32])
        self.blaze_filter_num2 = np.array([8, 12, 16, 24, 32, 40, 48, 64])
33 34 35
        self.mid_filter_num = np.array([8, 12, 16, 20, 24, 32])
        self.double_filter_num = np.array(
            [8, 12, 16, 24, 32, 40, 48, 64, 72, 80, 88, 96])
C
ceci3 已提交
36 37 38
        self.use_5x5kernel = np.array(
            [0]
        )  ### if constraint is latency, use 3x3 kernel, otherwise self.use_5x5kernel = np.array([0, 1])
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

    def init_tokens(self):
        return [2, 1, 3, 8, 2, 1, 2, 1, 1]

    def range_table(self):
        return [
            len(self.blaze_filter_num1), len(self.blaze_filter_num2),
            len(self.double_filter_num), len(self.double_filter_num),
            len(self.mid_filter_num), len(self.mid_filter_num),
            len(self.mid_filter_num), len(self.mid_filter_num),
            len(self.use_5x5kernel)
        ]

    def get_nas_cnf(self, tokens=None):
        if tokens is None:
            tokens = self.init_tokens()

        blaze_filters = [[
            self.blaze_filter_num1[tokens[0]], self.blaze_filter_num1[tokens[0]]
        ], [
            self.blaze_filter_num1[tokens[0]],
            self.blaze_filter_num2[tokens[1]], 2
        ], [
            self.blaze_filter_num2[tokens[1]], self.blaze_filter_num2[tokens[1]]
        ]]

        double_blaze_filters = [[
            self.blaze_filter_num2[tokens[1]], self.mid_filter_num[tokens[4]],
            self.double_filter_num[tokens[2]], 2
        ], [
            self.double_filter_num[tokens[2]], self.mid_filter_num[tokens[5]],
            self.double_filter_num[tokens[2]]
        ], [
            self.double_filter_num[tokens[2]], self.mid_filter_num[tokens[6]],
            self.double_filter_num[tokens[3]], 2
        ], [
            self.double_filter_num[tokens[3]], self.mid_filter_num[tokens[7]],
            self.double_filter_num[tokens[3]]
        ]]

C
ceci3 已提交
79 80
        ### if constraint is latency, use 3x3 kernel, otherwise is_5x5kernel = True if self.use_5x5kernel[tokens[8]] else False
        is_5x5kernel = False  ###True if self.use_5x5kernel[tokens[8]] else False
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        return blaze_filters, double_blaze_filters, is_5x5kernel

    def token2arch(self, tokens=None):

        blaze_filters, double_blaze_filters, is_5x5kernel = self.get_nas_cnf(
            tokens)
        self.print_nas_structure(tokens)

        def net_arch(input, mode, cfg):
            self.output_decoder = cfg.BlazeFace['output_decoder']
            self.min_sizes = cfg.BlazeFace['min_sizes']
            self.use_density_prior_box = cfg.BlazeFace['use_density_prior_box']

            my_backbone = BlazeNet(
                blaze_filters=blaze_filters,
                double_blaze_filters=double_blaze_filters,
                use_5x5kernel=is_5x5kernel)
            my_blazeface = BlazeFace(
                my_backbone,
                output_decoder=self.output_decoder,
                min_sizes=self.min_sizes,
                use_density_prior_box=self.use_density_prior_box)
            return my_blazeface.build(input, mode=mode)

        return net_arch

    def print_nas_structure(self, tokens=None):
        blaze_filters, double_filters, is_5x5kernel = self.get_nas_cnf(tokens)
        print('---------->>> BlazeFace-NAS structure start: <<<------------')
        print('BlazeNet:')
        print('  blaze_filters: {}'.format(blaze_filters))
        print('  double_blaze_filters: {}'.format(double_filters))
        print('  use_5x5kernel: {}'.format(is_5x5kernel))
        print('  with_extra_blocks: true')
        print('  lite_edition: false')
        print('---------->>> BlazeFace-NAS structure end! <<<------------')