test_sa_nas.py 4.2 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
B
Bai Yifan 已提交
14 15
import sys
sys.path.append("../")
C
ceci3 已提交
16
import os
W
wanghaoshuang 已提交
17 18 19 20
import sys
import unittest
import paddle.fluid as fluid
from paddleslim.nas import SANAS
W
wanghaoshuang 已提交
21
from paddleslim.analysis import flops
C
ceci3 已提交
22
import numpy as np
W
wanghaoshuang 已提交
23

B
Bai Yifan 已提交
24

C
ceci3 已提交
25 26 27 28 29
def compute_op_num(program):
    params = {}
    ch_list = []
    for block in program.blocks:
        for param in block.all_parameters():
B
Bai Yifan 已提交
30
            if len(param.shape) == 4:
C
ceci3 已提交
31 32 33
                params[param.name] = param.shape
                ch_list.append(int(param.shape[0]))
    return params, ch_list
W
wanghaoshuang 已提交
34

B
Bai Yifan 已提交
35

W
wanghaoshuang 已提交
36
class TestSANAS(unittest.TestCase):
C
ceci3 已提交
37 38 39
    def setUp(self):
        self.init_test_case()
        port = np.random.randint(8337, 8773)
B
Bai Yifan 已提交
40 41
        self.sanas = SANAS(
            configs=self.configs, server_addr=("", port), save_checkpoint=None)
W
wanghaoshuang 已提交
42

C
ceci3 已提交
43
    def init_test_case(self):
B
Bai Yifan 已提交
44
        self.configs = [('MobileNetV2BlockSpace', {'block_mask': [0]})]
C
ceci3 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
        self.filter_num = np.array([
            3, 4, 8, 12, 16, 24, 32, 48, 64, 80, 96, 128, 144, 160, 192, 224,
            256, 320, 384, 512
        ])
        self.k_size = np.array([3, 5])
        self.multiply = np.array([1, 2, 3, 4, 5, 6])
        self.repeat = np.array([1, 2, 3, 4, 5, 6])

    def check_chnum_convnum(self, program):
        current_tokens = self.sanas.current_info()['current_tokens']
        channel_exp = self.multiply[current_tokens[0]]
        filter_num = self.filter_num[current_tokens[1]]
        repeat_num = self.repeat[current_tokens[2]]

        conv_list, ch_pro = compute_op_num(program)
        ### assert conv number
B
Bai Yifan 已提交
61 62 63 64
        self.assertTrue((repeat_num * 3) == len(
            conv_list
        ), "the number of conv is NOT match, the number compute from token: {}, actual conv number: {}".
                        format(repeat_num * 3, len(conv_list)))
C
ceci3 已提交
65 66 67 68 69 70 71 72 73 74

        ### assert number of channels
        ch_token = []
        init_ch_num = 32
        for i in range(repeat_num):
            ch_token.append(init_ch_num * channel_exp)
            ch_token.append(init_ch_num * channel_exp)
            ch_token.append(filter_num)
            init_ch_num = filter_num

B
Bai Yifan 已提交
75 76 77 78
        self.assertTrue(
            str(ch_token) == str(ch_pro),
            "channel num is WRONG, channel num from token is {}, channel num come fom program is {}".
            format(str(ch_token), str(ch_pro)))
C
ceci3 已提交
79 80 81 82 83 84 85 86

    def test_all_function(self):
        ### unittest for next_archs
        next_program = fluid.Program()
        startup_program = fluid.Program()
        token2arch_program = fluid.Program()

        with fluid.program_guard(next_program, startup_program):
B
Bai Yifan 已提交
87 88
            inputs = fluid.data(
                name='input', shape=[None, 3, 32, 32], dtype='float32')
C
ceci3 已提交
89 90 91 92 93 94 95 96 97 98 99
            archs = self.sanas.next_archs()
            for arch in archs:
                output = arch(inputs)
                inputs = output
        self.check_chnum_convnum(next_program)

        ### unittest for reward
        self.assertTrue(self.sanas.reward(float(1.0)), "reward is False")

        ### uniitest for tokens2arch
        with fluid.program_guard(token2arch_program, startup_program):
B
Bai Yifan 已提交
100 101 102 103
            inputs = fluid.data(
                name='input', shape=[None, 3, 32, 32], dtype='float32')
            arch = self.sanas.tokens2arch(self.sanas.current_info()[
                'current_tokens'])
C
ceci3 已提交
104 105 106 107 108 109 110
            for arch in archs:
                output = arch(inputs)
                inputs = output
        self.check_chnum_convnum(token2arch_program)

        ### unittest for current_info
        current_info = self.sanas.current_info()
B
Bai Yifan 已提交
111 112 113 114 115
        self.assertTrue(
            isinstance(current_info, dict),
            "the type of current info must be dict, but now is {}".format(
                type(current_info)))

W
wanghaoshuang 已提交
116 117 118

if __name__ == '__main__':
    unittest.main()