Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
9f43bbcc
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9f43bbcc
编写于
11月 21, 2020
作者:
C
ceci3
提交者:
GitHub
11月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug for OFA (#464)
* fix bugs for ernie
上级
c6fdcc3f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
354 addition
and
84 deletion
+354
-84
demo/one_shot/ofa_train.py
demo/one_shot/ofa_train.py
+22
-16
paddleslim/nas/__init__.py
paddleslim/nas/__init__.py
+1
-0
paddleslim/nas/ofa/convert_super.py
paddleslim/nas/ofa/convert_super.py
+81
-7
paddleslim/nas/ofa/layers.py
paddleslim/nas/ofa/layers.py
+106
-13
paddleslim/nas/ofa/ofa.py
paddleslim/nas/ofa/ofa.py
+61
-37
tests/test_ofa.py
tests/test_ofa.py
+83
-11
未找到文件。
demo/one_shot/ofa_train.py
浏览文件 @
9f43bbcc
...
...
@@ -14,14 +14,14 @@
import
numpy
as
np
import
paddle
import
paddle.
fluid
as
fluid
import
paddle.
fluid.dygraph.nn
as
nn
import
paddle.
nn
as
nn
import
paddle.
nn.functional
as
F
from
paddle.nn
import
ReLU
from
paddleslim.nas.ofa
import
OFA
,
RunConfig
,
DistillConfig
from
paddleslim.nas.ofa
import
supernet
class
Model
(
fluid
.
dygraph
.
Layer
):
class
Model
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
with
supernet
(
...
...
@@ -50,18 +50,20 @@ class Model(fluid.dygraph.Layer):
for
idx
,
layer
in
enumerate
(
models
):
if
idx
==
6
:
inputs
=
fluid
.
layers
.
flatten
(
inputs
,
1
)
inputs
=
paddle
.
flatten
(
inputs
,
1
)
inputs
=
layer
(
inputs
)
inputs
=
fluid
.
layers
.
softmax
(
inputs
)
inputs
=
F
.
softmax
(
inputs
)
return
inputs
def
test_ofa
():
model
=
Model
()
teacher_model
=
Model
()
default_run_config
=
{
'train_batch_size'
:
256
,
'eval_batch_size'
:
64
,
'n_epochs'
:
[[
1
],
[
2
,
3
],
[
4
,
5
]],
'init_learning_rate'
:
[[
0.001
],
[
0.003
,
0.001
],
[
0.003
,
0.001
]],
'dynamic_batch_size'
:
[
1
,
1
,
1
],
...
...
@@ -72,42 +74,46 @@ def test_ofa():
default_distill_config
=
{
'lambda_distill'
:
0.01
,
'teacher_model'
:
M
odel
,
'teacher_model'
:
teacher_m
odel
,
'mapping_layers'
:
[
'models.0.fn'
]
}
distill_config
=
DistillConfig
(
**
default_distill_config
)
fluid
.
enable_dygraph
()
model
=
Model
()
ofa_model
=
OFA
(
model
,
run_config
,
distill_config
=
distill_config
)
train_reader
=
paddle
.
fluid
.
io
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
256
,
drop_last
=
True
)
train_dataset
=
paddle
.
vision
.
datasets
.
MNIST
(
mode
=
'train'
,
backend
=
'cv2'
,
transform
=
transform
)
train_loader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
places
=
place
,
feed_list
=
[
image
,
label
],
drop_last
=
True
,
batch_size
=
64
)
start_epoch
=
0
for
idx
in
range
(
len
(
run_config
.
n_epochs
)):
cur_idx
=
run_config
.
n_epochs
[
idx
]
for
ph_idx
in
range
(
len
(
cur_idx
)):
cur_lr
=
run_config
.
init_learning_rate
[
idx
][
ph_idx
]
adam
=
fluid
.
optimizer
.
Adam
(
adam
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
cur_lr
,
parameter_list
=
(
ofa_model
.
parameters
()
+
ofa_model
.
netAs_param
))
for
epoch_id
in
range
(
start_epoch
,
run_config
.
n_epochs
[
idx
][
ph_idx
]):
for
batch_id
,
data
in
enumerate
(
train_
re
ader
()):
for
batch_id
,
data
in
enumerate
(
train_
lo
ader
()):
dy_x_data
=
np
.
array
(
[
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
-
1
,
1
)
img
=
fluid
.
dygraph
.
to_variable
(
dy_x_data
)
label
=
fluid
.
dygraph
.
to_variable
(
y_data
)
img
=
paddle
.
dygraph
.
to_variable
(
dy_x_data
)
label
=
paddle
.
dygraph
.
to_variable
(
y_data
)
label
.
stop_gradient
=
True
for
model_no
in
range
(
run_config
.
dynamic_batch_size
[
idx
]):
output
,
_
=
ofa_model
(
img
,
label
)
loss
=
fluid
.
layers
.
reduce_
mean
(
output
)
loss
=
F
.
mean
(
output
)
dis_loss
=
ofa_model
.
calc_distill_loss
()
loss
+=
dis_loss
loss
.
backward
()
...
...
paddleslim/nas/__init__.py
浏览文件 @
9f43bbcc
...
...
@@ -19,6 +19,7 @@ from .sa_nas import *
from
.rl_nas
import
*
from
..nas
import
darts
from
.darts
import
*
from
.ofa
import
*
__all__
=
[]
__all__
+=
sa_nas
.
__all__
...
...
paddleslim/nas/ofa/convert_super.py
浏览文件 @
9f43bbcc
...
...
@@ -16,9 +16,8 @@ import inspect
import
decorator
import
logging
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
framework
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Conv2DTranspose
,
Linear
,
BatchNorm
,
InstanceNorm
import
numbers
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Conv2DTranspose
,
Linear
,
BatchNorm
,
InstanceNorm
,
LayerNorm
,
Embedding
from
.layers
import
*
from
...common
import
get_logger
...
...
@@ -26,7 +25,7 @@ _logger = get_logger(__name__, level=logging.INFO)
__all__
=
[
'supernet'
]
WEIGHT_LAYER
=
[
'conv'
,
'linear'
]
WEIGHT_LAYER
=
[
'conv'
,
'linear'
,
'embedding'
]
### TODO: add decorator
...
...
@@ -45,7 +44,7 @@ class Convert:
cur_channel
=
None
for
idx
,
layer
in
enumerate
(
model
):
cls_name
=
layer
.
__class__
.
__name__
.
lower
()
if
'conv'
in
cls_name
or
'linear'
in
cls_name
:
if
'conv'
in
cls_name
or
'linear'
in
cls_name
or
'embedding'
in
cls_name
:
weight_layer_count
+=
1
last_weight_layer_idx
=
idx
if
first_weight_layer_idx
==
-
1
:
...
...
@@ -63,7 +62,7 @@ class Convert:
new_attr_name
=
[
'_stride'
,
'_dilation'
,
'_groups'
,
'_param_attr'
,
'_bias_attr'
,
'_use_cudnn'
,
'_act'
,
'_dtype'
'_bias_attr'
,
'_use_cudnn'
,
'_act'
,
'_dtype'
,
'_padding'
]
new_attr_dict
=
dict
()
...
...
@@ -179,6 +178,8 @@ class Convert:
layer
.
_parameters
[
'weight'
].
shape
[
0
])
elif
self
.
context
.
channel
:
new_attr_dict
[
'num_channels'
]
=
max
(
cur_channel
)
else
:
new_attr_dict
[
'num_channels'
]
=
attr_dict
[
'_num_channels'
]
for
attr
in
new_attr_name
:
new_attr_dict
[
attr
[
1
:]]
=
attr_dict
[
attr
]
...
...
@@ -196,7 +197,8 @@ class Convert:
new_attr_name
=
[
'_stride'
,
'_dilation'
,
'_groups'
,
'_param_attr'
,
'_bias_attr'
,
'_use_cudnn'
,
'_act'
,
'_dtype'
,
'_output_size'
'_padding'
,
'_bias_attr'
,
'_use_cudnn'
,
'_act'
,
'_dtype'
,
'_output_size'
]
assert
attr_dict
[
'_filter_size'
]
!=
None
,
"Conv2DTranspose only support filter size != None now"
...
...
@@ -371,6 +373,8 @@ class Convert:
layer
.
_parameters
[
'scale'
].
shape
[
0
])
elif
self
.
context
.
channel
:
new_attr_dict
[
'num_channels'
]
=
max
(
cur_channel
)
else
:
new_attr_dict
[
'num_channels'
]
=
attr_dict
[
'_num_channels'
]
for
attr
in
new_attr_name
:
new_attr_dict
[
attr
[
1
:]]
=
attr_dict
[
attr
]
...
...
@@ -380,6 +384,76 @@ class Convert:
layer
=
SuperInstanceNorm
(
**
new_attr_dict
)
model
[
idx
]
=
layer
elif
isinstance
(
layer
,
LayerNorm
)
and
(
getattr
(
self
.
context
,
'expand'
,
None
)
!=
None
or
getattr
(
self
.
context
,
'channel'
,
None
)
!=
None
):
### TODO(ceci3): fix when normalized_shape != last_dim_of_input
if
idx
>
last_weight_layer_idx
:
continue
attr_dict
=
layer
.
__dict__
new_attr_name
=
[
'_scale'
,
'_shift'
,
'_param_attr'
,
'_bias_attr'
,
'_act'
,
'_dtype'
,
'_epsilon'
]
new_attr_dict
=
dict
()
if
self
.
context
.
expand
:
new_attr_dict
[
'normalized_shape'
]
=
self
.
context
.
expand
*
int
(
attr_dict
[
'_normalized_shape'
][
0
])
elif
self
.
context
.
channel
:
new_attr_dict
[
'normalized_shape'
]
=
max
(
cur_channel
)
else
:
new_attr_dict
[
'normalized_shape'
]
=
attr_dict
[
'_normalized_shape'
]
for
attr
in
new_attr_name
:
new_attr_dict
[
attr
[
1
:]]
=
attr_dict
[
attr
]
del
layer
,
attr_dict
layer
=
SuperLayerNorm
(
**
new_attr_dict
)
model
[
idx
]
=
layer
elif
isinstance
(
layer
,
Embedding
)
and
(
getattr
(
self
.
context
,
'expand'
,
None
)
!=
None
or
getattr
(
self
.
context
,
'channel'
,
None
)
!=
None
):
attr_dict
=
layer
.
__dict__
key
=
attr_dict
[
'_full_name'
]
new_attr_name
=
[
'_is_sparse'
,
'_is_distributed'
,
'_padding_idx'
,
'_param_attr'
,
'_dtype'
]
new_attr_dict
=
dict
()
new_attr_dict
[
'candidate_config'
]
=
dict
()
bef_size
=
attr_dict
[
'_size'
]
if
self
.
context
.
expand
:
new_attr_dict
[
'size'
]
=
[
bef_size
[
0
],
self
.
context
.
expand
*
bef_size
[
1
]
]
new_attr_dict
[
'candidate_config'
].
update
({
'expand_ratio'
:
self
.
context
.
expand_ratio
})
elif
self
.
context
.
channel
:
cur_channel
=
self
.
context
.
channel
[
0
]
self
.
context
.
channel
=
self
.
context
.
channel
[
1
:]
new_attr_dict
[
'size'
]
=
[
bef_size
[
0
],
max
(
cur_channel
)]
new_attr_dict
[
'candidate_config'
].
update
({
'channel'
:
cur_channel
})
pre_channel
=
cur_channel
else
:
new_attr_dict
[
'size'
]
=
bef_size
for
attr
in
new_attr_name
:
new_attr_dict
[
attr
[
1
:]]
=
attr_dict
[
attr
]
del
layer
,
attr_dict
layer
=
Block
(
SuperEmbedding
(
**
new_attr_dict
),
key
=
key
)
model
[
idx
]
=
layer
return
model
...
...
paddleslim/nas/ofa/layers.py
浏览文件 @
9f43bbcc
...
...
@@ -28,7 +28,7 @@ __all__ = [
'SuperConv2D'
,
'SuperConv2DTranspose'
,
'SuperSeparableConv2D'
,
'SuperBatchNorm'
,
'SuperLinear'
,
'SuperInstanceNorm'
,
'Block'
,
'SuperGroupConv2D'
,
'SuperDepthwiseConv2D'
,
'SuperGroupConv2DTranspose'
,
'SuperDepthwiseConv2DTranspose'
'SuperDepthwiseConv2DTranspose'
,
'SuperLayerNorm'
,
'SuperEmbedding'
]
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
...
...
@@ -70,9 +70,10 @@ class Block(BaseBlock):
key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
"""
def
__init__
(
self
,
fn
,
key
=
None
):
def
__init__
(
self
,
fn
,
fixed
=
False
,
key
=
None
):
super
(
Block
,
self
).
__init__
(
key
)
self
.
fn
=
fn
self
.
fixed
=
fixed
self
.
candidate_config
=
self
.
fn
.
candidate_config
def
forward
(
self
,
*
inputs
,
**
kwargs
):
...
...
@@ -208,7 +209,6 @@ class SuperConv2D(fluid.dygraph.Conv2D):
act
=
None
,
dtype
=
'float32'
):
### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain
### TODO: change padding to any padding
super
(
SuperConv2D
,
self
).
__init__
(
num_channels
,
num_filters
,
filter_size
,
stride
,
padding
,
dilation
,
groups
,
param_attr
,
bias_attr
,
use_cudnn
,
act
,
dtype
)
...
...
@@ -228,7 +228,7 @@ class SuperConv2D(fluid.dygraph.Conv2D):
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
channel
=
candidate_config
[
'channel'
]
if
'channel'
in
candidate_config
else
None
self
.
base_channel
=
None
self
.
base_channel
=
self
.
_num_filters
if
self
.
expand_ratio
!=
None
:
self
.
base_channel
=
int
(
self
.
_num_filters
/
max
(
self
.
expand_ratio
))
...
...
@@ -296,6 +296,11 @@ class SuperConv2D(fluid.dygraph.Conv2D):
if
not
in_dygraph_mode
():
_logger
.
error
(
"NOT support static graph"
)
self
.
cur_config
=
{
'kernel_size'
:
kernel_size
,
'expand_ratio'
:
expand_ratio
,
'channel'
:
channel
}
in_nc
=
int
(
input
.
shape
[
1
])
assert
(
expand_ratio
==
None
or
channel
==
None
...
...
@@ -313,7 +318,11 @@ class SuperConv2D(fluid.dygraph.Conv2D):
out_nc
)
weight
=
self
.
get_active_filter
(
weight_in_nc
,
weight_out_nc
,
ks
)
padding
=
convert_to_list
(
get_same_padding
(
ks
),
2
)
if
kernel_size
!=
None
or
'kernel_size'
in
self
.
candidate_config
.
keys
():
padding
=
convert_to_list
(
get_same_padding
(
ks
),
2
)
else
:
padding
=
self
.
_padding
if
self
.
_l_type
==
'conv2d'
:
attrs
=
(
'strides'
,
self
.
_stride
,
'paddings'
,
padding
,
'dilations'
,
...
...
@@ -488,7 +497,6 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
use_cudnn
=
True
,
act
=
None
,
dtype
=
'float32'
):
### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain
super
(
SuperConv2DTranspose
,
self
).
__init__
(
num_channels
,
num_filters
,
filter_size
,
output_size
,
padding
,
stride
,
dilation
,
groups
,
param_attr
,
bias_attr
,
use_cudnn
,
act
,
...
...
@@ -507,7 +515,7 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
channel
=
candidate_config
[
'channel'
]
if
'channel'
in
candidate_config
else
None
self
.
base_channel
=
None
self
.
base_channel
=
self
.
_num_filters
if
self
.
expand_ratio
:
self
.
base_channel
=
int
(
self
.
_num_filters
/
max
(
self
.
expand_ratio
))
...
...
@@ -572,6 +580,11 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
if
not
in_dygraph_mode
():
_logger
.
error
(
"NOT support static graph"
)
self
.
cur_config
=
{
'kernel_size'
:
kernel_size
,
'expand_ratio'
:
expand_ratio
,
'channel'
:
channel
}
in_nc
=
int
(
input
.
shape
[
1
])
assert
(
expand_ratio
==
None
or
channel
==
None
...
...
@@ -590,7 +603,10 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
out_nc
)
weight
=
self
.
get_active_filter
(
weight_in_nc
,
weight_out_nc
,
ks
)
padding
=
convert_to_list
(
get_same_padding
(
ks
),
2
)
if
kernel_size
!=
None
or
'kernel_size'
in
self
.
candidate_config
.
keys
():
padding
=
convert_to_list
(
get_same_padding
(
ks
),
2
)
else
:
padding
=
self
.
_padding
op
=
getattr
(
core
.
ops
,
self
.
_op_type
)
out
=
op
(
input
,
weight
,
'output_size'
,
self
.
_output_size
,
'strides'
,
...
...
@@ -701,7 +717,7 @@ class SuperSeparableConv2D(fluid.dygraph.Layer):
self
.
conv
.
extend
([
norm_layer
(
num_channels
*
scale_factor
)])
self
.
conv
.
extend
([
Conv2D
(
fluid
.
dygraph
.
nn
.
Conv2D
(
num_channels
=
num_channels
*
scale_factor
,
num_filters
=
num_filters
,
filter_size
=
1
,
...
...
@@ -713,14 +729,16 @@ class SuperSeparableConv2D(fluid.dygraph.Layer):
self
.
candidate_config
=
candidate_config
self
.
expand_ratio
=
candidate_config
[
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
base_output_dim
=
None
self
.
base_output_dim
=
self
.
conv
[
0
].
_num_filters
if
self
.
expand_ratio
!=
None
:
self
.
base_output_dim
=
int
(
self
.
output_dim
/
max
(
self
.
expand_ratio
))
self
.
base_output_dim
=
int
(
self
.
conv
[
0
].
_num_filters
/
max
(
self
.
expand_ratio
))
def
forward
(
self
,
input
,
expand_ratio
=
None
,
channel
=
None
):
if
not
in_dygraph_mode
():
_logger
.
error
(
"NOT support static graph"
)
self
.
cur_config
=
{
'expand_ratio'
:
expand_ratio
,
'channel'
:
channel
}
in_nc
=
int
(
input
.
shape
[
1
])
assert
(
expand_ratio
==
None
or
channel
==
None
...
...
@@ -809,7 +827,7 @@ class SuperLinear(fluid.dygraph.Linear):
self
.
candidate_config
=
candidate_config
self
.
expand_ratio
=
candidate_config
[
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
base_output_dim
=
None
self
.
base_output_dim
=
self
.
output_dim
if
self
.
expand_ratio
!=
None
:
self
.
base_output_dim
=
int
(
self
.
output_dim
/
max
(
self
.
expand_ratio
))
...
...
@@ -817,8 +835,9 @@ class SuperLinear(fluid.dygraph.Linear):
if
not
in_dygraph_mode
():
_logger
.
error
(
"NOT support static graph"
)
self
.
cur_config
=
{
'expand_ratio'
:
expand_ratio
,
'channel'
:
channel
}
### weight: (Cin, Cout)
in_nc
=
int
(
input
.
shape
[
1
])
in_nc
=
int
(
input
.
shape
[
-
1
])
assert
(
expand_ratio
==
None
or
channel
==
None
),
"expand_ratio and channel CANNOT be NOT None at the same time."
...
...
@@ -927,3 +946,77 @@ class SuperInstanceNorm(fluid.dygraph.InstanceNorm):
out
,
_
,
_
=
core
.
ops
.
instance_norm
(
input
,
scale
,
bias
,
'epsilon'
,
self
.
_epsilon
)
return
out
class
SuperLayerNorm
(
fluid
.
dygraph
.
LayerNorm
):
def
__init__
(
self
,
normalized_shape
,
candidate_config
=
{},
scale
=
True
,
shift
=
True
,
epsilon
=
1e-05
,
param_attr
=
None
,
bias_attr
=
None
,
act
=
None
,
dtype
=
'float32'
):
super
(
SuperLayerNorm
,
self
).
__init__
(
normalized_shape
,
scale
,
shift
,
epsilon
,
param_attr
,
bias_attr
,
act
,
dtype
)
def
forward
(
self
,
input
):
if
not
in_dygraph_mode
():
_logger
.
error
(
"NOT support static graph"
)
input_shape
=
list
(
input
.
shape
)
input_ndim
=
len
(
input_shape
)
normalized_ndim
=
len
(
self
.
_normalized_shape
)
self
.
_begin_norm_axis
=
input_ndim
-
normalized_ndim
### TODO(ceci3): fix if normalized_shape is not a single number
feature_dim
=
int
(
input
.
shape
[
-
1
])
weight
=
self
.
weight
[:
feature_dim
]
bias
=
self
.
bias
[:
feature_dim
]
pre_act
,
_
,
_
=
core
.
ops
.
layer_norm
(
input
,
weight
,
bias
,
'epsilon'
,
self
.
_epsilon
,
'begin_norm_axis'
,
self
.
_begin_norm_axis
)
return
dygraph_utils
.
_append_activation_in_dygraph
(
pre_act
,
act
=
self
.
_act
)
class
SuperEmbedding
(
fluid
.
dygraph
.
Embedding
):
def
__init__
(
self
,
size
,
candidate_config
=
{},
is_sparse
=
False
,
is_distributed
=
False
,
padding_idx
=
None
,
param_attr
=
None
,
dtype
=
'float32'
):
super
(
SuperEmbedding
,
self
).
__init__
(
size
,
is_sparse
,
is_distributed
,
padding_idx
,
param_attr
,
dtype
)
self
.
candidate_config
=
candidate_config
self
.
expand_ratio
=
candidate_config
[
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
base_output_dim
=
self
.
_size
[
-
1
]
if
self
.
expand_ratio
!=
None
:
self
.
base_output_dim
=
int
(
self
.
_size
[
-
1
]
/
max
(
self
.
expand_ratio
))
def
forward
(
self
,
input
,
expand_ratio
=
None
,
channel
=
None
):
if
not
in_dygraph_mode
():
_logger
.
error
(
"NOT support static graph"
)
assert
(
expand_ratio
==
None
or
channel
==
None
),
"expand_ratio and channel CANNOT be NOT None at the same time."
if
expand_ratio
!=
None
:
out_nc
=
int
(
expand_ratio
*
self
.
base_output_dim
)
elif
channel
!=
None
:
out_nc
=
int
(
channel
)
else
:
out_nc
=
self
.
_size
[
-
1
]
weight
=
self
.
weight
[:,
:
out_nc
]
return
core
.
ops
.
lookup_table_v2
(
weight
,
input
,
'is_sparse'
,
self
.
_is_sparse
,
'is_distributed'
,
self
.
_is_distributed
,
'remote_prefetch'
,
self
.
_remote_prefetch
,
'padding_idx'
,
self
.
_padding_idx
)
paddleslim/nas/ofa/ofa.py
浏览文件 @
9f43bbcc
...
...
@@ -16,7 +16,7 @@ import logging
import
numpy
as
np
from
collections
import
namedtuple
import
paddle
import
paddle.nn
as
nn
#
import paddle.nn as nn
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Conv2D
from
.layers
import
BaseBlock
,
Block
,
SuperConv2D
,
SuperBatchNorm
...
...
@@ -28,9 +28,8 @@ _logger = get_logger(__name__, level=logging.INFO)
__all__
=
[
'OFA'
,
'RunConfig'
,
'DistillConfig'
]
RunConfig
=
namedtuple
(
'RunConfig'
,
[
'train_batch_size'
,
'eval_batch_size'
,
'n_epochs'
,
'save_frequency'
,
'eval_frequency'
,
'init_learning_rate'
,
'total_images'
,
'elastic_depth'
,
'dynamic_batch_size'
'train_batch_size'
,
'n_epochs'
,
'save_frequency'
,
'eval_frequency'
,
'init_learning_rate'
,
'total_images'
,
'elastic_depth'
,
'dynamic_batch_size'
])
RunConfig
.
__new__
.
__defaults__
=
(
None
,
)
*
len
(
RunConfig
.
_fields
)
...
...
@@ -53,20 +52,26 @@ class OFABase(fluid.dygraph.Layer):
for
name
,
sublayer
in
self
.
model
.
named_sublayers
():
if
isinstance
(
sublayer
,
BaseBlock
):
sublayer
.
set_supernet
(
self
)
layers
[
sublayer
.
key
]
=
sublayer
.
candidate_config
for
k
in
sublayer
.
candidate_config
.
keys
():
elastic_task
.
add
(
k
)
if
not
sublayer
.
fixed
:
layers
[
sublayer
.
key
]
=
sublayer
.
candidate_config
for
k
in
sublayer
.
candidate_config
.
keys
():
elastic_task
.
add
(
k
)
return
layers
,
elastic_task
def
forward
(
self
,
*
inputs
,
**
kwargs
):
raise
NotImplementedError
# NOTE: config means set forward config for layers, used in distill.
def
layers_forward
(
self
,
block
,
*
inputs
,
**
kwargs
):
if
getattr
(
self
,
'current_config'
,
None
)
!=
None
:
assert
block
.
key
in
self
.
current_config
,
'DONNT have {} layer in config.'
.
format
(
block
.
key
)
config
=
self
.
current_config
[
block
.
key
]
### if block is fixed, donnot join key into candidate
### concrete config as parameter in kwargs
if
block
.
fixed
==
False
:
assert
block
.
key
in
self
.
current_config
,
'DONNT have {} layer in config.'
.
format
(
block
.
key
)
config
=
self
.
current_config
[
block
.
key
]
else
:
config
=
dict
()
config
.
update
(
kwargs
)
else
:
config
=
dict
()
logging
.
debug
(
self
.
model
,
config
)
...
...
@@ -81,7 +86,7 @@ class OFABase(fluid.dygraph.Layer):
class
OFA
(
OFABase
):
def
__init__
(
self
,
model
,
run_config
,
run_config
=
None
,
net_config
=
None
,
distill_config
=
None
,
elastic_order
=
None
,
...
...
@@ -92,7 +97,6 @@ class OFA(OFABase):
self
.
distill_config
=
distill_config
self
.
elastic_order
=
elastic_order
self
.
train_full
=
train_full
self
.
iter_per_epochs
=
self
.
run_config
.
total_images
//
self
.
run_config
.
train_batch_size
self
.
iter
=
0
self
.
dynamic_iter
=
0
self
.
manual_set_task
=
False
...
...
@@ -100,18 +104,16 @@ class OFA(OFABase):
self
.
_add_teacher
=
False
self
.
netAs_param
=
[]
for
idx
in
range
(
len
(
run_config
.
n_epochs
)):
assert
isinstance
(
run_config
.
init_learning_rate
[
idx
],
list
),
"each candidate in init_learning_rate must be list"
assert
isinstance
(
run_config
.
n_epochs
[
idx
],
list
),
"each candidate in n_epochs must be list"
### if elastic_order is none, use default order
if
self
.
elastic_order
is
not
None
:
assert
isinstance
(
self
.
elastic_order
,
list
),
'elastic_order must be a list'
if
getattr
(
self
.
run_config
,
'elastic_depth'
,
None
)
!=
None
:
depth_list
=
list
(
set
(
self
.
run_config
.
elastic_depth
))
depth_list
.
sort
()
self
.
layers
[
'depth'
]
=
depth_list
if
self
.
elastic_order
is
None
:
self
.
elastic_order
=
[]
# zero, elastic resulotion, write in demo
...
...
@@ -133,16 +135,26 @@ class OFA(OFABase):
if
'channel'
in
self
.
_elastic_task
and
'width'
not
in
self
.
elastic_order
:
self
.
elastic_order
.
append
(
'width'
)
assert
len
(
self
.
run_config
.
n_epochs
)
==
len
(
self
.
elastic_order
)
assert
len
(
self
.
run_config
.
n_epochs
)
==
len
(
self
.
run_config
.
dynamic_batch_size
)
assert
len
(
self
.
run_config
.
n_epochs
)
==
len
(
self
.
run_config
.
init_learning_rate
)
if
getattr
(
self
.
run_config
,
'n_epochs'
,
None
)
!=
None
:
assert
len
(
self
.
run_config
.
n_epochs
)
==
len
(
self
.
elastic_order
)
for
idx
in
range
(
len
(
run_config
.
n_epochs
)):
assert
isinstance
(
run_config
.
n_epochs
[
idx
],
list
),
"each candidate in n_epochs must be list"
if
self
.
run_config
.
dynamic_batch_size
!=
None
:
assert
len
(
self
.
run_config
.
n_epochs
)
==
len
(
self
.
run_config
.
dynamic_batch_size
)
if
self
.
run_config
.
init_learning_rate
!=
None
:
assert
len
(
self
.
run_config
.
n_epochs
)
==
len
(
self
.
run_config
.
init_learning_rate
)
for
idx
in
range
(
len
(
run_config
.
n_epochs
)):
assert
isinstance
(
run_config
.
init_learning_rate
[
idx
],
list
),
"each candidate in init_learning_rate must be list"
### ================= add distill prepare ======================
if
self
.
distill_config
!=
None
and
(
self
.
distill_config
.
lambda_distill
!=
None
and
self
.
distill_config
.
lambda_distill
>
0
):
if
self
.
distill_config
!=
None
:
self
.
_add_teacher
=
True
self
.
_prepare_distill
()
...
...
@@ -153,9 +165,10 @@ class OFA(OFABase):
if
self
.
distill_config
.
teacher_model
==
None
:
logging
.
error
(
'If you want to add distill, please input
class
of teacher model'
'If you want to add distill, please input
instance
of teacher model'
)
### instance model by user can input super-param easily.
assert
isinstance
(
self
.
distill_config
.
teacher_model
,
paddle
.
fluid
.
dygraph
.
Layer
)
...
...
@@ -171,7 +184,7 @@ class OFA(OFABase):
# add hook if mapping layers is not None
# if mapping layer is None, return the output of the teacher model,
# if mapping layer is NOT None, add hook and compute distill loss about mapping layers.
mapping_layers
=
self
.
distill_config
.
mapping_layers
mapping_layers
=
getattr
(
self
.
distill_config
,
'mapping_layers'
,
None
)
if
mapping_layers
!=
None
:
self
.
netAs
=
[]
for
name
,
sublayer
in
self
.
model
.
named_sublayers
():
...
...
@@ -199,9 +212,16 @@ class OFA(OFABase):
def
_compute_epochs
(
self
):
if
getattr
(
self
,
'epoch'
,
None
)
==
None
:
assert
self
.
run_config
.
total_images
is
not
None
,
\
"if not use set_epoch() to set epoch, please set total_images in run_config."
assert
self
.
run_config
.
train_batch_size
is
not
None
,
\
"if not use set_epoch() to set epoch, please set train_batch_size in run_config."
assert
self
.
run_config
.
n_epochs
is
not
None
,
\
"if not use set_epoch() to set epoch, please set n_epochs in run_config."
self
.
iter_per_epochs
=
self
.
run_config
.
total_images
//
self
.
run_config
.
train_batch_size
epoch
=
self
.
iter
//
self
.
iter_per_epochs
else
:
epoch
=
self
.
epoch
s
epoch
=
self
.
epoch
return
epoch
def
_sample_from_nestdict
(
self
,
cands
,
sample_type
,
task
,
phase
):
...
...
@@ -284,6 +304,9 @@ class OFA(OFABase):
def
export
(
self
,
config
):
pass
def
set_net_config
(
self
,
net_config
):
self
.
net_config
=
net_config
def
forward
(
self
,
*
inputs
,
**
kwargs
):
# ===================== teacher process =====================
teacher_output
=
None
...
...
@@ -293,11 +316,12 @@ class OFA(OFABase):
# ============================================================
# ==================== student process =====================
self
.
dynamic_iter
+=
1
if
self
.
dynamic_iter
==
self
.
run_config
.
dynamic_batch_size
[
self
.
task_idx
]:
self
.
iter
+=
1
self
.
dynamic_iter
=
0
if
getattr
(
self
.
run_config
,
'dynamic_batch_size'
,
None
)
!=
None
:
self
.
dynamic_iter
+=
1
if
self
.
dynamic_iter
==
self
.
run_config
.
dynamic_batch_size
[
self
.
task_idx
]:
self
.
iter
+=
1
self
.
dynamic_iter
=
0
if
self
.
net_config
==
None
:
if
self
.
train_full
==
True
:
...
...
@@ -314,6 +338,6 @@ class OFA(OFABase):
_logger
.
debug
(
"Current config is {}"
.
format
(
self
.
current_config
))
if
'depth'
in
self
.
current_config
:
kwargs
[
'depth'
]
=
int
(
self
.
current_config
[
'depth'
])
kwargs
[
'depth'
]
=
self
.
current_config
[
'depth'
]
return
self
.
model
.
forward
(
*
inputs
,
**
kwargs
),
teacher_output
tests/test_ofa.py
浏览文件 @
9f43bbcc
...
...
@@ -17,7 +17,6 @@ sys.path.append("../")
import
numpy
as
np
import
unittest
import
paddle
from
static_case
import
StaticCase
import
paddle.fluid
as
fluid
import
paddle.fluid.dygraph.nn
as
nn
from
paddle.nn
import
ReLU
...
...
@@ -35,13 +34,16 @@ class ModelConv(fluid.dygraph.Layer):
channel
=
((
4
,
8
,
12
),
(
8
,
12
,
16
),
(
8
,
12
,
16
),
(
8
,
12
,
16
)))
as
ofa_super
:
models
=
[]
models
+=
[
nn
.
Conv2D
(
3
,
4
,
3
)]
models
+=
[
nn
.
Conv2D
(
3
,
4
,
3
,
padding
=
1
)]
models
+=
[
nn
.
InstanceNorm
(
4
)]
models
+=
[
ReLU
()]
models
+=
[
nn
.
Conv2D
(
4
,
4
,
3
,
groups
=
4
)]
models
+=
[
nn
.
InstanceNorm
(
4
)]
models
+=
[
ReLU
()]
models
+=
[
nn
.
Conv2DTranspose
(
4
,
4
,
3
,
groups
=
4
,
use_cudnn
=
True
)]
models
+=
[
nn
.
Conv2DTranspose
(
4
,
4
,
3
,
groups
=
4
,
padding
=
1
,
use_cudnn
=
True
)
]
models
+=
[
nn
.
BatchNorm
(
4
)]
models
+=
[
ReLU
()]
models
+=
[
nn
.
Conv2D
(
4
,
3
,
3
)]
...
...
@@ -51,7 +53,8 @@ class ModelConv(fluid.dygraph.Layer):
models
+=
[
Block
(
SuperSeparableConv2D
(
3
,
6
,
1
,
candidate_config
=
{
'channel'
:
(
3
,
6
)}))
3
,
6
,
1
,
padding
=
1
,
candidate_config
=
{
'channel'
:
(
3
,
6
)}),
fixed
=
True
)
]
with
supernet
(
kernel_size
=
(
3
,
5
,
7
),
expand_ratio
=
(
1
,
2
,
4
))
as
ofa_super
:
...
...
@@ -92,15 +95,37 @@ class ModelLinear(fluid.dygraph.Layer):
models
=
[]
with
supernet
(
expand_ratio
=
(
1
,
2
,
4
))
as
ofa_super
:
models1
=
[]
models1
+=
[
nn
.
Embedding
(
size
=
(
64
,
64
))]
models1
+=
[
nn
.
Linear
(
64
,
128
)]
models1
+=
[
nn
.
LayerNorm
(
128
)]
models1
+=
[
nn
.
Linear
(
128
,
256
)]
models1
=
ofa_super
.
convert
(
models1
)
models
+=
models1
self
.
models
=
paddle
.
nn
.
Sequential
(
*
models
)
def
forward
(
self
,
inputs
,
depth
=
None
):
if
depth
!=
None
:
assert
isinstance
(
depth
,
int
)
assert
depth
<
len
(
self
.
models
)
else
:
depth
=
len
(
self
.
models
)
for
idx
in
range
(
depth
):
layer
=
self
.
models
[
idx
]
inputs
=
layer
(
inputs
)
return
inputs
with
supernet
(
channel
=
((
64
,
128
,
256
),
(
64
,
128
,
256
)))
as
ofa_super
:
class
ModelLinear1
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
(
ModelLinear1
,
self
).
__init__
()
models
=
[]
with
supernet
(
channel
=
((
64
,
128
,
256
),
(
64
,
128
,
256
),
(
64
,
128
,
256
)))
as
ofa_super
:
models1
=
[]
models1
+=
[
nn
.
Linear
(
256
,
128
)]
models1
+=
[
nn
.
Embedding
(
size
=
(
64
,
64
))]
models1
+=
[
nn
.
Linear
(
64
,
128
)]
models1
+=
[
nn
.
LayerNorm
(
128
)]
models1
+=
[
nn
.
Linear
(
128
,
256
)]
models1
=
ofa_super
.
convert
(
models1
)
...
...
@@ -120,7 +145,35 @@ class ModelLinear(fluid.dygraph.Layer):
return
inputs
class
TestOFA
(
StaticCase
):
class
ModelLinear2
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
):
super
(
ModelLinear2
,
self
).
__init__
()
models
=
[]
with
supernet
(
expand_ratio
=
None
)
as
ofa_super
:
models1
=
[]
models1
+=
[
nn
.
Embedding
(
size
=
(
64
,
64
))]
models1
+=
[
nn
.
Linear
(
64
,
128
)]
models1
+=
[
nn
.
LayerNorm
(
128
)]
models1
+=
[
nn
.
Linear
(
128
,
256
)]
models1
=
ofa_super
.
convert
(
models1
)
models
+=
models1
self
.
models
=
paddle
.
nn
.
Sequential
(
*
models
)
def
forward
(
self
,
inputs
,
depth
=
None
):
if
depth
!=
None
:
assert
isinstance
(
depth
,
int
)
assert
depth
<
len
(
self
.
models
)
else
:
depth
=
len
(
self
.
models
)
for
idx
in
range
(
depth
):
layer
=
self
.
models
[
idx
]
inputs
=
layer
(
inputs
)
return
inputs
class
TestOFA
(
unittest
.
TestCase
):
def
setUp
(
self
):
fluid
.
enable_dygraph
()
self
.
init_model_and_data
()
...
...
@@ -137,7 +190,6 @@ class TestOFA(StaticCase):
def
init_config
(
self
):
default_run_config
=
{
'train_batch_size'
:
1
,
'eval_batch_size'
:
1
,
'n_epochs'
:
[[
1
],
[
2
,
3
],
[
4
,
5
]],
'init_learning_rate'
:
[[
0.001
],
[
0.003
,
0.001
],
[
0.003
,
0.001
]],
'dynamic_batch_size'
:
[
1
,
1
,
1
],
...
...
@@ -152,11 +204,13 @@ class TestOFA(StaticCase):
'mapping_layers'
:
[
'models.0.fn'
]
}
self
.
distill_config
=
DistillConfig
(
**
default_distill_config
)
self
.
elastic_order
=
[
'kernel_size'
,
'width'
,
'depth'
]
def
test_ofa
(
self
):
ofa_model
=
OFA
(
self
.
model
,
self
.
run_config
,
distill_config
=
self
.
distill_config
)
distill_config
=
self
.
distill_config
,
elastic_order
=
self
.
elastic_order
)
start_epoch
=
0
for
idx
in
range
(
len
(
self
.
run_config
.
n_epochs
)):
...
...
@@ -169,6 +223,8 @@ class TestOFA(StaticCase):
ofa_model
.
parameters
()
+
ofa_model
.
netAs_param
))
for
epoch_id
in
range
(
start_epoch
,
self
.
run_config
.
n_epochs
[
idx
][
ph_idx
]):
if
epoch_id
==
0
:
ofa_model
.
set_epoch
(
epoch_id
)
for
model_no
in
range
(
self
.
run_config
.
dynamic_batch_size
[
idx
]):
output
,
_
=
ofa_model
(
self
.
data
)
...
...
@@ -191,14 +247,13 @@ class TestOFACase1(TestOFA):
def
init_model_and_data
(
self
):
self
.
model
=
ModelLinear
()
self
.
teacher_model
=
ModelLinear
()
data_np
=
np
.
random
.
random
((
3
,
64
)).
astype
(
np
.
float32
)
data_np
=
np
.
random
.
random
((
3
,
64
)).
astype
(
np
.
int64
)
self
.
data
=
fluid
.
dygraph
.
to_variable
(
data_np
)
def
init_config
(
self
):
default_run_config
=
{
'train_batch_size'
:
1
,
'eval_batch_size'
:
1
,
'n_epochs'
:
[[
2
,
5
]],
'init_learning_rate'
:
[[
0.003
,
0.001
]],
'dynamic_batch_size'
:
[
1
],
...
...
@@ -211,6 +266,23 @@ class TestOFACase1(TestOFA):
'teacher_model'
:
self
.
teacher_model
,
}
self
.
distill_config
=
DistillConfig
(
**
default_distill_config
)
self
.
elastic_order
=
None
class
TestOFACase2
(
TestOFACase1
):
def
init_model_and_data
(
self
):
self
.
model
=
ModelLinear1
()
self
.
teacher_model
=
ModelLinear1
()
data_np
=
np
.
random
.
random
((
3
,
64
)).
astype
(
np
.
int64
)
self
.
data
=
fluid
.
dygraph
.
to_variable
(
data_np
)
class
TestOFACase3
(
unittest
.
TestCase
):
def
test_ofa
(
self
):
self
.
model
=
ModelLinear2
()
ofa_model
=
OFA
(
self
.
model
)
ofa_model
.
set_net_config
({
'expand_ratio'
:
None
})
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录