Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
e274d7fb
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
接近 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e274d7fb
编写于
6月 30, 2021
作者:
C
Chang Xu
提交者:
GitHub
6月 30, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
new_export_func (#824)
* new_export_func * add_test * remove_kernel_prune * add_test_&_update_clear_ss
上级
da0f5a6d
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
220 addition
and
179 deletion
+220
-179
paddleslim/core/dygraph.py
paddleslim/core/dygraph.py
+2
-1
paddleslim/nas/ofa/get_sub_model.py
paddleslim/nas/ofa/get_sub_model.py
+1
-118
paddleslim/nas/ofa/layers.py
paddleslim/nas/ofa/layers.py
+19
-6
paddleslim/nas/ofa/ofa.py
paddleslim/nas/ofa/ofa.py
+116
-39
tests/test_ofa.py
tests/test_ofa.py
+34
-14
tests/test_ofa_v2.py
tests/test_ofa_v2.py
+48
-1
未找到文件。
paddleslim/core/dygraph.py
浏览文件 @
e274d7fb
...
...
@@ -64,6 +64,7 @@ def extract_vars(inputs):
f
"Variable is excepted, but get an element with type(
{
type
(
_value
)
}
) from inputs whose type is dict. And the key of element is
{
_key
}
."
)
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
for
_value
in
inputs
:
vars
.
extend
(
extract_vars
(
_value
))
if
len
(
vars
)
==
0
:
...
...
@@ -99,7 +100,6 @@ def dygraph2program(layer,
extract_outputs_fn
=
None
,
dtypes
=
None
):
assert
isinstance
(
layer
,
Layer
)
extract_inputs_fn
=
extract_inputs_fn
if
extract_inputs_fn
is
not
None
else
extract_vars
extract_outputs_fn
=
extract_outputs_fn
if
extract_outputs_fn
is
not
None
else
extract_vars
tracer
=
_dygraph_tracer
().
_get_program_desc_tracer
()
...
...
@@ -116,6 +116,7 @@ def dygraph2program(layer,
else
:
inputs
=
to_variables
(
inputs
)
input_var_list
=
extract_inputs_fn
(
inputs
)
original_outputs
=
layer
(
*
inputs
)
# 'original_outputs' may be dict, so we should convert it to list of varibles.
# And should not create new varibles in 'extract_vars'.
...
...
paddleslim/nas/ofa/get_sub_model.py
浏览文件 @
e274d7fb
...
...
@@ -17,7 +17,7 @@ import paddle
from
paddle.fluid
import
core
from
.layers_base
import
BaseBlock
__all__
=
[
'
get_prune_params_config'
,
'prune_params'
,
'
check_search_space'
]
__all__
=
[
'check_search_space'
]
WEIGHT_OP
=
[
'conv2d'
,
'linear'
,
'embedding'
,
'conv2d_transpose'
,
'depthwise_conv2d'
...
...
@@ -28,63 +28,6 @@ CONV_TYPES = [
]
def
get_prune_params_config
(
graph
,
origin_model_config
):
""" Convert config of search space to parameters' prune config.
"""
param_config
=
{}
precedor
=
None
for
op
in
graph
.
ops
():
### TODO(ceci3):
### 1. fix config when this op is concat by graph.pre_ops(op)
### 2. add kernel_size in config
for
inp
in
op
.
all_inputs
():
n_ops
=
graph
.
next_ops
(
op
)
if
inp
.
_var
.
name
in
origin_model_config
.
keys
():
if
'expand_ratio'
in
origin_model_config
[
inp
.
_var
.
name
]
or
'channel'
in
origin_model_config
[
inp
.
_var
.
name
]:
key
=
'channel'
if
'channel'
in
origin_model_config
[
inp
.
_var
.
name
]
else
'expand_ratio'
tmp
=
origin_model_config
[
inp
.
_var
.
name
][
key
]
if
len
(
inp
.
_var
.
shape
)
>
1
:
if
inp
.
_var
.
name
in
param_config
.
keys
():
param_config
[
inp
.
_var
.
name
].
append
(
tmp
)
### first op
else
:
param_config
[
inp
.
_var
.
name
]
=
[
precedor
,
tmp
]
else
:
param_config
[
inp
.
_var
.
name
]
=
[
tmp
]
precedor
=
tmp
else
:
precedor
=
None
for
n_op
in
n_ops
:
for
next_inp
in
n_op
.
all_inputs
():
if
next_inp
.
_var
.
persistable
==
True
:
if
next_inp
.
_var
.
name
in
origin_model_config
.
keys
():
if
'expand_ratio'
in
origin_model_config
[
next_inp
.
_var
.
name
]
or
'channel'
in
origin_model_config
[
next_inp
.
_var
.
name
]:
key
=
'channel'
if
'channel'
in
origin_model_config
[
next_inp
.
_var
.
name
]
else
'expand_ratio'
tmp
=
origin_model_config
[
next_inp
.
_var
.
name
][
key
]
pre
=
tmp
if
precedor
is
None
else
precedor
if
len
(
next_inp
.
_var
.
shape
)
>
1
:
param_config
[
next_inp
.
_var
.
name
]
=
[
pre
]
else
:
param_config
[
next_inp
.
_var
.
name
]
=
[
tmp
]
else
:
if
len
(
next_inp
.
_var
.
shape
)
>
1
and
precedor
!=
None
:
param_config
[
next_inp
.
_var
.
name
]
=
[
precedor
,
None
]
else
:
param_config
[
next_inp
.
_var
.
name
]
=
[
precedor
]
return
param_config
def
get_actual_shape
(
transform
,
channel
):
if
transform
==
None
:
channel
=
int
(
channel
)
...
...
@@ -96,66 +39,6 @@ def get_actual_shape(transform, channel):
return
channel
def
prune_params
(
model
,
param_config
,
super_model_sd
=
None
):
""" Prune parameters according to the config.
Parameters:
model(paddle.nn.Layer): instance of model.
param_config(dict): prune config of each weight.
super_model_sd(dict, optional): parameters come from supernet. If super_model_sd is not None, transfer parameters from this dict to model; otherwise, prune model from itself.
"""
for
l_name
,
sublayer
in
model
.
named_sublayers
():
if
isinstance
(
sublayer
,
BaseBlock
):
continue
for
p_name
,
param
in
sublayer
.
named_parameters
(
include_sublayers
=
False
):
t_value
=
param
.
value
().
get_tensor
()
value
=
np
.
array
(
t_value
).
astype
(
"float32"
)
if
super_model_sd
!=
None
:
name
=
l_name
+
'.'
+
p_name
super_t_value
=
super_model_sd
[
name
].
value
().
get_tensor
()
super_value
=
np
.
array
(
super_t_value
).
astype
(
"float32"
)
super_model_sd
.
pop
(
name
)
if
param
.
name
in
param_config
.
keys
():
if
len
(
param_config
[
param
.
name
])
>
1
:
in_exp
=
param_config
[
param
.
name
][
0
]
out_exp
=
param_config
[
param
.
name
][
1
]
if
sublayer
.
__class__
.
__name__
.
lower
()
in
CONV_TYPES
:
in_chn
=
get_actual_shape
(
in_exp
,
value
.
shape
[
1
])
out_chn
=
get_actual_shape
(
out_exp
,
value
.
shape
[
0
])
prune_value
=
super_value
[:
out_chn
,
:
in_chn
,
...]
\
if
super_model_sd
!=
None
else
value
[:
out_chn
,
:
in_chn
,
...]
else
:
in_chn
=
get_actual_shape
(
in_exp
,
value
.
shape
[
0
])
out_chn
=
get_actual_shape
(
out_exp
,
value
.
shape
[
1
])
prune_value
=
super_value
[:
in_chn
,
:
out_chn
,
...]
\
if
super_model_sd
!=
None
else
value
[:
in_chn
,
:
out_chn
,
...]
else
:
out_chn
=
get_actual_shape
(
param_config
[
param
.
name
][
0
],
value
.
shape
[
0
])
prune_value
=
super_value
[:
out_chn
,
...]
\
if
super_model_sd
!=
None
else
value
[:
out_chn
,
...]
else
:
prune_value
=
super_value
if
super_model_sd
!=
None
else
value
p
=
t_value
.
_place
()
if
p
.
is_cpu_place
():
place
=
core
.
CPUPlace
()
elif
p
.
is_cuda_pinned_place
():
place
=
core
.
CUDAPinnedPlace
()
else
:
place
=
core
.
CUDAPlace
(
p
.
gpu_device_id
())
t_value
.
set
(
prune_value
,
place
)
if
param
.
trainable
:
param
.
clear_gradient
()
### initialize param which not in sublayers, such as create persistable inputs by create_parameters
if
super_model_sd
!=
None
and
len
(
super_model_sd
)
!=
0
:
for
k
,
v
in
super_model_sd
.
items
():
setattr
(
model
,
k
,
v
)
def
_is_depthwise
(
op
):
"""Check if this op is depthwise conv. Only Cin == Cout == groups be consider as depthwise conv.
The shape of input and the shape of output in depthwise conv must be same in superlayer,
...
...
paddleslim/nas/ofa/layers.py
浏览文件 @
e274d7fb
...
...
@@ -177,6 +177,7 @@ class SuperConv2D(nn.Conv2D):
data_format
=
data_format
)
self
.
candidate_config
=
candidate_config
self
.
cur_config
=
None
if
len
(
candidate_config
.
items
())
!=
0
:
for
k
,
v
in
candidate_config
.
items
():
candidate_config
[
k
]
=
list
(
set
(
v
))
...
...
@@ -314,7 +315,7 @@ class SuperConv2D(nn.Conv2D):
bias
=
self
.
bias
[:
weight_out_nc
]
else
:
bias
=
self
.
bias
self
.
cur_config
[
'prune_dim'
]
=
list
(
weight
.
shape
)
out
=
F
.
conv2d
(
input
,
weight
,
...
...
@@ -482,6 +483,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
data_format
=
data_format
)
self
.
candidate_config
=
candidate_config
self
.
cur_config
=
None
if
len
(
self
.
candidate_config
.
items
())
!=
0
:
for
k
,
v
in
candidate_config
.
items
():
candidate_config
[
k
]
=
list
(
set
(
v
))
...
...
@@ -620,7 +622,7 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
bias
=
self
.
bias
[:
weight_out_nc
]
else
:
bias
=
self
.
bias
self
.
cur_config
[
'prune_dim'
]
=
list
(
weight
.
shape
)
out
=
F
.
conv2d_transpose
(
input
,
weight
,
...
...
@@ -733,6 +735,7 @@ class SuperSeparableConv2D(nn.Layer):
])
self
.
candidate_config
=
candidate_config
self
.
cur_config
=
None
self
.
expand_ratio
=
candidate_config
[
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
base_output_dim
=
self
.
conv
[
0
].
_out_channels
...
...
@@ -784,7 +787,7 @@ class SuperSeparableConv2D(nn.Layer):
bias
=
self
.
conv
[
2
].
bias
[:
out_nc
]
else
:
bias
=
self
.
conv
[
2
].
bias
self
.
cur_config
[
'prune_dim'
]
=
list
(
weight
.
shape
)
conv1_out
=
F
.
conv2d
(
norm_out
,
weight
,
...
...
@@ -864,6 +867,7 @@ class SuperLinear(nn.Linear):
self
.
_in_features
=
in_features
self
.
_out_features
=
out_features
self
.
candidate_config
=
candidate_config
self
.
cur_config
=
None
self
.
expand_ratio
=
candidate_config
[
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
base_output_dim
=
self
.
_out_features
...
...
@@ -896,7 +900,7 @@ class SuperLinear(nn.Linear):
bias
=
self
.
bias
[:
out_nc
]
else
:
bias
=
self
.
bias
self
.
cur_config
[
'prune_dim'
]
=
list
(
weight
.
shape
)
out
=
F
.
linear
(
x
=
input
,
weight
=
weight
,
bias
=
bias
,
name
=
self
.
name
)
return
out
...
...
@@ -945,6 +949,7 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
super
(
SuperBatchNorm2D
,
self
).
__init__
(
num_features
,
momentum
,
epsilon
,
weight_attr
,
bias_attr
,
data_format
,
use_global_stats
,
name
)
self
.
cur_config
=
None
def
forward
(
self
,
input
):
self
.
_check_data_format
(
self
.
_data_format
)
...
...
@@ -956,7 +961,7 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
bias
=
self
.
bias
[:
feature_dim
]
mean
=
self
.
_mean
[:
feature_dim
]
variance
=
self
.
_variance
[:
feature_dim
]
self
.
cur_config
=
{
'prune_dim'
:
feature_dim
}
return
F
.
batch_norm
(
input
,
mean
,
...
...
@@ -982,6 +987,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm):
super
(
SuperSyncBatchNorm
,
self
).
__init__
(
num_features
,
momentum
,
epsilon
,
weight_attr
,
bias_attr
,
data_format
,
name
)
self
.
cur_config
=
None
def
forward
(
self
,
input
):
...
...
@@ -995,6 +1001,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm):
mean_out
=
mean
# variance and variance out share the same memory
variance_out
=
variance
self
.
cur_config
=
{
'prune_dim'
:
feature_dim
}
attrs
=
(
"momentum"
,
self
.
_momentum
,
"epsilon"
,
self
.
_epsilon
,
"is_test"
,
not
self
.
training
,
"data_layout"
,
self
.
_data_format
,
...
...
@@ -1049,6 +1056,7 @@ class SuperInstanceNorm2D(nn.InstanceNorm2D):
super
(
SuperInstanceNorm2D
,
self
).
__init__
(
num_features
,
epsilon
,
momentum
,
weight_attr
,
bias_attr
,
data_format
,
name
)
self
.
cur_config
=
None
def
forward
(
self
,
input
):
self
.
_check_input_dim
(
input
)
...
...
@@ -1060,7 +1068,7 @@ class SuperInstanceNorm2D(nn.InstanceNorm2D):
else
:
scale
=
self
.
scale
[:
feature_dim
]
bias
=
self
.
bias
[:
feature_dim
]
self
.
cur_config
=
{
'prune_dim'
:
feature_dim
}
return
F
.
instance_norm
(
input
,
scale
,
bias
,
eps
=
self
.
_epsilon
)
...
...
@@ -1112,6 +1120,7 @@ class SuperLayerNorm(nn.LayerNorm):
name
=
None
):
super
(
SuperLayerNorm
,
self
).
__init__
(
normalized_shape
,
epsilon
,
weight_attr
,
bias_attr
,
name
)
self
.
cur_config
=
None
def
forward
(
self
,
input
):
### TODO(ceci3): fix if normalized_shape is not a single number
...
...
@@ -1127,6 +1136,8 @@ class SuperLayerNorm(nn.LayerNorm):
bias
=
self
.
bias
[:
feature_dim
]
else
:
bias
=
None
self
.
cur_config
=
{
'prune_dim'
:
feature_dim
}
out
,
_
,
_
=
core
.
ops
.
layer_norm
(
input
,
weight
,
bias
,
'epsilon'
,
self
.
_epsilon
,
'begin_norm_axis'
,
begin_norm_axis
)
...
...
@@ -1191,6 +1202,7 @@ class SuperEmbedding(nn.Embedding):
padding_idx
,
sparse
,
weight_attr
,
name
)
self
.
candidate_config
=
candidate_config
self
.
cur_config
=
None
self
.
expand_ratio
=
candidate_config
[
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
base_output_dim
=
self
.
_embedding_dim
...
...
@@ -1216,6 +1228,7 @@ class SuperEmbedding(nn.Embedding):
out_nc
=
self
.
_embedding_dim
weight
=
self
.
weight
[:,
:
out_nc
]
self
.
cur_config
=
{
'prune_dim'
:
list
(
weight
.
shape
)}
return
F
.
embedding
(
input
,
weight
=
weight
,
...
...
paddleslim/nas/ofa/ofa.py
浏览文件 @
e274d7fb
...
...
@@ -27,11 +27,14 @@ else:
from
.layers
import
SuperConv2D
,
SuperLinear
Layer
=
paddle
.
nn
.
Layer
DataParallel
=
paddle
.
DataParallel
from
.layers_base
import
BaseBlock
from
.layers_base
import
BaseBlock
,
Block
from
.utils.utils
import
search_idx
from
...common
import
get_logger
from
...core
import
GraphWrapper
,
dygraph2program
from
.get_sub_model
import
get_prune_params_config
,
prune_params
,
check_search_space
,
broadcast_search_space
from
.get_sub_model
import
check_search_space
,
broadcast_search_space
from
paddle.fluid
import
core
from
paddle.fluid.framework
import
Variable
import
numbers
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
...
...
@@ -459,35 +462,41 @@ class OFA(OFABase):
def
search
(
self
,
eval_func
,
condition
):
pass
def
_export_sub_model_config
(
self
,
origin_model
,
config
,
input_shapes
,
input_dtypes
):
param2name
=
{}
for
name
,
sublayer
in
origin_model
.
named_sublayers
():
for
param
in
sublayer
.
parameters
(
include_sublayers
=
False
):
if
name
.
split
(
'.'
)[
-
1
]
==
'fn'
:
### if sublayer is Block, the name of the param.name has 'fn', the config always donnot have 'fn'
param2name
[
param
.
name
]
=
name
[:
-
3
]
else
:
param2name
[
param
.
name
]
=
name
def
_get_model_pruned_weight
(
self
):
program
=
dygraph2program
(
origin_model
,
inputs
=
input_shapes
,
dtypes
=
input_dtypes
)
graph
=
GraphWrapper
(
program
)
pruned_param
=
{}
for
l_name
,
sublayer
in
self
.
model
.
named_sublayers
():
same_config
,
_
=
check_search_space
(
graph
)
if
same_config
!=
None
:
broadcast_search_space
(
same_config
,
param2name
,
config
)
if
getattr
(
sublayer
,
'cur_config'
,
None
)
==
None
:
continue
origin_model_config
=
{}
for
name
,
sublayer
in
origin_model
.
named_sublayers
():
if
isinstance
(
sublayer
,
BaseBlock
):
sublayer
=
sublayer
.
fn
for
param
in
sublayer
.
parameters
(
include_sublayers
=
False
):
if
name
in
config
.
keys
():
origin_model_config
[
param
.
name
]
=
config
[
name
]
assert
'prune_dim'
in
sublayer
.
cur_config
,
'The laycer {} do not have prune_dim in cur_config.'
.
format
(
l_name
)
prune_shape
=
sublayer
.
cur_config
[
'prune_dim'
]
for
p_name
,
param
in
sublayer
.
named_parameters
(
include_sublayers
=
False
):
origin_param
=
param
.
value
().
get_tensor
()
param
=
np
.
array
(
origin_param
).
astype
(
"float32"
)
name
=
l_name
+
'.'
+
p_name
if
isinstance
(
prune_shape
,
list
):
param_prune_config
=
get_prune_params_config
(
graph
,
origin_model_config
)
return
param_prune_config
if
len
(
param
.
shape
)
==
4
:
pruned_param
[
name
]
=
param
[:
prune_shape
[
0
],
:
prune_shape
[
1
],
:,
:]
elif
len
(
param
.
shape
)
==
2
:
pruned_param
[
name
]
=
param
[:
prune_shape
[
0
],
:
prune_shape
[
1
]]
else
:
if
isinstance
(
sublayer
,
SuperLinear
):
pruned_param
[
name
]
=
param
[:
prune_shape
[
1
]]
else
:
pruned_param
[
name
]
=
param
[:
prune_shape
[
0
]]
else
:
pruned_param
[
name
]
=
param
[:
prune_shape
]
return
pruned_param
def
export
(
self
,
config
,
...
...
@@ -510,17 +519,72 @@ class OFA(OFABase):
config = {'conv2d_0': {'expand_ratio': 2}, 'conv2d_1': {'expand_ratio': 2}}
origin_model = ofa_model.export(origin_model, config, input_shapes=[1, 3, 28, 28], input_dtypes=['float32'])
"""
super_sd
=
None
self
.
set_net_config
(
config
)
self
.
model
.
eval
()
def
build_input
(
input_size
,
dtypes
):
if
isinstance
(
input_size
,
list
)
and
all
(
isinstance
(
i
,
numbers
.
Number
)
for
i
in
input_size
):
if
isinstance
(
dtypes
,
list
):
dtype
=
dtypes
[
0
]
else
:
dtype
=
dtypes
return
paddle
.
cast
(
paddle
.
rand
(
list
(
input_size
)),
dtype
)
if
isinstance
(
input_size
,
dict
):
inputs
=
{}
if
isinstance
(
dtypes
,
list
):
dtype
=
dtypes
[
0
]
else
:
dtype
=
dtypes
for
key
,
value
in
input_size
.
items
():
inputs
[
key
]
=
paddle
.
cast
(
paddle
.
rand
(
list
(
value
)),
dtype
)
return
inputs
if
isinstance
(
input_size
,
list
):
return
[
build_input
(
i
,
dtype
)
for
i
,
dtype
in
zip
(
input_size
,
dtypes
)
]
data
=
build_input
(
input_shapes
,
input_dtypes
)
if
isinstance
(
data
,
list
):
self
.
forward
(
*
data
)
else
:
self
.
forward
(
data
)
super_model_state_dict
=
None
if
load_weights_from_supernet
and
origin_model
!=
None
:
super_sd
=
remove_model_fn
(
origin_model
,
self
.
model
.
state_dict
())
super_model_state_dict
=
remove_model_fn
(
origin_model
,
self
.
model
.
state_dict
())
if
origin_model
==
None
:
origin_model
=
self
.
model
origin_model
=
origin_model
.
_layers
if
isinstance
(
origin_model
,
DataParallel
)
else
origin_model
param_config
=
self
.
_export_sub_model_config
(
origin_model
,
config
,
input_shapes
,
input_dtypes
)
prune_params
(
origin_model
,
param_config
,
super_sd
)
_logger
.
info
(
"Start to get pruned params, please wait..."
)
pruned_param
=
self
.
_get_model_pruned_weight
()
pruned_state_dict
=
remove_model_fn
(
origin_model
,
pruned_param
)
_logger
.
info
(
"Start to get pruned model, please wait..."
)
for
l_name
,
sublayer
in
origin_model
.
named_sublayers
():
for
p_name
,
param
in
sublayer
.
named_parameters
(
include_sublayers
=
False
):
name
=
l_name
+
'.'
+
p_name
t_value
=
param
.
value
().
get_tensor
()
if
name
in
pruned_state_dict
:
p
=
t_value
.
_place
()
if
p
.
is_cpu_place
():
place
=
core
.
CPUPlace
()
elif
p
.
is_cuda_pinned_place
():
place
=
core
.
CUDAPinnedPlace
()
else
:
place
=
core
.
CUDAPlace
(
p
.
gpu_device_id
())
t_value
.
set
(
pruned_state_dict
[
name
],
place
)
if
super_model_state_dict
!=
None
and
len
(
super_model_state_dict
)
!=
0
:
for
k
,
v
in
super_model_state_dict
.
items
():
setattr
(
origin_model
,
k
,
v
)
return
origin_model
@
property
...
...
@@ -566,11 +630,26 @@ class OFA(OFABase):
input_shapes
=
[]
input_dtypes
=
[]
for
n
in
inputs
:
input_shapes
.
append
(
n
.
shape
)
input_dtypes
.
append
(
n
.
numpy
().
dtype
)
for
n
,
v
in
kwargs
.
items
():
input_shapes
.
append
(
v
.
shape
)
input_dtypes
.
append
(
v
.
numpy
().
dtype
)
if
isinstance
(
n
,
Variable
):
input_shapes
.
append
(
n
)
input_dtypes
.
append
(
n
.
numpy
().
dtype
)
for
key
,
val
in
kwargs
.
items
():
if
isinstance
(
val
,
Variable
):
input_shapes
.
append
(
val
)
input_dtypes
.
append
(
val
.
numpy
().
dtype
)
elif
isinstance
(
val
,
dict
):
input_shape
=
{}
input_dtype
=
{}
for
k
,
v
in
val
.
items
():
input_shape
[
k
]
=
v
input_dtype
[
k
]
=
v
.
numpy
().
dtype
input_shapes
.
append
(
input_shape
)
input_dtypes
.
append
(
input_dtype
)
else
:
_logger
.
error
(
"Cannot figure out the type of inputs! Right now, the type of inputs can be only Variable or dict."
)
### find shortcut block using static model
model_to_traverse
=
self
.
model
.
_layers
if
isinstance
(
...
...
@@ -674,11 +753,9 @@ class OFA(OFABase):
_logger
.
debug
(
"Current config is {}"
.
format
(
self
.
current_config
))
if
'depth'
in
self
.
current_config
:
kwargs
[
'depth'
]
=
self
.
current_config
[
'depth'
]
if
self
.
_broadcast
:
broadcast_search_space
(
self
.
_same_ss
,
self
.
_param2key
,
self
.
current_config
)
student_output
=
self
.
model
.
forward
(
*
inputs
,
**
kwargs
)
if
self
.
_add_teacher
:
...
...
tests/test_ofa.py
浏览文件 @
e274d7fb
...
...
@@ -142,10 +142,19 @@ class ModelConv2(nn.Layer):
class
ModelLinear
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
ModelLinear
,
self
).
__init__
()
with
supernet
(
expand_ratio
=
(
1
.0
,
2.0
,
4.0
))
as
ofa_super
:
with
supernet
(
expand_ratio
=
(
1
,
2
,
4
))
as
ofa_super
:
models
=
[]
models
+=
[
nn
.
Embedding
(
num_embeddings
=
64
,
embedding_dim
=
64
)]
models
+=
[
nn
.
Linear
(
64
,
128
)]
weight_attr
=
paddle
.
ParamAttr
(
learning_rate
=
0.5
,
regularizer
=
paddle
.
regularizer
.
L2Decay
(
1.0
),
trainable
=
True
)
bias_attr
=
paddle
.
ParamAttr
(
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.0
))
models
+=
[
nn
.
Linear
(
64
,
128
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
]
models
+=
[
nn
.
LayerNorm
(
128
)]
models
+=
[
nn
.
Linear
(
128
,
256
)]
models
=
ofa_super
.
convert
(
models
)
...
...
@@ -402,16 +411,7 @@ class TestExport(unittest.TestCase):
self
.
ofa_model
=
OFA
(
model
)
def
test_ofa
(
self
):
config
=
{
'embedding_1'
:
{
'expand_ratio'
:
(
2.0
)
},
'linear_3'
:
{
'expand_ratio'
:
(
2.0
)
},
'linear_4'
:
{},
'linear_5'
:
{}
}
config
=
self
.
ofa_model
.
_sample_config
(
task
=
'expand_ratio'
,
phase
=
None
)
origin_dict
=
{}
for
name
,
param
in
self
.
origin_model
.
named_parameters
():
origin_dict
[
name
]
=
param
.
shape
...
...
@@ -459,9 +459,29 @@ class TestExportCase1(unittest.TestCase):
outs
,
_
=
self
.
ofa_model
(
self
.
data
)
self
.
config
=
self
.
ofa_model
.
current_config
def
test_export_model
(
self
):
self
.
ofa_model
.
export
(
def
test_export_model
_linear1
(
self
):
ex_model
=
self
.
ofa_model
.
export
(
self
.
config
,
input_shapes
=
[[
3
,
64
]],
input_dtypes
=
[
'int64'
])
ex_model
(
self
.
data
)
assert
len
(
self
.
ofa_model
.
ofa_layers
)
==
4
class
TestExportCase2
(
unittest
.
TestCase
):
def
setUp
(
self
):
model
=
ModelLinear
()
data_np
=
np
.
random
.
random
((
3
,
64
)).
astype
(
np
.
int64
)
self
.
data
=
paddle
.
to_tensor
(
data_np
)
self
.
ofa_model
=
OFA
(
model
)
self
.
ofa_model
.
set_epoch
(
0
)
outs
,
_
=
self
.
ofa_model
(
self
.
data
)
self
.
config
=
self
.
ofa_model
.
current_config
def
test_export_model_linear2
(
self
):
config
=
self
.
ofa_model
.
_sample_config
(
task
=
'expand_ratio'
,
phase
=
None
,
sample_type
=
'smallest'
)
ex_model
=
self
.
ofa_model
.
export
(
config
,
input_shapes
=
[[
3
,
64
]],
input_dtypes
=
[
'int64'
])
ex_model
(
self
.
data
)
assert
len
(
self
.
ofa_model
.
ofa_layers
)
==
4
...
...
tests/test_ofa_v2.py
浏览文件 @
e274d7fb
...
...
@@ -81,6 +81,25 @@ class ModelShortcut(nn.Layer):
return
z
class
ModelInputDict
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
ModelInputDict
,
self
).
__init__
()
self
.
conv0
=
nn
.
Sequential
(
nn
.
Conv2D
(
3
,
12
,
1
),
nn
.
BatchNorm2D
(
12
),
nn
.
ReLU
())
self
.
conv1
=
nn
.
Sequential
(
nn
.
Conv2D
(
12
,
12
,
1
),
nn
.
BatchNorm2D
(
12
),
nn
.
ReLU
())
self
.
conv2
=
nn
.
Sequential
(
nn
.
Conv2D
(
12
,
12
,
1
),
nn
.
BatchNorm2D
(
12
),
nn
.
ReLU
())
self
.
conv3
=
nn
.
Sequential
(
nn
.
Conv2D
(
12
,
12
,
1
),
nn
.
BatchNorm2D
(
12
),
nn
.
ReLU
())
def
forward
(
self
,
x
,
data
):
x
=
self
.
conv1
(
self
.
conv0
(
x
))
y
=
self
.
conv2
(
x
)
y
=
y
+
data
[
'data'
]
return
self
.
conv3
(
y
)
class
TestOFAV2
(
unittest
.
TestCase
):
def
setUp
(
self
):
model
=
ModelV1
()
...
...
@@ -93,7 +112,6 @@ class TestOFAV2(unittest.TestCase):
self
.
ofa_model
.
set_epoch
(
0
)
self
.
ofa_model
.
set_task
(
'expand_ratio'
)
out
,
_
=
self
.
ofa_model
(
self
.
images
)
print
(
self
.
ofa_model
.
get_current_config
)
class
TestOFAV2Export
(
unittest
.
TestCase
):
...
...
@@ -151,5 +169,34 @@ class TestShortcutSkiplayersCase2(TestShortcutSkiplayers):
assert
list
(
self
.
ofa_model
.
_ofa_layers
.
keys
())
==
[
'conv1.0'
,
'out.0'
]
class
TestInputDict
(
unittest
.
TestCase
):
def
setUp
(
self
):
model
=
ModelInputDict
()
sp_net_config
=
supernet
(
expand_ratio
=
[
0.5
,
1.0
])
self
.
model
=
Convert
(
sp_net_config
).
convert
(
model
)
self
.
images
=
paddle
.
randn
(
shape
=
[
2
,
3
,
32
,
32
],
dtype
=
'float32'
)
self
.
images2
=
{
'data'
:
paddle
.
randn
(
shape
=
[
2
,
12
,
32
,
32
],
dtype
=
'float32'
)
}
default_run_config
=
{
'skip_layers'
:
[
'conv1.0'
,
'conv2.0'
]}
self
.
run_config
=
RunConfig
(
**
default_run_config
)
self
.
ofa_model
=
OFA
(
self
.
model
,
run_config
=
self
.
run_config
)
self
.
ofa_model
.
_clear_search_space
(
self
.
images
,
data
=
self
.
images2
)
def
test_export
(
self
):
config
=
self
.
ofa_model
.
_sample_config
(
task
=
"expand_ratio"
,
sample_type
=
"smallest"
)
self
.
ofa_model
.
export
(
config
,
input_shapes
=
[[
1
,
3
,
32
,
32
],
{
'data'
:
[
1
,
12
,
32
,
32
]
}],
input_dtypes
=
[
'float32'
,
'float32'
])
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录