提交 962ac804 编写于 作者: W wanghaoshuang

Add unittest for light-nas.

上级 97f1776b
...@@ -21,9 +21,12 @@ import controller_server ...@@ -21,9 +21,12 @@ import controller_server
from controller_server import * from controller_server import *
import controller_client import controller_client
from controller_client import * from controller_client import *
import lock_utils
from lock_utils import *
__all__ = [] __all__ = []
__all__ += controller.__all__ __all__ += controller.__all__
__all__ += sa_controller.__all__ __all__ += sa_controller.__all__
__all__ += controller_server.__all__ __all__ += controller_server.__all__
__all__ += controller_client.__all__ __all__ += controller_client.__all__
__all__ += lock_utils.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
__all__ = ['lock', 'unlock']
if os.name == 'nt':
def lock(file):
raise NotImplementedError('Windows is not supported.')
def unlock(file):
raise NotImplementedError('Windows is not supported.')
elif os.name == 'posix':
from fcntl import flock, LOCK_EX, LOCK_UN
def lock(file):
"""Lock the file in local file system."""
flock(file.fileno(), LOCK_EX)
def unlock(file):
"""Unlock the file in local file system."""
flock(file.fileno(), LOCK_UN)
else:
raise RuntimeError("File Locker only support NT and Posix platforms!")
...@@ -23,6 +23,7 @@ from ..analysis import flops ...@@ -23,6 +23,7 @@ from ..analysis import flops
from ..common import ControllerServer from ..common import ControllerServer
from ..common import ControllerClient from ..common import ControllerClient
from .search_space import SearchSpaceFactory
__all__ = ["SANAS"] __all__ = ["SANAS"]
...@@ -32,8 +33,8 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -32,8 +33,8 @@ _logger = get_logger(__name__, level=logging.INFO)
class SANAS(object): class SANAS(object):
def __init__(self, def __init__(self,
configs, configs,
flops=None, max_flops=None,
latency=None, max_latency=None,
server_addr=("", 0), server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
...@@ -73,6 +74,9 @@ class SANAS(object): ...@@ -73,6 +74,9 @@ class SANAS(object):
self._search_space = factory.get_search_space(configs) self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens() init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table() range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
print range_table
controller = SAController(range_table, self._reduce_rate, controller = SAController(range_table, self._reduce_rate,
self._init_temperature, self._max_try_number, self._init_temperature, self._max_try_number,
...@@ -112,6 +116,7 @@ class SANAS(object): ...@@ -112,6 +116,7 @@ class SANAS(object):
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
i = 0 i = 0
for config, arch in zip(self._configs, archs): for config, arch in zip(self._configs, archs):
input_size = config[1]["input_size"]
input = fluid.data( input = fluid.data(
name="data_{}".format(i), name="data_{}".format(i),
shape=[None, 3, input_size, input_size], shape=[None, 3, input_size, input_size],
......
...@@ -52,7 +52,6 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -52,7 +52,6 @@ class MobileNetV2Space(SearchSpaceBase):
self.scale = scale self.scale = scale
self.class_dim = class_dim self.class_dim = class_dim
def init_tokens(self): def init_tokens(self):
""" """
The initial token send to controller. The initial token send to controller.
...@@ -71,10 +70,11 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -71,10 +70,11 @@ class MobileNetV2Space(SearchSpaceBase):
4, 9, 0, 0] # 6, 320, 1 4, 9, 0, 0] # 6, 320, 1
# yapf: enable # yapf: enable
if self.block_num < 5: if self.block_num < 5:
self.token_len = 1 + (self.block_num - 1) * 4 self.token_len = 1 + (self.block_num - 1) * 4
else: else:
self.token_len = 1 + (self.block_num + 2 * (self.block_num - 5)) * 4 self.token_len = 1 + (self.block_num + 2 *
(self.block_num - 5)) * 4
return init_token_base[:self.token_len] return init_token_base[:self.token_len]
...@@ -92,6 +92,7 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -92,6 +92,7 @@ class MobileNetV2Space(SearchSpaceBase):
5, 10, 6, 2, 5, 10, 6, 2,
5, 10, 6, 2, 5, 10, 6, 2,
5, 12, 6, 2] 5, 12, 6, 2]
range_table_base = list(np.array(range_table_base) - 1)
# yapf: enable # yapf: enable
return range_table_base[:self.token_len] return range_table_base[:self.token_len]
...@@ -107,24 +108,36 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -107,24 +108,36 @@ class MobileNetV2Space(SearchSpaceBase):
tokens = self.init_tokens() tokens = self.init_tokens()
bottleneck_params_list = [] bottleneck_params_list = []
if self.block_num >= 1: bottleneck_params_list.append((1, self.head_num[tokens[0]], 1, 1, 3)) if self.block_num >= 1:
if self.block_num >= 2: bottleneck_params_list.append((self.multiply[tokens[1]], self.filter_num1[tokens[2]], bottleneck_params_list.append(
self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) (1, self.head_num[tokens[0]], 1, 1, 3))
if self.block_num >= 3: bottleneck_params_list.append((self.multiply[tokens[5]], self.filter_num1[tokens[6]], if self.block_num >= 2:
self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) bottleneck_params_list.append(
if self.block_num >= 4: bottleneck_params_list.append((self.multiply[tokens[9]], self.filter_num2[tokens[10]], (self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) self.repeat[tokens[3]], 2, self.k_size[tokens[4]]))
if self.block_num >= 5: if self.block_num >= 3:
bottleneck_params_list.append((self.multiply[tokens[13]], self.filter_num3[tokens[14]], bottleneck_params_list.append(
self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) (self.multiply[tokens[5]], self.filter_num1[tokens[6]],
bottleneck_params_list.append((self.multiply[tokens[17]], self.filter_num3[tokens[18]], self.repeat[tokens[7]], 2, self.k_size[tokens[8]]))
self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) if self.block_num >= 4:
if self.block_num >= 6: bottleneck_params_list.append(
bottleneck_params_list.append((self.multiply[tokens[21]], self.filter_num5[tokens[22]], (self.multiply[tokens[9]], self.filter_num2[tokens[10]],
self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) self.repeat[tokens[11]], 2, self.k_size[tokens[12]]))
bottleneck_params_list.append((self.multiply[tokens[25]], self.filter_num6[tokens[26]], if self.block_num >= 5:
self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) bottleneck_params_list.append(
(self.multiply[tokens[13]], self.filter_num3[tokens[14]],
self.repeat[tokens[15]], 2, self.k_size[tokens[16]]))
bottleneck_params_list.append(
(self.multiply[tokens[17]], self.filter_num3[tokens[18]],
self.repeat[tokens[19]], 1, self.k_size[tokens[20]]))
if self.block_num >= 6:
bottleneck_params_list.append(
(self.multiply[tokens[21]], self.filter_num5[tokens[22]],
self.repeat[tokens[23]], 2, self.k_size[tokens[24]]))
bottleneck_params_list.append(
(self.multiply[tokens[25]], self.filter_num6[tokens[26]],
self.repeat[tokens[27]], 1, self.k_size[tokens[28]]))
def net_arch(input): def net_arch(input):
#conv1 #conv1
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic. # all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
...@@ -182,15 +195,15 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -182,15 +195,15 @@ class MobileNetV2Space(SearchSpaceBase):
return fluid.layers.elementwise_add(input, data_residual) return fluid.layers.elementwise_add(input, data_residual)
def _inverted_residual_unit(self, def _inverted_residual_unit(self,
input, input,
num_in_filter, num_in_filter,
num_filters, num_filters,
ifshortcut, ifshortcut,
stride, stride,
filter_size, filter_size,
expansion_factor, expansion_factor,
reduction_ratio=4, reduction_ratio=4,
name=None): name=None):
"""Build inverted residual unit. """Build inverted residual unit.
Args: Args:
input(Variable), input. input(Variable), input.
......
...@@ -25,19 +25,25 @@ from .search_space_registry import SEARCHSPACE ...@@ -25,19 +25,25 @@ from .search_space_registry import SEARCHSPACE
__all__ = ["ResNetSpace"] __all__ = ["ResNetSpace"]
@SEARCHSPACE.register @SEARCHSPACE.register
class ResNetSpace(SearchSpaceBase): class ResNetSpace(SearchSpaceBase):
def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000): def __init__(self,
input_size,
output_size,
block_num,
scale=1.0,
class_dim=1000):
super(ResNetSpace, self).__init__(input_size, output_size, block_num) super(ResNetSpace, self).__init__(input_size, output_size, block_num)
pass pass
def init_tokens(self): def init_tokens(self):
return [0,0,0,0,0,0] return [0, 0, 0, 0, 0, 0]
def range_table(self): def range_table(self):
return [3,3,3,3,3,3] return [2, 2, 2, 2, 2, 2]
def token2arch(self,tokens=None): def token2arch(self, tokens=None):
if tokens is None: if tokens is None:
self.init_tokens() self.init_tokens()
...@@ -54,5 +60,3 @@ class ResNetSpace(SearchSpaceBase): ...@@ -54,5 +60,3 @@ class ResNetSpace(SearchSpaceBase):
return input return input
return net_arch return net_arch
...@@ -16,6 +16,8 @@ sys.path.append("../") ...@@ -16,6 +16,8 @@ sys.path.append("../")
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.nas import SANAS from paddleslim.nas import SANAS
from paddleslim.nas import SearchSpaceFactory
from paddleslim.analysis import flops
class TestSANAS(unittest.TestCase): class TestSANAS(unittest.TestCase):
...@@ -27,27 +29,27 @@ class TestSANAS(unittest.TestCase): ...@@ -27,27 +29,27 @@ class TestSANAS(unittest.TestCase):
configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)] configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)]
space = factory.get_search_space([('MobileNetV2Space', config0)]) space = factory.get_search_space([('MobileNetV2Space', config0)])
origin_arch = space.token2arch() origin_arch = space.token2arch()[0]
main_program = fluid.Program() main_program = fluid.Program()
s_program = fluid.Program() s_program = fluid.Program()
with fluid.program_guard(main_program, s_program): with fluid.program_guard(main_program, s_program):
input = fluid.data( input = fluid.data(
name="input", shape=[3, 224, 224], dtype="float32") name="input", shape=[None, 3, 224, 224], dtype="float32")
origin_arch(input) origin_arch(input)
base_flops = flops(main_program) base_flops = flops(main_program)
serch_steps = 3 search_steps = 3
sa_nas = SANAS( sa_nas = SANAS(
configs, max_flops=base_flops, search_steps=search_steps) configs, max_flops=base_flops, search_steps=search_steps)
for i in range(serch_steps): for i in range(search_steps):
archs = sa_nas.next_archs() archs = sa_nas.next_archs()
main_program = fluid.Program() main_program = fluid.Program()
s_program = fluid.Program() s_program = fluid.Program()
with fluid.program_guard(main_program, s_program): with fluid.program_guard(main_program, s_program):
input = fluid.data( input = fluid.data(
name="input", shape=[3, 224, 224], dtype="float32") name="input", shape=[None, 3, 224, 224], dtype="float32")
archs[0](input) archs[0](input)
sa_nas.reward(1) sa_nas.reward(1)
self.assertTrue(flops(main_program) < base_flops) self.assertTrue(flops(main_program) < base_flops)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册