Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
02144bca
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
02144bca
编写于
2月 04, 2020
作者:
W
whs
提交者:
GitHub
2月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add one-shot NAS API and mnasnet based search space. (#17)
上级
1664a758
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
871 addition
and
4 deletion
+871
-4
demo/one_shot/train.py
demo/one_shot/train.py
+207
-0
docs/docs/api/one_shot_api.md
docs/docs/api/one_shot_api.md
+154
-0
docs/docs/tutorials/one_shot_nas_demo.md
docs/docs/tutorials/one_shot_nas_demo.md
+102
-0
docs/mkdocs.yml
docs/mkdocs.yml
+3
-0
paddleslim/nas/__init__.py
paddleslim/nas/__init__.py
+7
-3
paddleslim/nas/one_shot/__init__.py
paddleslim/nas/one_shot/__init__.py
+22
-0
paddleslim/nas/one_shot/one_shot_nas.py
paddleslim/nas/one_shot/one_shot_nas.py
+114
-0
paddleslim/nas/one_shot/super_mnasnet.py
paddleslim/nas/one_shot/super_mnasnet.py
+257
-0
paddleslim/nas/search_space/__init__.py
paddleslim/nas/search_space/__init__.py
+0
-1
paddleslim/nas/search_space/search_space_base.py
paddleslim/nas/search_space/search_space_base.py
+5
-0
未找到文件。
demo/one_shot/train.py
0 → 100644
浏览文件 @
02144bca
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
argparse
import
ast
import
numpy
as
np
from
PIL
import
Image
import
os
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.optimizer
import
AdamOptimizer
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.dygraph.base
import
to_variable
from
paddleslim.nas.one_shot
import
SuperMnasnet
from
paddleslim.nas.one_shot
import
OneShotSearch
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"Training for Mnist."
)
parser
.
add_argument
(
"--use_data_parallel"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"The flag indicating whether to use data parallel mode to train the model."
)
parser
.
add_argument
(
"-e"
,
"--epoch"
,
default
=
5
,
type
=
int
,
help
=
"set epoch"
)
parser
.
add_argument
(
"--ce"
,
action
=
"store_true"
,
help
=
"run ce"
)
args
=
parser
.
parse_args
()
return
args
class
SimpleImgConv
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
filter_size
,
conv_stride
=
1
,
conv_padding
=
0
,
conv_dilation
=
1
,
conv_groups
=
1
,
act
=
None
,
use_cudnn
=
False
,
param_attr
=
None
,
bias_attr
=
None
):
super
(
SimpleImgConv
,
self
).
__init__
()
self
.
_conv2d
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
conv_stride
,
padding
=
conv_padding
,
dilation
=
conv_dilation
,
groups
=
conv_groups
,
param_attr
=
None
,
bias_attr
=
None
,
act
=
act
,
use_cudnn
=
use_cudnn
)
def
forward
(
self
,
inputs
):
x
=
self
.
_conv2d
(
inputs
)
return
x
class
MNIST
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
(
MNIST
,
self
).
__init__
()
self
.
_simple_img_conv_pool_1
=
SimpleImgConv
(
1
,
20
,
2
,
act
=
"relu"
)
self
.
arch
=
SuperMnasnet
(
name_scope
=
"super_net"
,
input_channels
=
20
,
out_channels
=
20
)
self
.
_simple_img_conv_pool_2
=
SimpleImgConv
(
20
,
50
,
2
,
act
=
"relu"
)
self
.
pool_2_shape
=
50
*
13
*
13
SIZE
=
10
scale
=
(
2.0
/
(
self
.
pool_2_shape
**
2
*
SIZE
))
**
0.5
self
.
_fc
=
Linear
(
self
.
pool_2_shape
,
10
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
scale
)),
act
=
"softmax"
)
def
forward
(
self
,
inputs
,
label
=
None
,
tokens
=
None
):
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
arch
(
x
,
tokens
=
tokens
)
# addddddd
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
fluid
.
layers
.
reshape
(
x
,
shape
=
[
-
1
,
self
.
pool_2_shape
])
x
=
self
.
_fc
(
x
)
if
label
is
not
None
:
acc
=
fluid
.
layers
.
accuracy
(
input
=
x
,
label
=
label
)
return
x
,
acc
else
:
return
x
def
test_mnist
(
model
,
tokens
=
None
):
acc_set
=
[]
avg_loss_set
=
[]
batch_size
=
64
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
batch_size
,
drop_last
=
True
)
for
batch_id
,
data
in
enumerate
(
test_reader
()):
dy_x_data
=
np
.
array
([
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
batch_size
,
1
)
img
=
to_variable
(
dy_x_data
)
label
=
to_variable
(
y_data
)
label
.
stop_gradient
=
True
prediction
,
acc
=
model
.
forward
(
img
,
label
,
tokens
=
tokens
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
acc_set
.
append
(
float
(
acc
.
numpy
()))
avg_loss_set
.
append
(
float
(
avg_loss
.
numpy
()))
if
batch_id
%
100
==
0
:
print
(
"Test - batch_id: {}"
.
format
(
batch_id
))
# get test acc and loss
acc_val_mean
=
np
.
array
(
acc_set
).
mean
()
avg_loss_val_mean
=
np
.
array
(
avg_loss_set
).
mean
()
return
acc_val_mean
def
train_mnist
(
args
,
model
,
tokens
=
None
):
epoch_num
=
args
.
epoch
BATCH_SIZE
=
64
adam
=
AdamOptimizer
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
())
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
BATCH_SIZE
,
drop_last
=
True
)
if
args
.
use_data_parallel
:
train_reader
=
fluid
.
contrib
.
reader
.
distributed_batch_reader
(
train_reader
)
for
epoch
in
range
(
epoch_num
):
for
batch_id
,
data
in
enumerate
(
train_reader
()):
dy_x_data
=
np
.
array
([
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
-
1
,
1
)
img
=
to_variable
(
dy_x_data
)
label
=
to_variable
(
y_data
)
label
.
stop_gradient
=
True
cost
,
acc
=
model
.
forward
(
img
,
label
,
tokens
=
tokens
)
loss
=
fluid
.
layers
.
cross_entropy
(
cost
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
if
args
.
use_data_parallel
:
avg_loss
=
model
.
scale_loss
(
avg_loss
)
avg_loss
.
backward
()
model
.
apply_collective_grads
()
else
:
avg_loss
.
backward
()
adam
.
minimize
(
avg_loss
)
# save checkpoint
model
.
clear_gradients
()
if
batch_id
%
1
==
0
:
print
(
"Loss at epoch {} step {}: {:}"
.
format
(
epoch
,
batch_id
,
avg_loss
.
numpy
()))
model
.
eval
()
test_acc
=
test_mnist
(
model
,
tokens
=
tokens
)
model
.
train
()
print
(
"Loss at epoch {} , acc is: {}"
.
format
(
epoch
,
test_acc
))
save_parameters
=
(
not
args
.
use_data_parallel
)
or
(
args
.
use_data_parallel
and
fluid
.
dygraph
.
parallel
.
Env
().
local_rank
==
0
)
if
save_parameters
:
fluid
.
save_dygraph
(
model
.
state_dict
(),
"save_temp"
)
print
(
"checkpoint saved"
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
place
=
fluid
.
CPUPlace
()
with
fluid
.
dygraph
.
guard
(
place
):
model
=
MNIST
()
# step 1: training super net
#train_mnist(args, model)
# step 2: search
best_tokens
=
OneShotSearch
(
model
,
test_mnist
)
# step 3: final training
# train_mnist(args, model, best_tokens)
docs/docs/api/one_shot_api.md
0 → 100644
浏览文件 @
02144bca
## OneShotSearch
paddleslim.nas.one_shot.OneShotSearch(model, eval_func, strategy='sa', search_steps=100)
[
代码
](
)
: 从超级网络中搜索出一个最佳的子网络。
**参数:**
-
**model(fluid.dygraph.layer):**
通过在
`OneShotSuperNet`
前后添加若该模块构建的动态图模块。因为
`OneShotSuperNet`
是一个超网络,所以
`model`
也是一个超网络。换句话说,在
`model`
模块的子模块中,至少有一个是
`OneShotSuperNet`
的实例。该方法从
`model`
超网络中搜索得到一个最佳的子网络。超网络
`model`
需要先被训练,具体细节请参考
[
OneShotSuperNet
](
)。
-
**eval_func:**
用于评估子网络性能的回调函数。该回调函数需要接受
`model`
为参数,并调用
`model`
的
`forward`
方法进行性能评估。
-
**strategy(str):**
搜索策略的名称。默认为'sa', 当前仅支持'sa'.
-
**search_steps(int):**
搜索轮次数。默认为100。
**返回:**
-
**best_tokens:**
表示最佳子网络的编码信息(tokens)。
**示例代码:**
请参考
[
one-shot NAS示例
](
)
## OneShotSuperNet
用于
`OneShot`
搜索策略的超级网络的基类,所有超级网络的实现要继承该类。
paddleslim.nas.one_shot.OneShotSuperNet(name_scope)
: 构造方法。
**参数:**
-
**name_scope:(str) **
超级网络的命名空间。
**返回:**
-
**super_net:**
一个
`OneShotSuperNet`
实例。
init_tokens()
: 获得当前超级网络的初始化子网络的编码,主要用于搜索。
**返回:**
-
**tokens(list<int>):**
一个子网络的编码。
range_table()
: 超级网络中各个子网络由一组整型数字编码表示,该方法返回编码每个位置的取值范围。
**返回:**
-
**range_table(tuple):**
子网络编码每一位的取值范围。
`range_table`
格式为
`(min_values, max_values)`
,其中,
`min_values`
为一个整型数组,表示每个编码位置可选取的最小值;
`max_values`
表示每个编码位置可选取的最大值。
_forward_
impl(input, tokens)
: 前向计算函数。
`OneShotSuperNet`
的子类需要实现该函数。
**参数:**
-
**input(Variable):**
超级网络的输入。
-
**tokens(list<int>):**
执行前向计算所用的子网络的编码。默认为
`None`
,即随机选取一个子网络执行前向。
**返回:**
-
**output(Variable):**
前向计算的输出
forward(self, input, tokens=None)
: 执行前向计算。
**参数:**
-
**input(Variable):**
超级网络的输入。
-
**tokens(list<int>):**
执行前向计算所用的子网络的编码。默认为
`None`
,即随机选取一个子网络执行前向。
**返回:**
-
**output(Variable):**
前向计算的输出
_random_
tokens()
: 随机选取一个子网络,并返回其编码。
**返回:**
-
**tokens(list<int>):**
一个子网络的编码。
## SuperMnasnet
在
[
Mnasnet
](
https://arxiv.org/abs/1807.11626
)
基础上修改得到的超级网络, 该类继承自
`OneShotSuperNet`
.
paddleslim.nas.one_shot.SuperMnasnet(name_scope, input_channels=3, out_channels=1280, repeat_times=[6, 6, 6, 6, 6, 6], stride=[1, 1, 1, 1, 2, 1], channels=[16, 24, 40, 80, 96, 192, 320], use_auxhead=False)
: 构造函数。
**参数:**
-
**name_scope(str):**
命名空间。
-
**input_channels(str):**
当前超级网络的输入的特征图的通道数量。
-
**out_channels(str):**
当前超级网络的输出的特征图的通道数量。
-
**repeat_times(list):**
每种
`block`
重复的次数。
-
**stride(list):**
一种
`block`
重复堆叠成
`repeat_block`
,
`stride`
表示每个
`repeat_block`
的下采样比例。
-
**channels(list):**
channels[i]和channels[i+1]分别表示第i个
`repeat_block`
的输入特征图的通道数和输出特征图的通道数。
-
**use_auxhead(bool):**
是否使用辅助特征图。如果设置为
`True`
,则
`SuperMnasnet`
除了返回输出特征图,还还返回辅助特征图。默认为False.
**返回:**
-
**instance(SuperMnasnet):**
一个
`SuperMnasnet`
实例
**示例:**
```
import paddle
import paddle.fluid as fluid
class MNIST(fluid.dygraph.Layer):
def __init__(self):
super(MNIST, self).__init__()
self.arch = SuperMnasnet(
name_scope="super_net", input_channels=20, out_channels=20)
self.pool_2_shape = 50 * 13 * 13
SIZE = 10
scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(
self.pool_2_shape,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs, label=None, tokens=None):
x = self.arch(inputs, tokens=tokens)
x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
x = self._fc(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
return x, acc
else:
return x
```
docs/docs/tutorials/one_shot_nas_demo.md
0 → 100644
浏览文件 @
02144bca
# One Shot NAS 示例
>该示例依赖Paddle1.7.0或Paddle develop版本。
该示例使用MNIST数据,介绍了如何使用PaddleSlim的OneShotNAS接口搜索出一个分类网络。OneShotNAS仅支持动态图,所以该示例完全使用Paddle动态图模式。
## 关键代码介绍
One-shot网络结构搜索策略包含以下步骤:
1.
定义超网络
2.
训练超网络
3.
基于超网络搜索子网络
4.
训练最佳子网络
以下按序介绍各个步骤的关键代码。
### 定义超级网络
按照动态图教程,定义一个分类网络模块,该模块包含4个子模块:
`_simple_img_conv_pool_1`
,
`_simple_img_conv_pool_2`
,
`super_net`
和
`fc`
,其中
`super_net`
为
`SuperMnasnet`
的一个实例。
在前向计算过程中,输入图像先后经过子模块
`_simple_img_conv_pool_1`
、
`super_net`
、
`_simple_img_conv_pool_2`
和
`fc`
的前向计算。
代码如下所示:
```
class MNIST(fluid.dygraph.Layer):
def __init__(self):
super(MNIST, self).__init__()
self._simple_img_conv_pool_1 = SimpleImgConv(1, 20, 2, act="relu")
self.arch = SuperMnasnet(
name_scope="super_net", input_channels=20, out_channels=20)
self._simple_img_conv_pool_2 = SimpleImgConv(20, 50, 2, act="relu")
self.pool_2_shape = 50 * 13 * 13
SIZE = 10
scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(
self.pool_2_shape,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs, label=None, tokens=None):
x = self._simple_img_conv_pool_1(inputs)
x = self.arch(x, tokens=tokens) # addddddd
x = self._simple_img_conv_pool_2(x)
x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
x = self._fc(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
return x, acc
else:
return x
```
动态图模块MNIST的forward函数接受一个参数
`tokens`
,用于指定在前向计算中使用的子网络,如果
`tokens`
为None,则随机选取一个子网络进行前向计算。
### 训练超级网络
网络训练的逻辑定义在
`train_mnist`
函数中,将
`tokens`
参数设置为None,进行超网络训练,即在每个batch选取一个超网络进行训练。
代码如下所示:
```
with
fluid
.
dygraph
.
guard
(
place
):
model
=
MNIST
()
train_mnist
(
args
,
model
)
```
### 搜索最佳子网络
使用PaddleSlim提供的
`OneShotSearch`
接口搜索最佳子网络。传入已定义且训练好的超网络实例
`model`
和一个用于评估子网络的回调函数
`test_mnist`
.
代码如下:
```
best_tokens = OneShotSearch(model, test_mnist)
```
### 训练最佳子网络
获得最佳的子网络的编码
`best_tokens`
后,调用之前定义的
`train_mnist`
方法进行子网络的训练。代码如下:
```
train_mnist(args, model, best_tokens)
```
## 启动示例
执行以下代码运行示例:
```
python train.py
```
执行
`python train.py --help`
查看更多可配置选项。
## FAQ
docs/mkdocs.yml
浏览文件 @
02144bca
...
@@ -9,6 +9,7 @@ nav:
...
@@ -9,6 +9,7 @@ nav:
-
量化训练
:
tutorials/quant_aware_demo.md
-
量化训练
:
tutorials/quant_aware_demo.md
-
Embedding量化
:
tutorials/quant_embedding_demo.md
-
Embedding量化
:
tutorials/quant_embedding_demo.md
-
SA搜索
:
tutorials/nas_demo.md
-
SA搜索
:
tutorials/nas_demo.md
-
One-shot搜索
:
tutorials/one_shot_nas_demo.md
-
搜索空间
:
search_space.md
-
搜索空间
:
search_space.md
-
知识蒸馏
:
tutorials/distillation_demo.md
-
知识蒸馏
:
tutorials/distillation_demo.md
-
API
:
-
API
:
...
@@ -17,6 +18,8 @@ nav:
...
@@ -17,6 +18,8 @@ nav:
-
模型分析
:
api/analysis_api.md
-
模型分析
:
api/analysis_api.md
-
知识蒸馏
:
api/single_distiller_api.md
-
知识蒸馏
:
api/single_distiller_api.md
-
SA搜索
:
api/nas_api.md
-
SA搜索
:
api/nas_api.md
-
One-shot搜索
:
api/one_shot_api.md
-
搜索空间
:
search_space.md
-
硬件延时评估表
:
table_latency.md
-
硬件延时评估表
:
table_latency.md
-
算法原理
:
algo/algo.md
-
算法原理
:
algo/algo.md
...
...
paddleslim/nas/__init__.py
浏览文件 @
02144bca
...
@@ -11,8 +11,12 @@
...
@@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_import
from
..nas
import
search_space
from
.search_space
import
*
from
.search_space
import
*
from
.sa_nas
import
SANAS
from
..nas
import
sa_nas
from
.sa_nas
import
*
__all__
=
[
'SANAS'
]
__all__
=
[]
__all__
+=
sa_nas
.
__all__
__all__
+=
search_space
.
__all__
paddleslim/nas/one_shot/__init__.py
0 → 100644
浏览文件 @
02144bca
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
..one_shot
import
one_shot_nas
from
.one_shot_nas
import
*
from
..one_shot
import
super_mnasnet
from
.super_mnasnet
import
*
__all__
=
[]
__all__
+=
one_shot_nas
.
__all__
__all__
+=
super_mnasnet
.
__all__
paddleslim/nas/one_shot/one_shot_nas.py
0 → 100644
浏览文件 @
02144bca
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle.fluid
as
fluid
from
...common
import
SAController
__all__
=
[
'OneShotSuperNet'
,
'OneShotSearch'
]
def
OneShotSearch
(
model
,
eval_func
,
strategy
=
'sa'
,
search_steps
=
100
):
"""
Search a best tokens which represents a sub-network.
Archs:
model(fluid.dygraph.Layer): A dynamic graph module whose sub-modules should contain
one instance of `OneShotSuperNet` at least.
eval_func(function): A callback function which accept model and tokens as arguments.
strategy(str): The name of strategy used to search. Default: 'sa'.
search_steps(int): The total steps for searching.
Returns:
tokens(list): The best tokens searched.
"""
super_net
=
None
for
layer
in
model
.
sublayers
(
include_sublayers
=
False
):
print
(
"layer: {}"
.
format
(
layer
))
if
isinstance
(
layer
,
OneShotSuperNet
):
super_net
=
layer
break
assert
super_net
is
not
None
controller
=
None
if
strategy
==
"sa"
:
contoller
=
SAController
(
range_table
=
super_net
.
range_table
(),
init_tokens
=
super_net
.
init_tokens
())
assert
(
controller
is
not
None
,
"Unsupported searching strategy."
)
for
i
in
range
(
search_steps
):
tokens
=
contoller
.
next_tokens
()
reward
=
eval_func
(
model
,
tokens
)
contoller
.
update
(
tokens
,
reward
,
i
)
return
contoller
.
best_tokens
()
class
OneShotSuperNet
(
fluid
.
dygraph
.
Layer
):
"""
The base class of super net used in one-shot searching strategy.
A super net is a dygraph layer.
Args:
name_scope(str): The name scope of super net.
"""
def
__init__
(
self
,
name_scope
):
super
(
OneShotSuperNet
,
self
).
__init__
(
name_scope
)
def
init_tokens
(
self
):
"""Get init tokens in search space.
Return:
tokens(list): The init tokens which is a list of integer.
"""
raise
NotImplementedError
(
'Abstract method.'
)
def
range_table
(
self
):
"""Get range table of current search space.
Return:
range_table(tuple): The maximum value and minimum value in each position of tokens
with format `(min_values, max_values)`. The `min_values` is
a list of integers indicating the minimum values while `max_values`
indicating the maximum values.
"""
raise
NotImplementedError
(
'Abstract method.'
)
def
_forward_impl
(
self
,
*
inputs
,
**
kwargs
):
"""
Defines the computation performed at every call.
Should be overridden by all subclasses.
Args:
inputs(tuple): unpacked tuple arguments
kwargs(dict): unpacked dict arguments
"""
raise
NotImplementedError
(
'Abstract method.'
)
def
forward
(
self
,
input
,
tokens
=
None
):
"""
Defines the computation performed at every call.
Args:
input(variable): The input of super net.
tokens(list): The tokens used to generate a sub-network.
None means computing in super net training mode.
Otherwise, it will execute the sub-network generated by tokens.
The `tokens` should be set in searching stage and final training stage.
Default: None.
Returns:
output(varaible): The output of super net.
"""
if
tokens
==
None
:
tokens
=
self
.
_random_tokens
()
return
self
.
_forward_impl
(
input
,
tokens
=
tokens
)
def
_random_tokens
(
self
):
tokens
=
[]
for
min_v
,
max_v
in
zip
(
self
.
range_table
()[
0
],
self
.
range_table
()[
1
]):
tokens
.
append
(
np
.
random
.
randint
(
min_v
,
max_v
))
return
tokens
paddleslim/nas/one_shot/super_mnasnet.py
0 → 100644
浏览文件 @
02144bca
import
paddle
from
paddle
import
fluid
from
paddle.fluid.layer_helper
import
LayerHelper
import
numpy
as
np
from
one_shot_nas
import
OneShotSuperNet
__all__
=
[
'SuperMnasnet'
]
class
DConvBlock
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
,
in_channels
,
channels
,
expansion
,
stride
,
kernel_size
=
3
,
padding
=
1
):
super
(
DConvBlock
,
self
).
__init__
(
name_scope
)
self
.
expansion
=
expansion
self
.
in_channels
=
in_channels
self
.
channels
=
channels
self
.
stride
=
stride
self
.
flops
=
0
self
.
flops_calculated
=
False
self
.
expand
=
fluid
.
dygraph
.
Conv2D
(
in_channels
,
num_filters
=
in_channels
*
expansion
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
bias_attr
=
False
)
self
.
expand_bn
=
fluid
.
dygraph
.
BatchNorm
(
num_channels
=
in_channels
*
expansion
,
act
=
'relu6'
)
self
.
dconv
=
fluid
.
dygraph
.
Conv2D
(
in_channels
*
expansion
,
num_filters
=
in_channels
*
expansion
,
filter_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
act
=
None
,
bias_attr
=
False
,
groups
=
in_channels
*
expansion
,
use_cudnn
=
False
)
self
.
dconv_bn
=
fluid
.
dygraph
.
BatchNorm
(
num_channels
=
in_channels
*
expansion
,
act
=
'relu6'
)
self
.
project
=
fluid
.
dygraph
.
Conv2D
(
in_channels
*
expansion
,
num_filters
=
channels
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
bias_attr
=
False
)
self
.
project_bn
=
fluid
.
dygraph
.
BatchNorm
(
num_channels
=
channels
,
act
=
None
)
self
.
shortcut
=
fluid
.
dygraph
.
Conv2D
(
in_channels
,
num_filters
=
channels
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
bias_attr
=
False
)
self
.
shortcut_bn
=
fluid
.
dygraph
.
BatchNorm
(
num_channels
=
channels
,
act
=
None
)
def
get_flops
(
self
,
input
,
output
,
op
):
if
not
self
.
flops_calculated
:
flops
=
input
.
shape
[
1
]
*
output
.
shape
[
1
]
*
(
op
.
_filter_size
**
2
)
*
output
.
shape
[
2
]
*
output
.
shape
[
3
]
if
op
.
_groups
:
flops
/=
op
.
_groups
self
.
flops
+=
flops
def
forward
(
self
,
inputs
):
expand_x
=
self
.
expand_bn
(
self
.
expand
(
inputs
))
self
.
get_flops
(
inputs
,
expand_x
,
self
.
expand
)
dconv_x
=
self
.
dconv_bn
(
self
.
dconv
(
expand_x
))
self
.
get_flops
(
expand_x
,
dconv_x
,
self
.
dconv
)
proj_x
=
self
.
project_bn
(
self
.
project
(
dconv_x
))
self
.
get_flops
(
dconv_x
,
proj_x
,
self
.
project
)
if
self
.
in_channels
!=
self
.
channels
and
self
.
stride
==
1
:
shortcut
=
self
.
shortcut_bn
(
self
.
shortcut
(
inputs
))
self
.
get_flops
(
inputs
,
shortcut
,
self
.
shortcut
)
elif
self
.
stride
==
1
:
shortcut
=
inputs
self
.
flops_calculated
=
True
if
self
.
stride
==
1
:
out
=
fluid
.
layers
.
elementwise_add
(
x
=
proj_x
,
y
=
shortcut
)
return
out
return
proj_x
class
SearchBlock
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
,
in_channels
,
channels
,
stride
,
kernel_size
=
3
,
padding
=
1
):
super
(
SearchBlock
,
self
).
__init__
(
name_scope
)
self
.
_stride
=
stride
self
.
block_list
=
[]
self
.
flops
=
[
0
for
i
in
range
(
10
)]
self
.
flops_calculated
=
[
False
if
i
<
6
else
True
for
i
in
range
(
10
)]
kernels
=
[
3
,
5
,
7
]
expansions
=
[
3
,
6
]
for
k
in
kernels
:
for
e
in
expansions
:
self
.
block_list
.
append
(
DConvBlock
(
self
.
full_name
(),
in_channels
,
channels
,
e
,
stride
,
k
,
(
k
-
1
)
//
2
))
self
.
add_sublayer
(
"expansion_{}_kernel_{}"
.
format
(
e
,
k
),
self
.
block_list
[
-
1
])
def
forward
(
self
,
inputs
,
arch
):
if
arch
>=
6
:
return
inputs
out
=
self
.
block_list
[
arch
](
inputs
)
if
not
self
.
flops_calculated
[
arch
]:
self
.
flops
[
arch
]
=
self
.
block_list
[
arch
].
flops
self
.
flops_calculated
[
arch
]
=
True
return
out
class
AuxiliaryHead
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
,
num_classes
):
super
(
AuxiliaryHead
,
self
).
__init__
(
name_scope
)
self
.
pool1
=
fluid
.
dygraph
.
Pool2D
(
5
,
'avg'
,
pool_stride
=
3
,
pool_padding
=
0
)
self
.
conv1
=
fluid
.
dygraph
.
Conv2D
(
128
,
1
,
bias_attr
=
False
)
self
.
bn1
=
fluid
.
dygraph
.
BatchNorm
(
128
,
act
=
'relu6'
)
self
.
conv2
=
fluid
.
dygraph
.
Conv2D
(
768
,
2
,
bias_attr
=
False
)
self
.
bn2
=
fluid
.
dygraph
.
BatchNorm
(
768
,
act
=
'relu6'
)
self
.
classifier
=
fluid
.
dygraph
.
FC
(
num_classes
,
act
=
'softmax'
)
self
.
layer_helper
=
LayerHelper
(
self
.
full_name
(),
act
=
'relu6'
)
def
forward
(
self
,
inputs
):
#pylint: disable=arguments-differ
inputs
=
self
.
layer_helper
.
append_activation
(
inputs
)
inputs
=
self
.
pool1
(
inputs
)
inputs
=
self
.
conv1
(
inputs
)
inputs
=
self
.
bn1
(
inputs
)
inputs
=
self
.
conv2
(
inputs
)
inputs
=
self
.
bn2
(
inputs
)
inputs
=
self
.
classifier
(
inputs
)
return
inputs
class
SuperMnasnet
(
OneShotSuperNet
):
def
__init__
(
self
,
name_scope
,
input_channels
=
3
,
out_channels
=
1280
,
repeat_times
=
[
6
,
6
,
6
,
6
,
6
,
6
],
stride
=
[
1
,
1
,
1
,
1
,
2
,
1
],
channels
=
[
16
,
24
,
40
,
80
,
96
,
192
,
320
],
use_auxhead
=
False
):
super
(
SuperMnasnet
,
self
).
__init__
(
name_scope
)
self
.
flops
=
0
self
.
repeat_times
=
repeat_times
self
.
flops_calculated
=
False
self
.
last_tokens
=
None
self
.
_conv
=
fluid
.
dygraph
.
Conv2D
(
input_channels
,
32
,
3
,
1
,
1
,
act
=
None
,
bias_attr
=
False
)
self
.
_bn
=
fluid
.
dygraph
.
BatchNorm
(
32
,
act
=
'relu6'
)
self
.
_sep_conv
=
fluid
.
dygraph
.
Conv2D
(
32
,
32
,
3
,
1
,
1
,
groups
=
32
,
act
=
None
,
use_cudnn
=
False
,
bias_attr
=
False
)
self
.
_sep_conv_bn
=
fluid
.
dygraph
.
BatchNorm
(
32
,
act
=
'relu6'
)
self
.
_sep_project
=
fluid
.
dygraph
.
Conv2D
(
32
,
16
,
1
,
1
,
0
,
act
=
None
,
bias_attr
=
False
)
self
.
_sep_project_bn
=
fluid
.
dygraph
.
BatchNorm
(
16
,
act
=
'relu6'
)
self
.
_final_conv
=
fluid
.
dygraph
.
Conv2D
(
320
,
out_channels
,
1
,
1
,
0
,
act
=
None
,
bias_attr
=
False
)
self
.
_final_bn
=
fluid
.
dygraph
.
BatchNorm
(
out_channels
,
act
=
'relu6'
)
self
.
stride
=
stride
self
.
block_list
=
[]
self
.
use_auxhead
=
use_auxhead
for
_iter
,
_stride
in
enumerate
(
self
.
stride
):
repeat_block
=
[]
for
_ind
in
range
(
self
.
repeat_times
[
_iter
]):
if
_ind
==
0
:
block
=
SearchBlock
(
self
.
full_name
(),
channels
[
_iter
],
channels
[
_iter
+
1
],
_stride
)
else
:
block
=
SearchBlock
(
self
.
full_name
(),
channels
[
_iter
+
1
],
channels
[
_iter
+
1
],
1
)
self
.
add_sublayer
(
"block_{}_{}"
.
format
(
_iter
,
_ind
),
block
)
repeat_block
.
append
(
block
)
self
.
block_list
.
append
(
repeat_block
)
if
self
.
use_auxhead
:
self
.
auxhead
=
AuxiliaryHead
(
self
.
full_name
(),
10
)
def
init_tokens
(
self
):
return
[
3
,
3
,
6
,
6
,
6
,
6
,
3
,
3
,
3
,
6
,
6
,
6
,
3
,
3
,
3
,
3
,
6
,
6
,
3
,
3
,
3
,
6
,
6
,
6
,
3
,
3
,
3
,
6
,
6
,
6
,
3
,
6
,
6
,
6
,
6
,
6
]
def
range_table
(
self
):
max_v
=
[
6
,
6
,
10
,
10
,
10
,
10
,
6
,
6
,
6
,
10
,
10
,
10
,
6
,
6
,
6
,
6
,
10
,
10
,
6
,
6
,
6
,
10
,
10
,
10
,
6
,
6
,
6
,
10
,
10
,
10
,
6
,
10
,
10
,
10
,
10
,
10
]
return
(
len
(
max_v
)
*
[
0
],
max_v
)
def
get_flops
(
self
,
input
,
output
,
op
):
if
not
self
.
flops_calculated
:
flops
=
input
.
shape
[
1
]
*
output
.
shape
[
1
]
*
(
op
.
_filter_size
**
2
)
*
output
.
shape
[
2
]
*
output
.
shape
[
3
]
if
op
.
_groups
:
flops
/=
op
.
_groups
self
.
flops
+=
flops
def
_forward_impl
(
self
,
inputs
,
tokens
=
None
):
if
isinstance
(
tokens
,
np
.
ndarray
)
and
not
(
tokens
==
self
.
last_tokens
).
all
()
\
or
not
isinstance
(
tokens
,
np
.
ndarray
)
and
not
tokens
==
self
.
last_tokens
:
self
.
flops_calculated
=
False
self
.
flops
=
0
self
.
last_tokens
=
tokens
x
=
self
.
_bn
(
self
.
_conv
(
inputs
))
self
.
get_flops
(
inputs
,
x
,
self
.
_conv
)
sep_x
=
self
.
_sep_conv_bn
(
self
.
_sep_conv
(
x
))
self
.
get_flops
(
x
,
sep_x
,
self
.
_sep_conv
)
proj_x
=
self
.
_sep_project_bn
(
self
.
_sep_project
(
sep_x
))
self
.
get_flops
(
sep_x
,
proj_x
,
self
.
_sep_project
)
x
=
proj_x
for
ind
in
range
(
len
(
self
.
block_list
)):
for
b_ind
,
block
in
enumerate
(
self
.
block_list
[
ind
]):
x
=
fluid
.
layers
.
dropout
(
block
(
x
,
tokens
[
ind
*
6
+
b_ind
]),
0.
)
if
not
self
.
flops_calculated
:
self
.
flops
+=
block
.
flops
[
tokens
[
ind
*
6
+
b_ind
]]
if
ind
==
len
(
self
.
block_list
)
*
2
//
3
-
1
and
self
.
use_auxhead
:
fc_aux
=
self
.
auxhead
(
x
)
final_x
=
self
.
_final_bn
(
self
.
_final_conv
(
x
))
self
.
get_flops
(
x
,
final_x
,
self
.
_final_conv
)
# x = self.global_pooling(final_x)
self
.
flops_calculated
=
True
if
self
.
use_auxhead
:
return
final_x
,
fc_aux
return
final_x
paddleslim/nas/search_space/__init__.py
浏览文件 @
02144bca
...
@@ -21,7 +21,6 @@ from .inception_block import InceptionABlockSpace, InceptionCBlockSpace
...
@@ -21,7 +21,6 @@ from .inception_block import InceptionABlockSpace, InceptionCBlockSpace
from
.search_space_registry
import
SEARCHSPACE
from
.search_space_registry
import
SEARCHSPACE
from
.search_space_factory
import
SearchSpaceFactory
from
.search_space_factory
import
SearchSpaceFactory
from
.search_space_base
import
SearchSpaceBase
from
.search_space_base
import
SearchSpaceBase
__all__
=
[
__all__
=
[
'MobileNetV1Space'
,
'MobileNetV2Space'
,
'ResNetSpace'
,
'MobileNetV1Space'
,
'MobileNetV2Space'
,
'ResNetSpace'
,
'MobileNetV1BlockSpace'
,
'MobileNetV2BlockSpace'
,
'ResNetBlockSpace'
,
'MobileNetV1BlockSpace'
,
'MobileNetV2BlockSpace'
,
'ResNetBlockSpace'
,
...
...
paddleslim/nas/search_space/search_space_base.py
浏览文件 @
02144bca
...
@@ -19,6 +19,7 @@ __all__ = ['SearchSpaceBase']
...
@@ -19,6 +19,7 @@ __all__ = ['SearchSpaceBase']
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
class
SearchSpaceBase
(
object
):
class
SearchSpaceBase
(
object
):
"""Controller for Neural Architecture Search.
"""Controller for Neural Architecture Search.
"""
"""
...
@@ -56,3 +57,7 @@ class SearchSpaceBase(object):
...
@@ -56,3 +57,7 @@ class SearchSpaceBase(object):
model arch
model arch
"""
"""
raise
NotImplementedError
(
'Abstract method.'
)
raise
NotImplementedError
(
'Abstract method.'
)
def
super_net
(
self
):
"""This function is just used in one shot NAS strategy. Return a super graph."""
raise
NotImplementedError
(
'Abstract method.'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录