提交 962ac804 编写于 作者: W wanghaoshuang

Add unittest for light-nas.

上级 97f1776b
......@@ -21,9 +21,12 @@ import controller_server
from controller_server import *
import controller_client
from controller_client import *
import lock_utils
from lock_utils import *
__all__ = []
__all__ += controller.__all__
__all__ += sa_controller.__all__
__all__ += controller_server.__all__
__all__ += controller_client.__all__
__all__ += lock_utils.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
__all__ = ['lock', 'unlock']
if os.name == 'nt':
def lock(file):
raise NotImplementedError('Windows is not supported.')
def unlock(file):
raise NotImplementedError('Windows is not supported.')
elif os.name == 'posix':
from fcntl import flock, LOCK_EX, LOCK_UN
def lock(file):
"""Lock the file in local file system."""
flock(file.fileno(), LOCK_EX)
def unlock(file):
"""Unlock the file in local file system."""
flock(file.fileno(), LOCK_UN)
else:
raise RuntimeError("File Locker only support NT and Posix platforms!")
......@@ -23,6 +23,7 @@ from ..analysis import flops
from ..common import ControllerServer
from ..common import ControllerClient
from .search_space import SearchSpaceFactory
__all__ = ["SANAS"]
......@@ -32,8 +33,8 @@ _logger = get_logger(__name__, level=logging.INFO)
class SANAS(object):
def __init__(self,
configs,
flops=None,
latency=None,
max_flops=None,
max_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
......@@ -73,6 +74,9 @@ class SANAS(object):
self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
print range_table
controller = SAController(range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
......@@ -112,6 +116,7 @@ class SANAS(object):
with fluid.program_guard(main_program, startup_program):
i = 0
for config, arch in zip(self._configs, archs):
input_size = config[1]["input_size"]
input = fluid.data(
name="data_{}".format(i),
shape=[None, 3, input_size, input_size],
......
......@@ -52,7 +52,6 @@ class MobileNetV2Space(SearchSpaceBase):
self.scale = scale
self.class_dim = class_dim
def init_tokens(self):
"""
The initial token send to controller.
......@@ -74,7 +73,8 @@ class MobileNetV2Space(SearchSpaceBase):
if self.block_num < 5:
self.token_len = 1 + (self.block_num - 1) * 4
else:
self.token_len = 1 + (self.block_num + 2 * (self.block_num - 5)) * 4
self.token_len = 1 + (self.block_num + 2 *
(self.block_num - 5)) * 4
return init_token_base[:self.token_len]
......@@ -92,6 +92,7 @@ class MobileNetV2Space(SearchSpaceBase):
5, 10, 6, 2,
5, 10, 6, 2,
5, 12, 6, 2]
range_table_base = list(np.array(range_table_base) - 1)
# yapf: enable
return range_table_base[:self.token_len]
......@@ -107,22 +108,34 @@ class MobileNetV2Space(SearchSpaceBase):
tokens = self.init_tokens()
bottleneck_params_list = []
if self.block_num >= 1: bottleneck_params_list.append((1, self.head_num[tokens[0]], 1, 1, 3))
if self.block_num >= 2: bottleneck_params_list.append((self.multiply[tokens[1]], self.filter_num1[tokens[2]],
if self.block_num >= 1:
bottleneck_params_list.append(
(1, self.head_num[tokens[0]], 1, 1, 3))
if self.block_num >= 2:
bottleneck_params_list.append(
(self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[3]], 2, self.k_size[tokens[4]]))
if self.block_num >= 3: bottleneck_params_list.append((self.multiply[tokens[5]], self.filter_num1[tokens[6]],
if self.block_num >= 3:
bottleneck_params_list.append(
(self.multiply[tokens[5]], self.filter_num1[tokens[6]],
self.repeat[tokens[7]], 2, self.k_size[tokens[8]]))
if self.block_num >= 4: bottleneck_params_list.append((self.multiply[tokens[9]], self.filter_num2[tokens[10]],
if self.block_num >= 4:
bottleneck_params_list.append(
(self.multiply[tokens[9]], self.filter_num2[tokens[10]],
self.repeat[tokens[11]], 2, self.k_size[tokens[12]]))
if self.block_num >= 5:
bottleneck_params_list.append((self.multiply[tokens[13]], self.filter_num3[tokens[14]],
bottleneck_params_list.append(
(self.multiply[tokens[13]], self.filter_num3[tokens[14]],
self.repeat[tokens[15]], 2, self.k_size[tokens[16]]))
bottleneck_params_list.append((self.multiply[tokens[17]], self.filter_num3[tokens[18]],
bottleneck_params_list.append(
(self.multiply[tokens[17]], self.filter_num3[tokens[18]],
self.repeat[tokens[19]], 1, self.k_size[tokens[20]]))
if self.block_num >= 6:
bottleneck_params_list.append((self.multiply[tokens[21]], self.filter_num5[tokens[22]],
bottleneck_params_list.append(
(self.multiply[tokens[21]], self.filter_num5[tokens[22]],
self.repeat[tokens[23]], 2, self.k_size[tokens[24]]))
bottleneck_params_list.append((self.multiply[tokens[25]], self.filter_num6[tokens[26]],
bottleneck_params_list.append(
(self.multiply[tokens[25]], self.filter_num6[tokens[26]],
self.repeat[tokens[27]], 1, self.k_size[tokens[28]]))
def net_arch(input):
......
......@@ -25,19 +25,25 @@ from .search_space_registry import SEARCHSPACE
__all__ = ["ResNetSpace"]
@SEARCHSPACE.register
class ResNetSpace(SearchSpaceBase):
def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000):
def __init__(self,
input_size,
output_size,
block_num,
scale=1.0,
class_dim=1000):
super(ResNetSpace, self).__init__(input_size, output_size, block_num)
pass
def init_tokens(self):
return [0,0,0,0,0,0]
return [0, 0, 0, 0, 0, 0]
def range_table(self):
return [3,3,3,3,3,3]
return [2, 2, 2, 2, 2, 2]
def token2arch(self,tokens=None):
def token2arch(self, tokens=None):
if tokens is None:
self.init_tokens()
......@@ -54,5 +60,3 @@ class ResNetSpace(SearchSpaceBase):
return input
return net_arch
......@@ -16,6 +16,8 @@ sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.nas import SANAS
from paddleslim.nas import SearchSpaceFactory
from paddleslim.analysis import flops
class TestSANAS(unittest.TestCase):
......@@ -27,27 +29,27 @@ class TestSANAS(unittest.TestCase):
configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)]
space = factory.get_search_space([('MobileNetV2Space', config0)])
origin_arch = space.token2arch()
origin_arch = space.token2arch()[0]
main_program = fluid.Program()
s_program = fluid.Program()
with fluid.program_guard(main_program, s_program):
input = fluid.data(
name="input", shape=[3, 224, 224], dtype="float32")
name="input", shape=[None, 3, 224, 224], dtype="float32")
origin_arch(input)
base_flops = flops(main_program)
serch_steps = 3
search_steps = 3
sa_nas = SANAS(
configs, max_flops=base_flops, search_steps=search_steps)
for i in range(serch_steps):
for i in range(search_steps):
archs = sa_nas.next_archs()
main_program = fluid.Program()
s_program = fluid.Program()
with fluid.program_guard(main_program, s_program):
input = fluid.data(
name="input", shape=[3, 224, 224], dtype="float32")
name="input", shape=[None, 3, 224, 224], dtype="float32")
archs[0](input)
sa_nas.reward(1)
self.assertTrue(flops(main_program) < base_flops)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册