Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
73019e56
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
73019e56
编写于
1月 22, 2020
作者:
C
ceci3
提交者:
GitHub
1月 22, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix for python3.7 (#45)
* update fix * fix init token * fix bug * update * update doc
上级
501ab9d4
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
85 addition
and
111 deletion
+85
-111
demo/nas/README.md
demo/nas/README.md
+14
-58
docs/docs/api/nas_api.md
docs/docs/api/nas_api.md
+22
-22
paddleslim/common/controller_client.py
paddleslim/common/controller_client.py
+10
-0
paddleslim/common/controller_server.py
paddleslim/common/controller_server.py
+8
-0
paddleslim/common/sa_controller.py
paddleslim/common/sa_controller.py
+1
-1
paddleslim/nas/sa_nas.py
paddleslim/nas/sa_nas.py
+1
-4
paddleslim/nas/search_space/base_layer.py
paddleslim/nas/search_space/base_layer.py
+7
-7
paddleslim/nas/search_space/inception_block.py
paddleslim/nas/search_space/inception_block.py
+5
-5
paddleslim/nas/search_space/mobilenet_block.py
paddleslim/nas/search_space/mobilenet_block.py
+9
-7
paddleslim/nas/search_space/mobilenetv1.py
paddleslim/nas/search_space/mobilenetv1.py
+1
-1
paddleslim/nas/search_space/mobilenetv2.py
paddleslim/nas/search_space/mobilenetv2.py
+1
-1
paddleslim/nas/search_space/resnet_block.py
paddleslim/nas/search_space/resnet_block.py
+6
-5
未找到文件。
demo/nas/README.md
浏览文件 @
73019e56
...
@@ -2,69 +2,25 @@
...
@@ -2,69 +2,25 @@
本示例介绍如何使用网络结构搜索接口,搜索到一个更小或者精度更高的模型,该文档仅介绍paddleslim中SANAS的使用及如何利用SANAS得到模型结构,完整示例代码请参考sa_nas_mobilenetv2.py或者block_sa_nas_mobilenetv2.py。
本示例介绍如何使用网络结构搜索接口,搜索到一个更小或者精度更高的模型,该文档仅介绍paddleslim中SANAS的使用及如何利用SANAS得到模型结构,完整示例代码请参考sa_nas_mobilenetv2.py或者block_sa_nas_mobilenetv2.py。
## 接口介绍
## 数据准备
请参考。
本示例默认使用cifar10数据,cifar10数据会根据调用的paddle接口自动下载,无需额外准备。
### 1. 配置搜索空间
详细的搜索空间配置可以参考
<a
href=
'../../../paddleslim/nas/nas_api.md'
>
神经网络搜索API文档
</a>
。
```
config = [('MobileNetV2Space')]
```
### 2. 利用搜索空间初始化SANAS实例
```
from paddleslim.nas import SANAS
sa_nas = SANAS(
config,
server_addr=("", 8881),
init_temperature=10.24,
reduce_rate=0.85,
search_steps=300,
is_server=True)
```
### 3. 根据实例化的NAS得到当前的网络结构
## 接口介绍
```
请参考
<a
href=
'../../docs/docs/api/nas_api.md'
>
神经网络搜索API文档
</a>
。
archs = sa_nas.next_archs()
```
### 4. 根据得到的网络结构和输入构造训练和测试program
```
import paddle.fluid as fluid
train_program = fluid.Program()
test_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
本示例为在MobileNetV2的搜索空间上搜索FLOPs更小的模型。
data = fluid.data(name='data', shape=[None, 3, 32, 32], dtype='float32')
## 1 搜索空间配置
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
默认搜索空间为
`MobileNetV2`
,详细的搜索空间配置请参考
<a
href=
'../../docs/docs/search_space.md'
>
搜索空间配置文档
</a>
。
for arch in archs:
data = arch(data)
output = fluid.layers.fc(data, 10)
softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost)
acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
test_program = train_program.clone(for_test=True)
## 2 启动训练
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(avg_cost)
### 2.1 启动基于MobileNetV2初始模型结构构造搜索空间的实验
```
shell
CUDA_VISIBLE_DEVICES
=
0 python sa_nas_mobilenetv2.py
```
```
### 5. 根据构造的训练program添加限制条件
```
from paddleslim.analysis import flops
if flops(train_program) > 321208544:
### 2.2 启动基于MobileNetV2的block构造搜索空间的实验
continue
```
shell
```
CUDA_VISIBLE_DEVICES
=
0 python block_sa_nas_mobilenetv2.py
### 6. 回传score
```
sa_nas.reward(score)
```
```
docs/docs/api/nas_api.md
浏览文件 @
73019e56
...
@@ -32,7 +32,7 @@ paddleslim.nas.SANAS(configs, server_addr=("", 8881), init_temperature=None, red
...
@@ -32,7 +32,7 @@ paddleslim.nas.SANAS(configs, server_addr=("", 8881), init_temperature=None, red
```
python
```
python
from
paddleslim.nas
import
SANAS
from
paddleslim.nas
import
SANAS
config
=
[(
'MobileNetV2Space'
)]
config
=
[(
'MobileNetV2Space'
)]
sanas
=
SANAS
(
config
=
config
)
sanas
=
SANAS
(
config
s
=
config
)
```
```
!!! note "Note"
!!! note "Note"
...
@@ -48,26 +48,6 @@ sanas = SANAS(config=config)
...
@@ -48,26 +48,6 @@ sanas = SANAS(config=config)
-
初始化token如果是随机生成的话,代表初始化token是一个比较差的token,SA算法可以处于一种不稳定的阶段进行搜索,尽可能的随机探索所有可能得token,从而找到一个较好的token。初始温度可以设置的高一些,例如设置为1000,退火率相对设置的小一些。
-
初始化token如果是随机生成的话,代表初始化token是一个比较差的token,SA算法可以处于一种不稳定的阶段进行搜索,尽可能的随机探索所有可能得token,从而找到一个较好的token。初始温度可以设置的高一些,例如设置为1000,退火率相对设置的小一些。
paddlesim.nas.SANAS.tokens2arch(tokens)
: 通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组token对应唯一的一个网络结构。
**参数:**
-
**tokens(list):**
- 一组token。
**返回:**
根据传入的token得到一个模型结构实例。
**示例代码:**
```
python
import
paddle.fluid
as
fluid
input
=
fluid
.
data
(
name
=
'input'
,
shape
=
[
None
,
3
,
32
,
32
],
dtype
=
'float32'
)
archs
=
sanas
.
token2arch
(
tokens
)
for
arch
in
archs
:
output
=
arch
(
input
)
input
=
output
```
paddleslim.nas.SANAS.next_archs()
paddleslim.nas.SANAS.next_archs()
: 获取下一组模型结构。
: 获取下一组模型结构。
...
@@ -84,7 +64,6 @@ for arch in archs:
...
@@ -84,7 +64,6 @@ for arch in archs:
input
=
output
input
=
output
```
```
paddleslim.nas.SANAS.reward(score)
paddleslim.nas.SANAS.reward(score)
: 把当前模型结构的得分情况回传。
: 把当前模型结构的得分情况回传。
...
@@ -95,6 +74,27 @@ paddleslim.nas.SANAS.reward(score)
...
@@ -95,6 +74,27 @@ paddleslim.nas.SANAS.reward(score)
**返回:**
**返回:**
模型结构更新成功或者失败,成功则返回
`True`
,失败则返回
`False`
。
模型结构更新成功或者失败,成功则返回
`True`
,失败则返回
`False`
。
paddlesim.nas.SANAS.tokens2arch(tokens)
: 通过一组token得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组token对应唯一的一个网络结构。
**参数:**
-
**tokens(list):**
- 一组token。
**返回:**
根据传入的token得到一个模型结构实例。
**示例代码:**
```
python
import
paddle.fluid
as
fluid
input
=
fluid
.
data
(
name
=
'input'
,
shape
=
[
None
,
3
,
32
,
32
],
dtype
=
'float32'
)
archs
=
sanas
.
token2arch
(
tokens
)
for
arch
in
archs
:
output
=
arch
(
input
)
input
=
output
```
paddleslim.nas.SANAS.current_info()
paddleslim.nas.SANAS.current_info()
: 返回当前token和搜索过程中最好的token和reward。
: 返回当前token和搜索过程中最好的token和reward。
...
...
paddleslim/common/controller_client.py
浏览文件 @
73019e56
...
@@ -71,3 +71,13 @@ class ControllerClient(object):
...
@@ -71,3 +71,13 @@ class ControllerClient(object):
tokens
=
socket_client
.
recv
(
1024
).
decode
()
tokens
=
socket_client
.
recv
(
1024
).
decode
()
tokens
=
[
int
(
token
)
for
token
in
tokens
.
strip
(
"
\n
"
).
split
(
","
)]
tokens
=
[
int
(
token
)
for
token
in
tokens
.
strip
(
"
\n
"
).
split
(
","
)]
return
tokens
return
tokens
def
request_current_info
(
self
):
"""
Request for current information.
"""
socket_client
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
socket_client
.
connect
((
self
.
server_ip
,
self
.
server_port
))
socket_client
.
send
(
"current_info"
.
encode
())
current_info
=
socket_client
.
recv
(
1024
).
decode
()
return
eval
(
current_info
)
paddleslim/common/controller_server.py
浏览文件 @
73019e56
...
@@ -90,10 +90,18 @@ class ControllerServer(object):
...
@@ -90,10 +90,18 @@ class ControllerServer(object):
(
self
.
_search_steps
)))
and
not
self
.
_closed
:
(
self
.
_search_steps
)))
and
not
self
.
_closed
:
conn
,
addr
=
self
.
_socket_server
.
accept
()
conn
,
addr
=
self
.
_socket_server
.
accept
()
message
=
conn
.
recv
(
1024
).
decode
()
message
=
conn
.
recv
(
1024
).
decode
()
_logger
.
debug
(
message
)
if
message
.
strip
(
"
\n
"
)
==
"next_tokens"
:
if
message
.
strip
(
"
\n
"
)
==
"next_tokens"
:
tokens
=
self
.
_controller
.
next_tokens
()
tokens
=
self
.
_controller
.
next_tokens
()
tokens
=
","
.
join
([
str
(
token
)
for
token
in
tokens
])
tokens
=
","
.
join
([
str
(
token
)
for
token
in
tokens
])
conn
.
send
(
tokens
.
encode
())
conn
.
send
(
tokens
.
encode
())
elif
message
.
strip
(
"
\n
"
)
==
"current_info"
:
current_info
=
dict
()
current_info
[
'best_tokens'
]
=
self
.
_controller
.
best_tokens
current_info
[
'best_reward'
]
=
self
.
_controller
.
max_reward
current_info
[
'current_tokens'
]
=
self
.
_controller
.
current_tokens
conn
.
send
(
str
(
current_info
).
encode
())
else
:
else
:
_logger
.
debug
(
"recv message from {}: [{}]"
.
format
(
addr
,
_logger
.
debug
(
"recv message from {}: [{}]"
.
format
(
addr
,
message
))
message
))
...
...
paddleslim/common/sa_controller.py
浏览文件 @
73019e56
...
@@ -81,7 +81,7 @@ class SAController(EvolutionaryController):
...
@@ -81,7 +81,7 @@ class SAController(EvolutionaryController):
self
.
_iter
=
iters
self
.
_iter
=
iters
self
.
_checkpoints
=
checkpoints
self
.
_checkpoints
=
checkpoints
self
.
_searched
=
searched
if
searched
!=
None
else
dict
()
self
.
_searched
=
searched
if
searched
!=
None
else
dict
()
self
.
_current_token
=
init_tokens
self
.
_current_token
s
=
init_tokens
def
__getstate__
(
self
):
def
__getstate__
(
self
):
d
=
{}
d
=
{}
...
...
paddleslim/nas/sa_nas.py
浏览文件 @
73019e56
...
@@ -162,10 +162,7 @@ class SANAS(object):
...
@@ -162,10 +162,7 @@ class SANAS(object):
Returns:
Returns:
dict<name, value>: a dictionary include best tokens, best reward and current reward.
dict<name, value>: a dictionary include best tokens, best reward and current reward.
"""
"""
current_dict
=
dict
()
current_dict
=
self
.
_controller_client
.
request_current_info
()
current_dict
[
'best_tokens'
]
=
self
.
_controller
.
best_tokens
current_dict
[
'best_reward'
]
=
self
.
_controller
.
max_reward
current_dict
[
'current_tokens'
]
=
self
.
_controller
.
current_tokens
return
current_dict
return
current_dict
def
next_archs
(
self
):
def
next_archs
(
self
):
...
...
paddleslim/nas/search_space/base_layer.py
浏览文件 @
73019e56
...
@@ -19,7 +19,7 @@ from paddle.fluid.param_attr import ParamAttr
...
@@ -19,7 +19,7 @@ from paddle.fluid.param_attr import ParamAttr
def
conv_bn_layer
(
input
,
def
conv_bn_layer
(
input
,
filter_size
,
filter_size
,
num_filters
,
num_filters
,
stride
,
stride
=
1
,
padding
=
'SAME'
,
padding
=
'SAME'
,
num_groups
=
1
,
num_groups
=
1
,
act
=
None
,
act
=
None
,
...
@@ -53,7 +53,7 @@ def conv_bn_layer(input,
...
@@ -53,7 +53,7 @@ def conv_bn_layer(input,
bn_name
=
name
+
'_bn'
bn_name
=
name
+
'_bn'
return
fluid
.
layers
.
batch_norm
(
return
fluid
.
layers
.
batch_norm
(
input
=
conv
,
input
=
conv
,
act
=
act
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
'_offset'
),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
'_offset'
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_mean_name
=
bn_name
+
'_mean'
,
...
...
paddleslim/nas/search_space/inception_block.py
浏览文件 @
73019e56
...
@@ -58,7 +58,7 @@ class InceptionABlockSpace(SearchSpaceBase):
...
@@ -58,7 +58,7 @@ class InceptionABlockSpace(SearchSpaceBase):
"""
"""
The initial token.
The initial token.
"""
"""
return
get_random_tokens
(
self
.
range_table
)
return
get_random_tokens
(
self
.
range_table
()
)
def
range_table
(
self
):
def
range_table
(
self
):
"""
"""
...
@@ -175,7 +175,7 @@ class InceptionABlockSpace(SearchSpaceBase):
...
@@ -175,7 +175,7 @@ class InceptionABlockSpace(SearchSpaceBase):
input
=
self
.
_inceptionA
(
input
=
self
.
_inceptionA
(
input
,
input
,
A_tokens
=
filter_nums
,
A_tokens
=
filter_nums
,
filter_size
=
filter_size
,
filter_size
=
int
(
filter_size
)
,
stride
=
stride
,
stride
=
stride
,
pool_type
=
pool_type
,
pool_type
=
pool_type
,
name
=
'inceptionA_{}'
.
format
(
i
+
1
))
name
=
'inceptionA_{}'
.
format
(
i
+
1
))
...
@@ -287,7 +287,7 @@ class InceptionCBlockSpace(SearchSpaceBase):
...
@@ -287,7 +287,7 @@ class InceptionCBlockSpace(SearchSpaceBase):
"""
"""
The initial token.
The initial token.
"""
"""
return
get_random_tokens
(
self
.
range_table
)
return
get_random_tokens
(
self
.
range_table
()
)
def
range_table
(
self
):
def
range_table
(
self
):
"""
"""
...
@@ -408,13 +408,13 @@ class InceptionCBlockSpace(SearchSpaceBase):
...
@@ -408,13 +408,13 @@ class InceptionCBlockSpace(SearchSpaceBase):
pool_type
=
'avg'
if
layer_setting
[
11
]
==
0
else
'max'
pool_type
=
'avg'
if
layer_setting
[
11
]
==
0
else
'max'
if
stride
==
2
:
if
stride
==
2
:
layer_count
+=
1
layer_count
+=
1
if
check_points
((
layer_count
-
1
)
in
return_block
):
if
check_points
((
layer_count
-
1
)
,
return_block
):
mid_layer
[
layer_count
-
1
]
=
input
mid_layer
[
layer_count
-
1
]
=
input
input
=
self
.
_inceptionC
(
input
=
self
.
_inceptionC
(
input
,
input
,
C_tokens
=
filter_nums
,
C_tokens
=
filter_nums
,
filter_size
=
filter_size
,
filter_size
=
int
(
filter_size
)
,
stride
=
stride
,
stride
=
stride
,
pool_type
=
pool_type
,
pool_type
=
pool_type
,
name
=
'inceptionC_{}'
.
format
(
i
+
1
))
name
=
'inceptionC_{}'
.
format
(
i
+
1
))
...
...
paddleslim/nas/search_space/mobilenet_block.py
浏览文件 @
73019e56
...
@@ -60,7 +60,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
...
@@ -60,7 +60,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
self
.
scale
=
scale
self
.
scale
=
scale
def
init_tokens
(
self
):
def
init_tokens
(
self
):
return
get_random_tokens
(
self
.
range_table
)
return
get_random_tokens
(
self
.
range_table
()
)
def
range_table
(
self
):
def
range_table
(
self
):
range_table_base
=
[]
range_table_base
=
[]
...
@@ -153,7 +153,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
...
@@ -153,7 +153,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
c
=
int
(
c
*
self
.
scale
),
c
=
int
(
c
*
self
.
scale
),
n
=
n
,
n
=
n
,
s
=
s
,
s
=
s
,
k
=
k
,
k
=
int
(
k
)
,
name
=
'mobilenetv2_'
+
str
(
i
+
1
))
name
=
'mobilenetv2_'
+
str
(
i
+
1
))
in_c
=
int
(
c
*
self
.
scale
)
in_c
=
int
(
c
*
self
.
scale
)
...
@@ -289,6 +289,8 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
...
@@ -289,6 +289,8 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
scale
=
1.0
):
scale
=
1.0
):
super
(
MobileNetV1BlockSpace
,
self
).
__init__
(
input_size
,
output_size
,
super
(
MobileNetV1BlockSpace
,
self
).
__init__
(
input_size
,
output_size
,
block_num
,
block_mask
)
block_num
,
block_mask
)
if
self
.
block_mask
==
None
:
# use input_size and output_size to compute self.downsample_num
# use input_size and output_size to compute self.downsample_num
self
.
downsample_num
=
compute_downsample_num
(
self
.
input_size
,
self
.
downsample_num
=
compute_downsample_num
(
self
.
input_size
,
self
.
output_size
)
self
.
output_size
)
...
@@ -305,7 +307,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
...
@@ -305,7 +307,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
self
.
scale
=
scale
self
.
scale
=
scale
def
init_tokens
(
self
):
def
init_tokens
(
self
):
return
get_random_tokens
(
self
.
range_table
)
return
get_random_tokens
(
self
.
range_table
()
)
def
range_table
(
self
):
def
range_table
(
self
):
range_table_base
=
[]
range_table_base
=
[]
...
@@ -383,7 +385,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
...
@@ -383,7 +385,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
num_filters2
=
filter_num2
,
num_filters2
=
filter_num2
,
stride
=
stride
,
stride
=
stride
,
scale
=
self
.
scale
,
scale
=
self
.
scale
,
kernel_size
=
kernel_size
,
kernel_size
=
int
(
kernel_size
)
,
name
=
'mobilenetv1_{}'
.
format
(
str
(
i
+
1
)))
name
=
'mobilenetv1_{}'
.
format
(
str
(
i
+
1
)))
if
return_mid_layer
:
if
return_mid_layer
:
...
...
paddleslim/nas/search_space/mobilenetv1.py
浏览文件 @
73019e56
...
@@ -191,7 +191,7 @@ class MobileNetV1Space(SearchSpaceBase):
...
@@ -191,7 +191,7 @@ class MobileNetV1Space(SearchSpaceBase):
num_groups
=
filter_num1
,
num_groups
=
filter_num1
,
stride
=
stride
,
stride
=
stride
,
scale
=
self
.
scale
,
scale
=
self
.
scale
,
kernel_size
=
kernel_size
,
kernel_size
=
int
(
kernel_size
)
,
name
=
'mobilenetv1_{}'
.
format
(
str
(
i
+
1
)))
name
=
'mobilenetv1_{}'
.
format
(
str
(
i
+
1
)))
### return_block and end_points means block num
### return_block and end_points means block num
...
...
paddleslim/nas/search_space/mobilenetv2.py
浏览文件 @
73019e56
...
@@ -182,7 +182,7 @@ class MobileNetV2Space(SearchSpaceBase):
...
@@ -182,7 +182,7 @@ class MobileNetV2Space(SearchSpaceBase):
c
=
int
(
c
*
self
.
scale
),
c
=
int
(
c
*
self
.
scale
),
n
=
n
,
n
=
n
,
s
=
s
,
s
=
s
,
k
=
k
,
k
=
int
(
k
)
,
name
=
'mobilenetv2_conv'
+
str
(
i
))
name
=
'mobilenetv2_conv'
+
str
(
i
))
in_c
=
int
(
c
*
self
.
scale
)
in_c
=
int
(
c
*
self
.
scale
)
...
...
paddleslim/nas/search_space/resnet_block.py
浏览文件 @
73019e56
...
@@ -32,6 +32,7 @@ class ResNetBlockSpace(SearchSpaceBase):
...
@@ -32,6 +32,7 @@ class ResNetBlockSpace(SearchSpaceBase):
def
__init__
(
self
,
input_size
,
output_size
,
block_num
,
block_mask
=
None
):
def
__init__
(
self
,
input_size
,
output_size
,
block_num
,
block_mask
=
None
):
super
(
ResNetBlockSpace
,
self
).
__init__
(
input_size
,
output_size
,
super
(
ResNetBlockSpace
,
self
).
__init__
(
input_size
,
output_size
,
block_num
,
block_mask
)
block_num
,
block_mask
)
if
self
.
block_mask
==
None
:
# use input_size and output_size to compute self.downsample_num
# use input_size and output_size to compute self.downsample_num
self
.
downsample_num
=
compute_downsample_num
(
self
.
input_size
,
self
.
downsample_num
=
compute_downsample_num
(
self
.
input_size
,
self
.
output_size
)
self
.
output_size
)
...
@@ -44,7 +45,7 @@ class ResNetBlockSpace(SearchSpaceBase):
...
@@ -44,7 +45,7 @@ class ResNetBlockSpace(SearchSpaceBase):
self
.
k_size
=
np
.
array
([
3
,
5
])
self
.
k_size
=
np
.
array
([
3
,
5
])
def
init_tokens
(
self
):
def
init_tokens
(
self
):
return
get_random_tokens
(
self
.
range_table
)
return
get_random_tokens
(
self
.
range_table
()
)
def
range_table
(
self
):
def
range_table
(
self
):
range_table_base
=
[]
range_table_base
=
[]
...
@@ -133,7 +134,7 @@ class ResNetBlockSpace(SearchSpaceBase):
...
@@ -133,7 +134,7 @@ class ResNetBlockSpace(SearchSpaceBase):
num_filters1
=
filter_num1
,
num_filters1
=
filter_num1
,
num_filters2
=
filter_num3
,
num_filters2
=
filter_num3
,
num_filters3
=
filter_num3
,
num_filters3
=
filter_num3
,
kernel_size
=
k_size
,
kernel_size
=
int
(
k_size
)
,
repeat1
=
repeat1
,
repeat1
=
repeat1
,
repeat2
=
repeat2
,
repeat2
=
repeat2
,
stride
=
stride
,
stride
=
stride
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录