Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
a0d17e44
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0d17e44
编写于
11月 11, 2019
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'rename_nas' into 'develop'
add multi search_space support See merge request
!17
上级
af0eb732
907c7473
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
231 addition
and
62 deletion
+231
-62
paddleslim/nas/search_space/__init__.py
paddleslim/nas/search_space/__init__.py
+3
-0
paddleslim/nas/search_space/base_layer.py
paddleslim/nas/search_space/base_layer.py
+2
-0
paddleslim/nas/search_space/combine_search_space.py
paddleslim/nas/search_space/combine_search_space.py
+99
-0
paddleslim/nas/search_space/mobilenetv2.py
paddleslim/nas/search_space/mobilenetv2.py
+43
-45
paddleslim/nas/search_space/resnet.py
paddleslim/nas/search_space/resnet.py
+58
-0
paddleslim/nas/search_space/search_space_base.py
paddleslim/nas/search_space/search_space_base.py
+1
-1
paddleslim/nas/search_space/search_space_factory.py
paddleslim/nas/search_space/search_space_factory.py
+5
-12
tests/test_searchspace.py
tests/test_searchspace.py
+20
-4
未找到文件。
paddleslim/nas/search_space/__init__.py
浏览文件 @
a0d17e44
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
import
mobilenetv2
import
mobilenetv2
from
.mobilenetv2
import
*
from
.mobilenetv2
import
*
import
resnet
from
.resnet
import
*
import
search_space_registry
import
search_space_registry
from
search_space_registry
import
*
from
search_space_registry
import
*
import
search_space_factory
import
search_space_factory
...
@@ -26,3 +28,4 @@ __all__ += mobilenetv2.__all__
...
@@ -26,3 +28,4 @@ __all__ += mobilenetv2.__all__
__all__
+=
search_space_registry
.
__all__
__all__
+=
search_space_registry
.
__all__
__all__
+=
search_space_factory
.
__all__
__all__
+=
search_space_factory
.
__all__
__all__
+=
search_space_base
.
__all__
__all__
+=
search_space_base
.
__all__
paddleslim/nas/search_space/base_layer.py
浏览文件 @
a0d17e44
...
@@ -59,5 +59,7 @@ def conv_bn_layer(input,
...
@@ -59,5 +59,7 @@ def conv_bn_layer(input,
moving_variance_name
=
bn_name
+
'_variance'
)
moving_variance_name
=
bn_name
+
'_variance'
)
if
act
==
'relu6'
:
if
act
==
'relu6'
:
return
fluid
.
layers
.
relu6
(
bn
)
return
fluid
.
layers
.
relu6
(
bn
)
elif
act
==
'sigmoid'
:
return
fluid
.
layers
.
sigmoid
(
bn
)
else
:
else
:
return
bn
return
bn
paddleslim/nas/search_space/combine_search_space.py
0 → 100644
浏览文件 @
a0d17e44
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
.search_space_base
import
SearchSpaceBase
from
.search_space_registry
import
SEARCHSPACE
from
.base_layer
import
conv_bn_layer
__all__
=
[
"CombineSearchSpace"
]
class
CombineSearchSpace
(
object
):
"""
Combine Search Space.
Args:
configs(list<tuple>): multi config.
"""
def
__init__
(
self
,
config_lists
):
self
.
lens
=
len
(
config_lists
)
self
.
spaces
=
[]
for
config_list
in
config_lists
:
key
,
config
=
config_list
self
.
spaces
.
append
(
self
.
_get_single_search_space
(
key
,
config
))
def
_get_single_search_space
(
self
,
key
,
config
):
"""
get specific model space based on key and config.
Args:
key(str): model space name.
config(dict): basic config information.
return:
model space(class)
"""
cls
=
SEARCHSPACE
.
get
(
key
)
space
=
cls
(
config
[
'input_size'
],
config
[
'output_size'
],
config
[
'block_num'
])
return
space
def
init_tokens
(
self
):
"""
Combine init tokens.
"""
tokens
=
[]
self
.
single_token_num
=
[]
for
space
in
self
.
spaces
:
tokens
.
extend
(
space
.
init_tokens
())
self
.
single_token_num
.
append
(
len
(
space
.
init_tokens
()))
return
tokens
def
range_table
(
self
):
"""
Combine range table.
"""
range_tables
=
[]
for
space
in
self
.
spaces
:
range_tables
.
extend
(
space
.
range_table
())
return
range_tables
def
token2arch
(
self
,
tokens
=
None
):
"""
Combine model arch
"""
if
tokens
is
None
:
tokens
=
self
.
init_tokens
()
token_list
=
[]
start_idx
=
0
end_idx
=
0
for
i
in
range
(
len
(
self
.
single_token_num
)):
end_idx
+=
self
.
single_token_num
[
i
]
token_list
.
append
(
tokens
[
start_idx
:
end_idx
])
start_idx
=
end_idx
model_archs
=
[]
for
space
,
token
in
zip
(
self
.
spaces
,
token_list
):
model_archs
.
append
(
space
.
token2arch
(
token
))
return
model_archs
paddleslim/nas/search_space/mobilenetv2.py
浏览文件 @
a0d17e44
...
@@ -52,6 +52,7 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -52,6 +52,7 @@ class MobileNetV2Space(SearchSpaceBase):
self
.
scale
=
scale
self
.
scale
=
scale
self
.
class_dim
=
class_dim
self
.
class_dim
=
class_dim
def
init_tokens
(
self
):
def
init_tokens
(
self
):
"""
"""
The initial token send to controller.
The initial token send to controller.
...
@@ -60,7 +61,7 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -60,7 +61,7 @@ class MobileNetV2Space(SearchSpaceBase):
"""
"""
# original MobileNetV2
# original MobileNetV2
# yapf: disable
# yapf: disable
return
[
4
,
# 1, 16, 1
init_token_base
=
[
4
,
# 1, 16, 1
4
,
5
,
1
,
0
,
# 6, 24, 1
4
,
5
,
1
,
0
,
# 6, 24, 1
4
,
5
,
1
,
0
,
# 6, 24, 2
4
,
5
,
1
,
0
,
# 6, 24, 2
4
,
4
,
2
,
0
,
# 6, 32, 3
4
,
4
,
2
,
0
,
# 6, 32, 3
...
@@ -70,13 +71,20 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -70,13 +71,20 @@ class MobileNetV2Space(SearchSpaceBase):
4
,
9
,
0
,
0
]
# 6, 320, 1
4
,
9
,
0
,
0
]
# 6, 320, 1
# yapf: enable
# yapf: enable
if
self
.
block_num
<
5
:
self
.
token_len
=
1
+
(
self
.
block_num
-
1
)
*
4
else
:
self
.
token_len
=
1
+
(
self
.
block_num
+
2
*
(
self
.
block_num
-
5
))
*
4
return
init_token_base
[:
self
.
token_len
]
def
range_table
(
self
):
def
range_table
(
self
):
"""
"""
get range table of current search space
get range table of current search space
"""
"""
# head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size]
# head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size]
# yapf: disable
# yapf: disable
r
eturn
[
7
,
r
ange_table_base
=
[
7
,
5
,
8
,
6
,
2
,
5
,
8
,
6
,
2
,
5
,
8
,
6
,
2
,
5
,
8
,
6
,
2
,
5
,
8
,
6
,
2
,
5
,
8
,
6
,
2
,
...
@@ -85,48 +93,38 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -85,48 +93,38 @@ class MobileNetV2Space(SearchSpaceBase):
5
,
10
,
6
,
2
,
5
,
10
,
6
,
2
,
5
,
12
,
6
,
2
]
5
,
12
,
6
,
2
]
# yapf: enable
# yapf: enable
return
range_table_base
[:
self
.
token_len
]
def
token2arch
(
self
,
tokens
=
None
):
def
token2arch
(
self
,
tokens
=
None
):
"""
"""
return net_arch function
return net_arch function
"""
"""
if
tokens
is
None
:
tokens
=
self
.
init_tokens
()
base_bottleneck_params_list
=
[
(
1
,
self
.
head_num
[
tokens
[
0
]],
1
,
1
,
3
),
(
self
.
multiply
[
tokens
[
1
]],
self
.
filter_num1
[
tokens
[
2
]],
self
.
repeat
[
tokens
[
3
]],
2
,
self
.
k_size
[
tokens
[
4
]]),
(
self
.
multiply
[
tokens
[
5
]],
self
.
filter_num1
[
tokens
[
6
]],
self
.
repeat
[
tokens
[
7
]],
2
,
self
.
k_size
[
tokens
[
8
]]),
(
self
.
multiply
[
tokens
[
9
]],
self
.
filter_num2
[
tokens
[
10
]],
self
.
repeat
[
tokens
[
11
]],
2
,
self
.
k_size
[
tokens
[
12
]]),
(
self
.
multiply
[
tokens
[
13
]],
self
.
filter_num3
[
tokens
[
14
]],
self
.
repeat
[
tokens
[
15
]],
2
,
self
.
k_size
[
tokens
[
16
]]),
(
self
.
multiply
[
tokens
[
17
]],
self
.
filter_num3
[
tokens
[
18
]],
self
.
repeat
[
tokens
[
19
]],
1
,
self
.
k_size
[
tokens
[
20
]]),
(
self
.
multiply
[
tokens
[
21
]],
self
.
filter_num5
[
tokens
[
22
]],
self
.
repeat
[
tokens
[
23
]],
2
,
self
.
k_size
[
tokens
[
24
]]),
(
self
.
multiply
[
tokens
[
25
]],
self
.
filter_num6
[
tokens
[
26
]],
self
.
repeat
[
tokens
[
27
]],
1
,
self
.
k_size
[
tokens
[
28
]]),
]
assert
self
.
block_num
<
7
,
'block number must less than 7, but receive block number is {}'
.
format
(
assert
self
.
block_num
<
7
,
'block number must less than 7, but receive block number is {}'
.
format
(
self
.
block_num
)
self
.
block_num
)
# the stride = 2 means downsample feature map in the convolution, so only when stride=2, block_num minus 1,
if
tokens
is
None
:
# otherwise, add layers to params_list directly.
tokens
=
self
.
init_tokens
()
bottleneck_params_list
=
[]
for
param_list
in
base_bottleneck_params_list
:
if
param_list
[
3
]
==
1
:
bottleneck_params_list
.
append
(
param_list
)
else
:
if
self
.
block_num
>
1
:
bottleneck_params_list
.
append
(
param_list
)
self
.
block_num
-=
1
else
:
break
bottleneck_params_list
=
[]
if
self
.
block_num
>=
1
:
bottleneck_params_list
.
append
((
1
,
self
.
head_num
[
tokens
[
0
]],
1
,
1
,
3
))
if
self
.
block_num
>=
2
:
bottleneck_params_list
.
append
((
self
.
multiply
[
tokens
[
1
]],
self
.
filter_num1
[
tokens
[
2
]],
self
.
repeat
[
tokens
[
3
]],
2
,
self
.
k_size
[
tokens
[
4
]]))
if
self
.
block_num
>=
3
:
bottleneck_params_list
.
append
((
self
.
multiply
[
tokens
[
5
]],
self
.
filter_num1
[
tokens
[
6
]],
self
.
repeat
[
tokens
[
7
]],
2
,
self
.
k_size
[
tokens
[
8
]]))
if
self
.
block_num
>=
4
:
bottleneck_params_list
.
append
((
self
.
multiply
[
tokens
[
9
]],
self
.
filter_num2
[
tokens
[
10
]],
self
.
repeat
[
tokens
[
11
]],
2
,
self
.
k_size
[
tokens
[
12
]]))
if
self
.
block_num
>=
5
:
bottleneck_params_list
.
append
((
self
.
multiply
[
tokens
[
13
]],
self
.
filter_num3
[
tokens
[
14
]],
self
.
repeat
[
tokens
[
15
]],
2
,
self
.
k_size
[
tokens
[
16
]]))
bottleneck_params_list
.
append
((
self
.
multiply
[
tokens
[
17
]],
self
.
filter_num3
[
tokens
[
18
]],
self
.
repeat
[
tokens
[
19
]],
1
,
self
.
k_size
[
tokens
[
20
]]))
if
self
.
block_num
>=
6
:
bottleneck_params_list
.
append
((
self
.
multiply
[
tokens
[
21
]],
self
.
filter_num5
[
tokens
[
22
]],
self
.
repeat
[
tokens
[
23
]],
2
,
self
.
k_size
[
tokens
[
24
]]))
bottleneck_params_list
.
append
((
self
.
multiply
[
tokens
[
25
]],
self
.
filter_num6
[
tokens
[
26
]],
self
.
repeat
[
tokens
[
27
]],
1
,
self
.
k_size
[
tokens
[
28
]]))
def
net_arch
(
input
):
def
net_arch
(
input
):
#conv1
#conv1
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
...
@@ -137,7 +135,7 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -137,7 +135,7 @@ class MobileNetV2Space(SearchSpaceBase):
stride
=
2
,
stride
=
2
,
padding
=
'SAME'
,
padding
=
'SAME'
,
act
=
'relu6'
,
act
=
'relu6'
,
name
=
'conv1_1'
)
name
=
'
mobilenetv2_
conv1_1'
)
# bottleneck sequences
# bottleneck sequences
i
=
1
i
=
1
...
@@ -145,7 +143,7 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -145,7 +143,7 @@ class MobileNetV2Space(SearchSpaceBase):
for
layer_setting
in
bottleneck_params_list
:
for
layer_setting
in
bottleneck_params_list
:
t
,
c
,
n
,
s
,
k
=
layer_setting
t
,
c
,
n
,
s
,
k
=
layer_setting
i
+=
1
i
+=
1
input
=
self
.
invresi_blocks
(
input
=
self
.
_
invresi_blocks
(
input
=
input
,
input
=
input
,
in_c
=
in_c
,
in_c
=
in_c
,
t
=
t
,
t
=
t
,
...
@@ -153,7 +151,7 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -153,7 +151,7 @@ class MobileNetV2Space(SearchSpaceBase):
n
=
n
,
n
=
n
,
s
=
s
,
s
=
s
,
k
=
k
,
k
=
k
,
name
=
'conv'
+
str
(
i
))
name
=
'
mobilenetv2_
conv'
+
str
(
i
))
in_c
=
int
(
c
*
self
.
scale
)
in_c
=
int
(
c
*
self
.
scale
)
# if output_size is 1, add fc layer in the end
# if output_size is 1, add fc layer in the end
...
@@ -161,8 +159,8 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -161,8 +159,8 @@ class MobileNetV2Space(SearchSpaceBase):
input
=
fluid
.
layers
.
fc
(
input
=
fluid
.
layers
.
fc
(
input
=
input
,
input
=
input
,
size
=
self
.
class_dim
,
size
=
self
.
class_dim
,
param_attr
=
ParamAttr
(
name
=
'
fc10
_weights'
),
param_attr
=
ParamAttr
(
name
=
'
mobilenetv2_fc
_weights'
),
bias_attr
=
ParamAttr
(
name
=
'
fc10
_offset'
))
bias_attr
=
ParamAttr
(
name
=
'
mobilenetv2_fc
_offset'
))
else
:
else
:
assert
self
.
output_size
==
input
.
shape
[
2
],
\
assert
self
.
output_size
==
input
.
shape
[
2
],
\
(
"output_size must EQUAL to input_size / (2^block_num)."
(
"output_size must EQUAL to input_size / (2^block_num)."
...
@@ -173,7 +171,7 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -173,7 +171,7 @@ class MobileNetV2Space(SearchSpaceBase):
return
net_arch
return
net_arch
def
shortcut
(
self
,
input
,
data_residual
):
def
_
shortcut
(
self
,
input
,
data_residual
):
"""Build shortcut layer.
"""Build shortcut layer.
Args:
Args:
input(Variable): input.
input(Variable): input.
...
@@ -183,7 +181,7 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -183,7 +181,7 @@ class MobileNetV2Space(SearchSpaceBase):
"""
"""
return
fluid
.
layers
.
elementwise_add
(
input
,
data_residual
)
return
fluid
.
layers
.
elementwise_add
(
input
,
data_residual
)
def
inverted_residual_unit
(
self
,
def
_
inverted_residual_unit
(
self
,
input
,
input
,
num_in_filter
,
num_in_filter
,
num_filters
,
num_filters
,
...
@@ -240,10 +238,10 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -240,10 +238,10 @@ class MobileNetV2Space(SearchSpaceBase):
name
=
name
+
'_linear'
)
name
=
name
+
'_linear'
)
out
=
linear_out
out
=
linear_out
if
ifshortcut
:
if
ifshortcut
:
out
=
self
.
shortcut
(
input
=
input
,
data_residual
=
out
)
out
=
self
.
_
shortcut
(
input
=
input
,
data_residual
=
out
)
return
out
return
out
def
invresi_blocks
(
self
,
input
,
in_c
,
t
,
c
,
n
,
s
,
k
,
name
=
None
):
def
_
invresi_blocks
(
self
,
input
,
in_c
,
t
,
c
,
n
,
s
,
k
,
name
=
None
):
"""Build inverted residual blocks.
"""Build inverted residual blocks.
Args:
Args:
input: Variable, input.
input: Variable, input.
...
@@ -257,7 +255,7 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -257,7 +255,7 @@ class MobileNetV2Space(SearchSpaceBase):
Returns:
Returns:
Variable, layers output.
Variable, layers output.
"""
"""
first_block
=
self
.
inverted_residual_unit
(
first_block
=
self
.
_
inverted_residual_unit
(
input
=
input
,
input
=
input
,
num_in_filter
=
in_c
,
num_in_filter
=
in_c
,
num_filters
=
c
,
num_filters
=
c
,
...
@@ -271,7 +269,7 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -271,7 +269,7 @@ class MobileNetV2Space(SearchSpaceBase):
last_c
=
c
last_c
=
c
for
i
in
range
(
1
,
n
):
for
i
in
range
(
1
,
n
):
last_residual_block
=
self
.
inverted_residual_unit
(
last_residual_block
=
self
.
_
inverted_residual_unit
(
input
=
last_residual_block
,
input
=
last_residual_block
,
num_in_filter
=
last_c
,
num_in_filter
=
last_c
,
num_filters
=
c
,
num_filters
=
c
,
...
...
paddleslim/nas/search_space/resnet.py
0 → 100644
浏览文件 @
a0d17e44
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
.search_space_base
import
SearchSpaceBase
from
.base_layer
import
conv_bn_layer
from
.search_space_registry
import
SEARCHSPACE
__all__
=
[
"ResNetSpace"
]
@
SEARCHSPACE
.
register
class
ResNetSpace
(
SearchSpaceBase
):
def
__init__
(
self
,
input_size
,
output_size
,
block_num
,
scale
=
1.0
,
class_dim
=
1000
):
super
(
ResNetSpace
,
self
).
__init__
(
input_size
,
output_size
,
block_num
)
pass
def
init_tokens
(
self
):
return
[
0
,
0
,
0
,
0
,
0
,
0
]
def
range_table
(
self
):
return
[
3
,
3
,
3
,
3
,
3
,
3
]
def
token2arch
(
self
,
tokens
=
None
):
if
tokens
is
None
:
self
.
init_tokens
()
def
net_arch
(
input
):
input
=
conv_bn_layer
(
input
,
num_filters
=
32
,
filter_size
=
3
,
stride
=
2
,
padding
=
'SAME'
,
act
=
'sigmoid'
,
name
=
'resnet_conv1_1'
)
return
input
return
net_arch
paddleslim/nas/search_space/search_space_base.py
浏览文件 @
a0d17e44
...
@@ -39,6 +39,6 @@ class SearchSpaceBase(object):
...
@@ -39,6 +39,6 @@ class SearchSpaceBase(object):
Args:
Args:
tokens(list<int>): The tokens which represent a network.
tokens(list<int>): The tokens which represent a network.
Return:
Return:
list<layers>
model arch
"""
"""
raise
NotImplementedError
(
'Abstract method.'
)
raise
NotImplementedError
(
'Abstract method.'
)
paddleslim/nas/search_space/search_space_factory.py
浏览文件 @
a0d17e44
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
search_space_registry
import
SEARCHSPACE
from
.combine_search_space
import
CombineSearchSpace
__all__
=
[
"SearchSpaceFactory"
]
__all__
=
[
"SearchSpaceFactory"
]
...
@@ -21,18 +21,11 @@ class SearchSpaceFactory(object):
...
@@ -21,18 +21,11 @@ class SearchSpaceFactory(object):
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
get_search_space
(
self
,
key
,
config
):
def
get_search_space
(
self
,
config_lists
):
"""
"""
get
specific model space based on key and config.
get
model spaces based on list(key, config).
Args:
key(str): model space name.
config(dict): basic config information.
return:
model space(class)
"""
"""
cls
=
SEARCHSPACE
.
get
(
key
)
assert
isinstance
(
config_lists
,
list
),
"configs must be a list"
space
=
cls
(
config
[
'input_size'
],
config
[
'output_size'
],
config
[
'block_num'
])
return
space
return
CombineSearchSpace
(
config_lists
)
tests/test_searchspace.py
浏览文件 @
a0d17e44
...
@@ -24,7 +24,7 @@ class TestSearchSpaceFactory(unittest.TestCase):
...
@@ -24,7 +24,7 @@ class TestSearchSpaceFactory(unittest.TestCase):
config
=
{
'input_size'
:
224
,
'output_size'
:
7
,
'block_num'
:
5
}
config
=
{
'input_size'
:
224
,
'output_size'
:
7
,
'block_num'
:
5
}
space
=
SearchSpaceFactory
()
space
=
SearchSpaceFactory
()
my_space
=
space
.
get_search_space
(
'MobileNetV2Space'
,
config
)
my_space
=
space
.
get_search_space
(
[(
'MobileNetV2Space'
,
config
)]
)
model_arch
=
my_space
.
token2arch
()
model_arch
=
my_space
.
token2arch
()
train_prog
=
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
...
@@ -36,10 +36,26 @@ class TestSearchSpaceFactory(unittest.TestCase):
...
@@ -36,10 +36,26 @@ class TestSearchSpaceFactory(unittest.TestCase):
shape
=
[
1
,
3
,
input_size
,
input_size
],
shape
=
[
1
,
3
,
input_size
,
input_size
],
dtype
=
'float32'
,
dtype
=
'float32'
,
append_batch_size
=
False
)
append_batch_size
=
False
)
print
(
'input shape'
,
model_input
.
shape
)
predict
=
model_arch
[
0
](
model_input
)
predict
=
model_arch
(
model_input
)
self
.
assertTrue
(
predict
.
shape
[
2
]
==
config
[
'output_size'
])
print
(
'output shape'
,
predict
.
shape
)
class
TestMultiSearchSpace
(
unittest
.
TestCase
):
space
=
SearchSpaceFactory
()
config0
=
{
'input_size'
:
224
,
'output_size'
:
7
,
'block_num'
:
5
}
config1
=
{
'input_size'
:
7
,
'output_size'
:
1
,
'block_num'
:
2
}
my_space
=
space
.
get_search_space
([(
'MobileNetV2Space'
,
config0
),
(
'ResNetSpace'
,
config1
)])
model_archs
=
my_space
.
token2arch
()
train_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
input_size
=
config0
[
'input_size'
]
model_input
=
fluid
.
layers
.
data
(
name
=
'model_in'
,
shape
=
[
1
,
3
,
input_size
,
input_size
],
dtype
=
'float32'
,
append_batch_size
=
False
)
for
model_arch
in
model_archs
:
predict
=
model_arch
(
model_input
)
model_input
=
predict
print
(
predict
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录