Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
40e0684a
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
40e0684a
编写于
12月 08, 2020
作者:
C
ceci3
提交者:
GitHub
12月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine ofa (#527)
* support 2.0
上级
f3b898c8
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
1426 addition
and
203 deletion
+1426
-203
paddleslim/nas/ofa/__init__.py
paddleslim/nas/ofa/__init__.py
+7
-1
paddleslim/nas/ofa/convert_super.py
paddleslim/nas/ofa/convert_super.py
+301
-145
paddleslim/nas/ofa/layers_new.py
paddleslim/nas/ofa/layers_new.py
+966
-0
paddleslim/nas/ofa/ofa.py
paddleslim/nas/ofa/ofa.py
+16
-10
paddleslim/nas/ofa/utils/utils.py
paddleslim/nas/ofa/utils/utils.py
+10
-0
tests/test_convert_supernet.py
tests/test_convert_supernet.py
+33
-0
tests/test_ofa.py
tests/test_ofa.py
+93
-47
未找到文件。
paddleslim/nas/ofa/__init__.py
浏览文件 @
40e0684a
...
...
@@ -14,4 +14,10 @@
from
.ofa
import
OFA
,
RunConfig
,
DistillConfig
from
.convert_super
import
supernet
from
.layers
import
*
from
.utils.utils
import
get_paddle_version
pd_ver
=
get_paddle_version
()
if
pd_ver
==
185
:
from
.layers
import
*
else
:
from
.layers_new
import
*
paddleslim/nas/ofa/convert_super.py
浏览文件 @
40e0684a
...
...
@@ -15,11 +15,23 @@
import
inspect
import
decorator
import
logging
import
paddle
import
numbers
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Conv2DTranspose
,
Linear
,
BatchNorm
,
InstanceNorm
,
LayerNorm
,
Embedding
from
.layers
import
*
import
paddle
from
...common
import
get_logger
from
.utils.utils
import
get_paddle_version
pd_ver
=
get_paddle_version
()
if
pd_ver
==
185
:
import
paddle.fluid.dygraph.nn
as
nn
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Conv2DTranspose
,
Linear
,
LayerNorm
,
Embedding
from
.layers
import
*
from
.
import
layers
Layer
=
fluid
.
dygraph
.
nn
.
Layer
else
:
import
paddle.nn
as
nn
from
paddle.nn
import
Conv2D
,
Conv2DTranspose
,
Linear
,
LayerNorm
,
Embedding
from
.layers_new
import
*
from
.
import
layers_new
as
layers
Layer
=
paddle
.
nn
.
Layer
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
...
...
@@ -28,19 +40,25 @@ __all__ = ['supernet']
WEIGHT_LAYER
=
[
'conv'
,
'linear'
,
'embedding'
]
### TODO: add decorator
class
Convert
:
def
__init__
(
self
,
context
):
self
.
context
=
context
def
convert
(
self
,
model
):
def
convert
(
self
,
network
):
# search the first and last weight layer, don't change out channel of the last weight layer
# don't change in channel of the first weight layer
model
=
[]
if
isinstance
(
network
,
Layer
):
for
name
,
sublayer
in
network
.
named_sublayers
():
model
.
append
(
sublayer
)
else
:
model
=
network
first_weight_layer_idx
=
-
1
last_weight_layer_idx
=
-
1
weight_layer_count
=
0
# NOTE: pre_channel store for shortcut module
pre_channel
=
0
pre_channel
=
None
cur_channel
=
None
for
idx
,
layer
in
enumerate
(
model
):
cls_name
=
layer
.
__class__
.
__name__
.
lower
()
...
...
@@ -61,50 +79,68 @@ class Convert:
key
=
attr_dict
[
'_full_name'
]
new_attr_name
=
[
'_stride'
,
'_dilation'
,
'_groups'
,
'_param_attr'
,
'_bias_attr'
,
'_use_cudnn'
,
'_act'
,
'_dtype'
,
'_padding'
'stride'
,
'padding'
,
'dilation'
,
'groups'
,
'bias_attr'
]
if
pd_ver
==
185
:
new_attr_name
+=
[
'param_attr'
,
'use_cudnn'
,
'act'
,
'dtype'
]
else
:
new_attr_name
+=
[
'weight_attr'
,
'data_format'
,
'padding_mode'
]
new_attr_dict
=
dict
(
)
new_attr_dict
=
dict
.
fromkeys
(
new_attr_name
,
None
)
new_attr_dict
[
'candidate_config'
]
=
dict
()
if
pd_ver
==
185
:
new_attr_dict
[
'num_channels'
]
=
None
new_attr_dict
[
'num_filters'
]
=
None
new_attr_dict
[
'filter_size'
]
=
None
else
:
new_attr_dict
[
'in_channels'
]
=
None
new_attr_dict
[
'out_channels'
]
=
None
new_attr_dict
[
'kernel_size'
]
=
None
self
.
kernel_size
=
getattr
(
self
.
context
,
'kernel_size'
,
None
)
if
self
.
kernel_size
!=
None
:
new_attr_dict
[
'transform_kernel'
]
=
True
# if the kernel_size of conv is 1, don't change it.
#if self.kernel_size and int(attr_dict['_filter_size'][0]) != 1:
if
self
.
kernel_size
and
int
(
attr_dict
[
'_filter_size'
])
!=
1
:
new_attr_dict
[
'filter_size'
]
=
max
(
self
.
kernel_size
)
fks
=
'_filter_size'
if
'_filter_size'
in
attr_dict
.
keys
(
)
else
'_kernel_size'
ks
=
list
(
attr_dict
[
fks
])
if
isinstance
(
attr_dict
[
fks
],
numbers
.
Integral
)
else
attr_dict
[
fks
]
if
self
.
kernel_size
and
int
(
ks
[
0
])
!=
1
:
new_attr_dict
[
'transform_kernel'
]
=
True
new_attr_dict
[
fks
[
1
:]]
=
max
(
self
.
kernel_size
)
new_attr_dict
[
'candidate_config'
].
update
({
'kernel_size'
:
self
.
kernel_size
})
else
:
new_attr_dict
[
'filter_size'
]
=
attr_dict
[
'_filter_size'
]
new_attr_dict
[
fks
[
1
:]]
=
attr_dict
[
fks
]
in_key
=
'_num_channels'
if
'_num_channels'
in
attr_dict
.
keys
(
)
else
'_in_channels'
out_key
=
'_num_filters'
if
'_num_filters'
in
attr_dict
.
keys
(
)
else
'_out_channels'
if
self
.
context
.
expand
:
### first super convolution
if
idx
==
first_weight_layer_idx
:
new_attr_dict
[
'num_channels'
]
=
attr_dict
[
'_num_channels'
]
new_attr_dict
[
in_key
[
1
:]]
=
attr_dict
[
in_key
]
else
:
new_attr_dict
[
'num_channels'
]
=
self
.
context
.
expand
*
attr_dict
[
'_num_channels'
]
new_attr_dict
[
in_key
[
1
:]]
=
int
(
self
.
context
.
expand
*
attr_dict
[
in_key
])
### last super convolution
if
idx
==
last_weight_layer_idx
:
new_attr_dict
[
'num_filters'
]
=
attr_dict
[
'_num_filters'
]
new_attr_dict
[
out_key
[
1
:]]
=
attr_dict
[
out_key
]
else
:
new_attr_dict
[
'num_filters'
]
=
self
.
context
.
expand
*
attr_dict
[
'_num_filters'
]
new_attr_dict
[
out_key
[
1
:]]
=
int
(
self
.
context
.
expand
*
attr_dict
[
out_key
])
new_attr_dict
[
'candidate_config'
].
update
({
'expand_ratio'
:
self
.
context
.
expand_ratio
})
elif
self
.
context
.
channel
:
if
attr_dict
[
'_groups'
]
!=
None
and
(
int
(
attr_dict
[
'_groups'
])
==
int
(
attr_dict
[
'_num_channels'
])
):
int
(
attr_dict
[
'_groups'
])
==
int
(
attr_dict
[
in_key
])
):
### depthwise conv, if conv is depthwise, use pre channel as cur_channel
_logger
.
warn
(
"If convolution is a depthwise conv, output channel change"
\
...
...
@@ -115,25 +151,27 @@ class Convert:
cur_channel
=
self
.
context
.
channel
[
0
]
self
.
context
.
channel
=
self
.
context
.
channel
[
1
:]
if
idx
==
first_weight_layer_idx
:
new_attr_dict
[
'num_channels'
]
=
attr_dict
[
'_num_channels'
]
new_attr_dict
[
in_key
[
1
:]]
=
attr_dict
[
in_key
]
else
:
new_attr_dict
[
'num_channels'
]
=
max
(
pre_channel
)
new_attr_dict
[
in_key
[
1
:]
]
=
max
(
pre_channel
)
if
idx
==
last_weight_layer_idx
:
new_attr_dict
[
'num_filters'
]
=
attr_dict
[
'_num_filters'
]
new_attr_dict
[
out_key
[
1
:]]
=
attr_dict
[
out_key
]
else
:
new_attr_dict
[
'num_filters'
]
=
max
(
cur_channel
)
new_attr_dict
[
out_key
[
1
:]
]
=
max
(
cur_channel
)
new_attr_dict
[
'candidate_config'
].
update
({
'channel'
:
cur_channel
})
pre_channel
=
cur_channel
else
:
new_attr_dict
[
'num_filters'
]
=
attr_dict
[
'_num_filters'
]
new_attr_dict
[
'num_channels'
]
=
attr_dict
[
'_num_channels'
]
new_attr_dict
[
in_key
[
1
:]]
=
attr_dict
[
in_key
]
new_attr_dict
[
out_key
[
1
:]]
=
attr_dict
[
out_key
]
for
attr
in
new_attr_name
:
new_attr_dict
[
attr
[
1
:]]
=
attr_dict
[
attr
]
if
attr
==
'weight_attr'
:
new_attr_dict
[
attr
]
=
attr_dict
[
'_param_attr'
]
else
:
new_attr_dict
[
attr
]
=
attr_dict
[
'_'
+
attr
]
del
layer
...
...
@@ -141,17 +179,15 @@ class Convert:
'_groups'
])
==
1
:
### standard conv
layer
=
Block
(
SuperConv2D
(
**
new_attr_dict
),
key
=
key
)
elif
int
(
attr_dict
[
'_groups'
])
==
int
(
attr_dict
[
'_num_channels'
]):
elif
int
(
attr_dict
[
'_groups'
])
==
int
(
attr_dict
[
in_key
]):
# if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
# channel in candidate_config = in_channel_list
if
'channel'
in
new_attr_dict
[
'candidate_config'
]:
new_attr_dict
[
'num_channels'
]
=
max
(
cur_channel
)
new_attr_dict
[
'num_filters'
]
=
new_attr_dict
[
'num_channels'
]
new_attr_dict
[
in_key
[
1
:]]
=
max
(
cur_channel
)
new_attr_dict
[
out_key
[
1
:]]
=
new_attr_dict
[
in_key
[
1
:]]
new_attr_dict
[
'candidate_config'
][
'channel'
]
=
cur_channel
new_attr_dict
[
'groups'
]
=
new_attr_dict
[
'num_channels'
]
new_attr_dict
[
'groups'
]
=
new_attr_dict
[
in_key
[
1
:]
]
layer
=
Block
(
SuperDepthwiseConv2D
(
**
new_attr_dict
),
key
=
key
)
else
:
...
...
@@ -159,7 +195,8 @@ class Convert:
layer
=
Block
(
SuperGroupConv2D
(
**
new_attr_dict
),
key
=
key
)
model
[
idx
]
=
layer
elif
isinstance
(
layer
,
BatchNorm
)
and
(
elif
isinstance
(
layer
,
getattr
(
nn
,
'BatchNorm2D'
,
nn
.
BatchNorm
))
and
(
getattr
(
self
.
context
,
'expand'
,
None
)
!=
None
or
getattr
(
self
.
context
,
'channel'
,
None
)
!=
None
):
# num_features in BatchNorm don't change after last weight operators
...
...
@@ -167,26 +204,41 @@ class Convert:
continue
attr_dict
=
layer
.
__dict__
new_attr_name
=
[
'_param_attr'
,
'_bias_attr'
,
'_act'
,
'_dtype'
,
'_in_place'
,
'_data_layout'
,
'_momentum'
,
'_epsilon'
,
'_is_test'
,
'_use_global_stats'
,
'_trainable_statistics'
new_attr_name
=
[
'momentum'
,
'epsilon'
,
'bias_attr'
]
if
pd_ver
==
185
:
new_attr_name
+=
[
'param_attr'
,
'act'
,
'dtype'
,
'in_place'
,
'data_layout'
,
'is_test'
,
'use_global_stats'
,
'trainable_statistics'
]
new_attr_dict
=
dict
()
else
:
new_attr_name
+=
[
'weight_attr'
,
'data_format'
,
'name'
]
new_attr_dict
=
dict
.
fromkeys
(
new_attr_name
,
None
)
if
pd_ver
==
185
:
new_attr_dict
[
'num_channels'
]
=
None
else
:
new_attr_dict
[
'num_features'
]
=
None
new_key
=
'num_channels'
if
'num_channels'
in
new_attr_dict
.
keys
(
)
else
'num_features'
if
self
.
context
.
expand
:
new_attr_dict
[
'num_channels'
]
=
self
.
context
.
expand
*
int
(
new_attr_dict
[
new_key
]
=
int
(
self
.
context
.
expand
*
layer
.
_parameters
[
'weight'
].
shape
[
0
])
elif
self
.
context
.
channel
:
new_attr_dict
[
'num_channels'
]
=
max
(
cur_channel
)
new_attr_dict
[
new_key
]
=
max
(
cur_channel
)
else
:
new_attr_dict
[
'num_channels'
]
=
attr_dict
[
'_num_channels'
]
new_attr_dict
[
new_key
]
=
attr_dict
[
'_num_channels'
]
if
'_num_channels'
in
attr_dict
.
keys
(
)
else
attr_dict
[
'_num_features'
]
for
attr
in
new_attr_name
:
new_attr_dict
[
attr
[
1
:]]
=
attr_dict
[
attr
]
new_attr_dict
[
attr
]
=
attr_dict
[
'_'
+
attr
]
del
layer
,
attr_dict
layer
=
SuperBatchNorm
(
**
new_attr_dict
)
layer
=
getattr
(
layers
,
'SuperBatchNorm'
,
SuperBatchNorm2D
)(
**
new_attr_dict
)
model
[
idx
]
=
layer
### assume output_size = None, filter_size != None
...
...
@@ -196,52 +248,72 @@ class Convert:
key
=
attr_dict
[
'_full_name'
]
new_attr_name
=
[
'_stride'
,
'_dilation'
,
'_groups'
,
'_param_attr'
,
'_padding'
,
'_bias_attr'
,
'_use_cudnn'
,
'_act'
,
'_dtype'
,
'_output_size'
'stride'
,
'padding'
,
'dilation'
,
'groups'
,
'bias_attr'
]
assert
attr_dict
[
'_filter_size'
]
!=
None
,
"Conv2DTranspose only support filter size != None now"
assert
getattr
(
attr_dict
,
'_filter_size'
,
'_kernel_size'
)
!=
None
,
"Conv2DTranspose only support kernel size != None now"
new_attr_dict
=
dict
()
if
pd_ver
==
185
:
new_attr_name
+=
[
'output_size'
,
'param_attr'
,
'use_cudnn'
,
'act'
,
'dtype'
]
else
:
new_attr_name
+=
[
'output_padding'
,
'weight_attr'
,
'data_format'
]
new_attr_dict
=
dict
.
fromkeys
(
new_attr_name
,
None
)
new_attr_dict
[
'candidate_config'
]
=
dict
()
if
pd_ver
==
185
:
new_attr_dict
[
'num_channels'
]
=
None
new_attr_dict
[
'num_filters'
]
=
None
new_attr_dict
[
'filter_size'
]
=
None
else
:
new_attr_dict
[
'in_channels'
]
=
None
new_attr_dict
[
'out_channels'
]
=
None
new_attr_dict
[
'kernel_size'
]
=
None
self
.
kernel_size
=
getattr
(
self
.
context
,
'kernel_size'
,
None
)
if
self
.
kernel_size
!=
None
:
new_attr_dict
[
'transform_kernel'
]
=
True
# if the kernel_size of conv transpose is 1, don't change it.
if
self
.
kernel_size
and
int
(
attr_dict
[
'_filter_size'
][
0
])
!=
1
:
new_attr_dict
[
'filter_size'
]
=
max
(
self
.
kernel_size
)
fks
=
'_filter_size'
if
'_filter_size'
in
attr_dict
.
keys
(
)
else
'_kernel_size'
ks
=
list
(
attr_dict
[
fks
])
if
isinstance
(
attr_dict
[
fks
],
numbers
.
Integral
)
else
attr_dict
[
fks
]
if
self
.
kernel_size
and
int
(
ks
[
0
])
!=
1
:
new_attr_dict
[
'transform_kernel'
]
=
True
new_attr_dict
[
fks
[
1
:]]
=
max
(
self
.
kernel_size
)
new_attr_dict
[
'candidate_config'
].
update
({
'kernel_size'
:
self
.
kernel_size
})
else
:
new_attr_dict
[
'filter_size'
]
=
attr_dict
[
'_filter_size'
]
new_attr_dict
[
fks
[
1
:]]
=
attr_dict
[
fks
]
in_key
=
'_num_channels'
if
'_num_channels'
in
attr_dict
.
keys
(
)
else
'_in_channels'
out_key
=
'_num_filters'
if
'_num_filters'
in
attr_dict
.
keys
(
)
else
'_out_channels'
if
self
.
context
.
expand
:
### first super convolution transpose
if
idx
==
first_weight_layer_idx
:
new_attr_dict
[
'num_channels'
]
=
attr_dict
[
'_num_channels'
]
new_attr_dict
[
in_key
[
1
:]]
=
attr_dict
[
in_key
]
else
:
new_attr_dict
[
'num_channels'
]
=
self
.
context
.
expand
*
attr_dict
[
'_num_channels'
]
new_attr_dict
[
in_key
[
1
:]]
=
int
(
self
.
context
.
expand
*
attr_dict
[
in_key
])
### last super convolution transpose
if
idx
==
last_weight_layer_idx
:
new_attr_dict
[
'num_filters'
]
=
attr_dict
[
'_num_filters'
]
new_attr_dict
[
out_key
[
1
:]]
=
attr_dict
[
out_key
]
else
:
new_attr_dict
[
'num_filters'
]
=
self
.
context
.
expand
*
attr_dict
[
'_num_filters'
]
new_attr_dict
[
out_key
[
1
:]]
=
int
(
self
.
context
.
expand
*
attr_dict
[
out_key
])
new_attr_dict
[
'candidate_config'
].
update
({
'expand_ratio'
:
self
.
context
.
expand_ratio
})
elif
self
.
context
.
channel
:
if
attr_dict
[
'_groups'
]
!=
None
and
(
int
(
attr_dict
[
'_groups'
])
==
int
(
attr_dict
[
'_num_channels'
])
):
int
(
attr_dict
[
'_groups'
])
==
int
(
attr_dict
[
in_key
])
):
### depthwise conv_transpose
_logger
.
warn
(
"If convolution is a depthwise conv_transpose, output channel "
\
...
...
@@ -252,29 +324,33 @@ class Convert:
cur_channel
=
self
.
context
.
channel
[
0
]
self
.
context
.
channel
=
self
.
context
.
channel
[
1
:]
if
idx
==
first_weight_layer_idx
:
new_attr_dict
[
'num_channels'
]
=
attr_dict
[
'_num_channels'
]
new_attr_dict
[
in_key
[
1
:]]
=
attr_dict
[
in_key
]
else
:
new_attr_dict
[
'num_channels'
]
=
max
(
pre_channel
)
new_attr_dict
[
in_key
[
1
:]
]
=
max
(
pre_channel
)
if
idx
==
last_weight_layer_idx
:
new_attr_dict
[
'num_filters'
]
=
attr_dict
[
'_num_filters'
]
new_attr_dict
[
out_key
[
1
:]]
=
attr_dict
[
out_key
]
else
:
new_attr_dict
[
'num_filters'
]
=
max
(
cur_channel
)
new_attr_dict
[
out_key
[
1
:]
]
=
max
(
cur_channel
)
new_attr_dict
[
'candidate_config'
].
update
({
'channel'
:
cur_channel
})
pre_channel
=
cur_channel
else
:
new_attr_dict
[
'num_filters'
]
=
attr_dict
[
'_num_filters'
]
new_attr_dict
[
'num_channels'
]
=
attr_dict
[
'_num_channels'
]
new_attr_dict
[
in_key
[
1
:]]
=
attr_dict
[
in_key
]
new_attr_dict
[
out_key
[
1
:]]
=
attr_dict
[
out_key
]
for
attr
in
new_attr_name
:
new_attr_dict
[
attr
[
1
:]]
=
attr_dict
[
attr
]
if
attr
==
'weight_attr'
:
new_attr_dict
[
attr
]
=
attr_dict
[
'_param_attr'
]
elif
attr
==
'output_padding'
:
new_attr_dict
[
attr
]
=
attr_dict
[
attr
]
else
:
new_attr_dict
[
attr
]
=
attr_dict
[
'_'
+
attr
]
del
layer
if
new_attr_dict
[
'output_size'
]
==
[]:
if
getattr
(
new_attr_dict
,
'output_size'
,
None
)
==
[]:
new_attr_dict
[
'output_size'
]
=
None
if
attr_dict
[
'_groups'
]
==
None
or
int
(
attr_dict
[
...
...
@@ -282,17 +358,15 @@ class Convert:
### standard conv_transpose
layer
=
Block
(
SuperConv2DTranspose
(
**
new_attr_dict
),
key
=
key
)
elif
int
(
attr_dict
[
'_groups'
])
==
int
(
attr_dict
[
'_num_channels'
]):
elif
int
(
attr_dict
[
'_groups'
])
==
int
(
attr_dict
[
in_key
]):
# if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
# channel in candidate_config = in_channel_list
if
'channel'
in
new_attr_dict
[
'candidate_config'
]:
new_attr_dict
[
'num_channels'
]
=
max
(
cur_channel
)
new_attr_dict
[
'num_filters'
]
=
new_attr_dict
[
'num_channels'
]
new_attr_dict
[
in_key
[
1
:]]
=
max
(
cur_channel
)
new_attr_dict
[
out_key
[
1
:]]
=
new_attr_dict
[
in_key
[
1
:]]
new_attr_dict
[
'candidate_config'
][
'channel'
]
=
cur_channel
new_attr_dict
[
'groups'
]
=
new_attr_dict
[
'num_channels'
]
new_attr_dict
[
'groups'
]
=
new_attr_dict
[
in_key
[
1
:]
]
layer
=
Block
(
SuperDepthwiseConv2DTranspose
(
**
new_attr_dict
),
key
=
key
)
else
:
...
...
@@ -306,25 +380,39 @@ class Convert:
getattr
(
self
.
context
,
'channel'
,
None
)
!=
None
):
attr_dict
=
layer
.
__dict__
key
=
attr_dict
[
'_full_name'
]
### TODO(paddle): add _param_attr and _bias_attr as private variable of Linear
#new_attr_name = ['_act', '_dtype', '_param_attr', '_bias_attr']
new_attr_name
=
[
'_act'
,
'_dtype'
]
if
pd_ver
==
185
:
new_attr_name
=
[
'param_attr'
,
'bias_attr'
,
'act'
,
'dtype'
]
else
:
new_attr_name
=
[
'weight_attr'
,
'bias_attr'
]
in_nc
,
out_nc
=
layer
.
_parameters
[
'weight'
].
shape
new_attr_dict
=
dict
(
)
new_attr_dict
=
dict
.
fromkeys
(
new_attr_name
,
None
)
new_attr_dict
[
'candidate_config'
]
=
dict
()
if
pd_ver
==
185
:
new_attr_dict
[
'input_dim'
]
=
None
new_attr_dict
[
'output_dim'
]
=
None
else
:
new_attr_dict
[
'in_features'
]
=
None
new_attr_dict
[
'out_features'
]
=
None
in_key
=
'_input_dim'
if
'_input_dim'
in
attr_dict
.
keys
(
)
else
'_in_features'
out_key
=
'_output_dim'
if
'_output_dim'
in
attr_dict
.
keys
(
)
else
'_out_features'
attr_dict
[
in_key
]
=
in_nc
attr_dict
[
out_key
]
=
out_nc
if
self
.
context
.
expand
:
if
idx
==
first_weight_layer_idx
:
new_attr_dict
[
'input_dim'
]
=
int
(
in_nc
)
new_attr_dict
[
in_key
[
1
:]]
=
int
(
attr_dict
[
in_key
]
)
else
:
new_attr_dict
[
'input_dim'
]
=
self
.
context
.
expand
*
int
(
in_nc
)
new_attr_dict
[
in_key
[
1
:]]
=
int
(
self
.
context
.
expand
*
attr_dict
[
in_key
]
)
if
idx
==
last_weight_layer_idx
:
new_attr_dict
[
'output_dim'
]
=
int
(
out_nc
)
new_attr_dict
[
out_key
[
1
:]]
=
int
(
attr_dict
[
out_key
]
)
else
:
new_attr_dict
[
'output_dim'
]
=
self
.
context
.
expand
*
int
(
out_nc
)
new_attr_dict
[
out_key
[
1
:]]
=
int
(
self
.
context
.
expand
*
attr_dict
[
out_key
]
)
new_attr_dict
[
'candidate_config'
].
update
({
'expand_ratio'
:
self
.
context
.
expand_ratio
})
...
...
@@ -332,31 +420,34 @@ class Convert:
cur_channel
=
self
.
context
.
channel
[
0
]
self
.
context
.
channel
=
self
.
context
.
channel
[
1
:]
if
idx
==
first_weight_layer_idx
:
new_attr_dict
[
'input_dim'
]
=
int
(
in_nc
)
new_attr_dict
[
in_key
[
1
:]]
=
int
(
attr_dict
[
in_key
]
)
else
:
new_attr_dict
[
'input_dim'
]
=
max
(
pre_channel
)
new_attr_dict
[
in_key
[
1
:]
]
=
max
(
pre_channel
)
if
idx
==
last_weight_layer_idx
:
new_attr_dict
[
'output_dim'
]
=
int
(
out_nc
)
new_attr_dict
[
out_key
[
1
:]]
=
int
(
attr_dict
[
out_key
]
)
else
:
new_attr_dict
[
'output_dim'
]
=
max
(
cur_channel
)
new_attr_dict
[
out_key
[
1
:]
]
=
max
(
cur_channel
)
new_attr_dict
[
'candidate_config'
].
update
({
'channel'
:
cur_channel
})
pre_channel
=
cur_channel
else
:
new_attr_dict
[
'input_dim'
]
=
int
(
in_nc
)
new_attr_dict
[
'output_dim'
]
=
int
(
out_nc
)
new_attr_dict
[
in_key
[
1
:]]
=
int
(
attr_dict
[
in_key
]
)
new_attr_dict
[
out_key
[
1
:]]
=
int
(
attr_dict
[
out_key
]
)
for
attr
in
new_attr_name
:
new_attr_dict
[
attr
[
1
:]]
=
attr_dict
[
attr
]
new_attr_dict
[
attr
]
=
attr_dict
[
'_'
+
attr
]
del
layer
,
attr_dict
layer
=
Block
(
SuperLinear
(
**
new_attr_dict
),
key
=
key
)
model
[
idx
]
=
layer
elif
isinstance
(
layer
,
InstanceNorm
)
and
(
elif
isinstance
(
layer
,
getattr
(
nn
,
'InstanceNorm2D'
,
paddle
.
fluid
.
dygraph
.
nn
.
InstanceNorm
))
and
(
getattr
(
self
.
context
,
'expand'
,
None
)
!=
None
or
getattr
(
self
.
context
,
'channel'
,
None
)
!=
None
):
# num_features in InstanceNorm don't change after last weight operators
...
...
@@ -364,24 +455,38 @@ class Convert:
continue
attr_dict
=
layer
.
__dict__
if
pd_ver
==
185
:
new_attr_name
=
[
'_param_attr'
,
'_bias_attr'
,
'_dtype'
,
'_epsilon
'
'bias_attr'
,
'epsilon'
,
'param_attr'
,
'dtype
'
]
new_attr_dict
=
dict
()
else
:
new_attr_name
=
[
'bias_attr'
,
'epsilon'
,
'weight_attr'
]
new_attr_dict
=
dict
.
fromkeys
(
new_attr_name
,
None
)
if
pd_ver
==
185
:
new_attr_dict
[
'num_channels'
]
=
None
else
:
new_attr_dict
[
'num_features'
]
=
None
new_key
=
'_num_channels'
if
'_num_channels'
in
new_attr_dict
.
keys
(
)
else
'_num_features'
### 10 is a default channel in the case of weight_attr=False, in this condition, num of channels if useless, so give it arbitrarily.
attr_dict
[
new_key
]
=
layer
.
_parameters
[
'scale'
].
shape
[
0
]
if
len
(
layer
.
_parameters
)
!=
0
else
10
if
self
.
context
.
expand
:
new_attr_dict
[
'num_channels'
]
=
self
.
context
.
expand
*
int
(
layer
.
_parameters
[
'scale'
].
shape
[
0
])
new_attr_dict
[
new_key
[
1
:]]
=
int
(
self
.
context
.
expand
*
attr_dict
[
new_key
])
elif
self
.
context
.
channel
:
new_attr_dict
[
'num_channels'
]
=
max
(
cur_channel
)
new_attr_dict
[
new_key
[
1
:]
]
=
max
(
cur_channel
)
else
:
new_attr_dict
[
'num_channels'
]
=
attr_dict
[
'_num_channels'
]
new_attr_dict
[
new_key
[
1
:]]
=
attr_dict
[
new_key
]
for
attr
in
new_attr_name
:
new_attr_dict
[
attr
[
1
:]]
=
attr_dict
[
attr
]
new_attr_dict
[
attr
]
=
attr_dict
[
'_'
+
attr
]
del
layer
,
attr_dict
layer
=
SuperInstanceNorm
(
**
new_attr_dict
)
layer
=
getattr
(
layers
,
'SuperInstanceNorm2D'
,
'SuperInstanceNorm'
)(
**
new_attr_dict
)
model
[
idx
]
=
layer
elif
isinstance
(
layer
,
LayerNorm
)
and
(
...
...
@@ -392,15 +497,19 @@ class Convert:
continue
attr_dict
=
layer
.
__dict__
new_attr_name
=
[
'_scale'
,
'_shift'
,
'_param_attr'
,
'_bias_attr'
,
'_act'
,
'_dtype'
,
'_epsilon'
new_attr_name
=
[
'epsilon'
,
'bias_attr'
]
if
pd_ver
==
185
:
new_attr_name
+=
[
'scale'
,
'shift'
,
'param_attr'
,
'act'
,
'dtype'
]
new_attr_dict
=
dict
()
else
:
new_attr_name
+=
[
'weight_attr'
]
new_attr_dict
=
dict
.
fromkeys
(
new_attr_name
,
None
)
new_attr_dict
[
'normalized_shape'
]
=
None
if
self
.
context
.
expand
:
new_attr_dict
[
'normalized_shape'
]
=
self
.
context
.
expand
*
int
(
attr_dict
[
'_normalized_shape'
][
0
])
new_attr_dict
[
'normalized_shape'
]
=
int
(
self
.
context
.
expand
*
attr_dict
[
'_normalized_shape'
][
0
])
elif
self
.
context
.
channel
:
new_attr_dict
[
'normalized_shape'
]
=
max
(
cur_channel
)
else
:
...
...
@@ -408,7 +517,7 @@ class Convert:
'_normalized_shape'
]
for
attr
in
new_attr_name
:
new_attr_dict
[
attr
[
1
:]]
=
attr_dict
[
attr
]
new_attr_dict
[
attr
]
=
attr_dict
[
'_'
+
attr
]
del
layer
,
attr_dict
layer
=
SuperLayerNorm
(
**
new_attr_dict
)
...
...
@@ -419,18 +528,32 @@ class Convert:
getattr
(
self
.
context
,
'channel'
,
None
)
!=
None
):
attr_dict
=
layer
.
__dict__
key
=
attr_dict
[
'_full_name'
]
new_attr_name
=
[
'_is_sparse'
,
'_is_distributed'
,
'_padding_idx'
,
'_param_attr'
,
'_dtype'
new_attr_name
=
[
'padding_idx'
,
]
if
pd_ver
==
185
:
new_attr_name
+=
[
'size'
,
'is_sparse'
,
'is_distributed'
,
'param_attr'
,
'dtype'
]
else
:
new_attr_name
+=
[
'num_embeddings'
,
'embedding_dim'
,
'sparse'
,
'weight_attr'
,
'name'
]
new_attr_dict
=
dict
(
)
new_attr_dict
=
dict
.
fromkeys
(
new_attr_name
,
None
)
new_attr_dict
[
'candidate_config'
]
=
dict
()
bef_size
=
attr_dict
[
'_size'
]
if
self
.
context
.
expand
:
if
pd_ver
==
185
:
new_attr_dict
[
'size'
]
=
[
bef_size
[
0
],
self
.
context
.
expand
*
bef_size
[
1
]
bef_size
[
0
],
int
(
self
.
context
.
expand
*
bef_size
[
1
])
]
else
:
new_attr_dict
[
'num_embeddings'
]
=
attr_dict
[
'_num_embeddings'
]
new_attr_dict
[
'embedding_dim'
]
=
int
(
self
.
context
.
expand
*
attr_dict
[
'_embedding_dim'
])
new_attr_dict
[
'candidate_config'
].
update
({
'expand_ratio'
:
self
.
context
.
expand_ratio
})
...
...
@@ -438,23 +561,52 @@ class Convert:
elif
self
.
context
.
channel
:
cur_channel
=
self
.
context
.
channel
[
0
]
self
.
context
.
channel
=
self
.
context
.
channel
[
1
:]
if
pd_ver
==
185
:
new_attr_dict
[
'size'
]
=
[
bef_size
[
0
],
max
(
cur_channel
)]
else
:
new_attr_dict
[
'num_embeddings'
]
=
attr_dict
[
'_num_embeddings'
]
new_attr_dict
[
'embedding_dim'
]
=
max
(
cur_channel
)
new_attr_dict
[
'candidate_config'
].
update
({
'channel'
:
cur_channel
})
pre_channel
=
cur_channel
else
:
if
pf_ver
==
185
:
new_attr_dict
[
'size'
]
=
bef_size
else
:
new_attr_dict
[
'num_embeddings'
]
=
attr_dict
[
'_num_embeddings'
]
new_attr_dict
[
'embedding_dim'
]
=
attr_dict
[
'_embedding_dim'
]
for
attr
in
new_attr_name
:
new_attr_dict
[
attr
[
1
:]]
=
attr_dict
[
attr
]
new_attr_dict
[
attr
]
=
attr_dict
[
'_'
+
attr
]
del
layer
,
attr_dict
layer
=
Block
(
SuperEmbedding
(
**
new_attr_dict
),
key
=
key
)
model
[
idx
]
=
layer
return
model
def
split_prefix
(
net
,
name_list
):
if
len
(
name_list
)
>
1
:
net
=
split_prefix
(
getattr
(
net
,
name_list
[
0
]),
name_list
[
1
:])
elif
len
(
name_list
)
==
1
:
net
=
getattr
(
net
,
name_list
[
0
])
else
:
raise
NotImplementedError
(
"name error"
)
return
net
if
isinstance
(
network
,
Layer
):
for
idx
,
(
name
,
sublayer
)
in
enumerate
(
network
.
named_sublayers
()):
if
len
(
name
.
split
(
'.'
))
>
1
:
net
=
split_prefix
(
network
,
name
.
split
(
'.'
)[:
-
1
])
else
:
net
=
network
setattr
(
net
,
name
.
split
(
'.'
)[
-
1
],
model
[
idx
])
return
network
class
supernet
:
...
...
@@ -474,12 +626,16 @@ class supernet:
self
.
expand
=
max
(
self
.
expand_ratio
)
elif
isinstance
(
self
.
expand_ratio
,
int
):
self
.
expand
=
self
.
expand_ratio
if
'channel'
not
in
kwargs
.
keys
():
self
.
channel
=
None
def
__enter__
(
self
):
return
Convert
(
self
)
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
pass
self
.
expand
=
None
self
.
channel
=
None
self
.
kernel_size
=
None
#def ofa_supernet(kernel_size, expand_ratio):
...
...
paddleslim/nas/ofa/layers_new.py
0 → 100644
浏览文件 @
40e0684a
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
logging
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
paddle.fluid.core
as
core
from
...common
import
get_logger
from
.utils.utils
import
compute_start_end
,
get_same_padding
,
convert_to_list
__all__
=
[
'SuperConv2D'
,
'SuperConv2DTranspose'
,
'SuperSeparableConv2D'
,
'SuperBatchNorm2D'
,
'SuperLinear'
,
'SuperInstanceNorm2D'
,
'Block'
,
'SuperGroupConv2D'
,
'SuperDepthwiseConv2D'
,
'SuperGroupConv2DTranspose'
,
'SuperDepthwiseConv2DTranspose'
,
'SuperLayerNorm'
,
'SuperEmbedding'
]
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
### TODO: if task is elastic width, need to add re_organize_middle_weight in 1x1 conv in MBBlock
_cnt
=
0
def
counter
():
global
_cnt
_cnt
+=
1
return
_cnt
class
BaseBlock
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
key
=
None
):
super
(
BaseBlock
,
self
).
__init__
()
if
key
is
not
None
:
self
.
_key
=
str
(
key
)
else
:
self
.
_key
=
self
.
__class__
.
__name__
+
str
(
counter
())
# set SuperNet class
def
set_supernet
(
self
,
supernet
):
self
.
__dict__
[
'supernet'
]
=
supernet
@
property
def
key
(
self
):
return
self
.
_key
class
Block
(
BaseBlock
):
"""
Model is composed of nest blocks.
Parameters:
fn(Layer): instance of super layers, such as: SuperConv2D(3, 5, 3).
key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
"""
def
__init__
(
self
,
fn
,
fixed
=
False
,
key
=
None
):
super
(
Block
,
self
).
__init__
(
key
)
self
.
fn
=
fn
self
.
fixed
=
fixed
self
.
candidate_config
=
self
.
fn
.
candidate_config
def
forward
(
self
,
*
inputs
,
**
kwargs
):
out
=
self
.
supernet
.
layers_forward
(
self
,
*
inputs
,
**
kwargs
)
return
out
class
SuperConv2D
(
nn
.
Conv2D
):
"""
This interface is used to construct a callable object of the ``SuperConv2D`` class.
The difference between ```SuperConv2D``` and ```Conv2D``` is: ```SuperConv2D``` need
to feed a config dictionary with the format of {'channel', num_of_channel} represents
the channels of the outputs, used to change the first dimension of weight and bias,
only train the first channels of the weight and bias.
Note: the channel in config need to less than first defined.
The super convolution2D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input and
Output are in NCHW format, where N is batch size, C is the number of
the feature map, H is the height of the feature map, and W is the width of the feature map.
Filter's shape is [MCHW] , where M is the number of output feature map,
C is the number of input feature map, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input feature map divided by the groups.
Please refer to UFLDL's `convolution
<http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`_
for more details.
If bias attribution and activation type are provided, bias is added to the
output of the convolution, and the corresponding activation function is
applied to the final result.
For each input :math:`X`, the equation is:
.. math::
Out =
\\
sigma (W
\\
ast X + b)
Where:
* :math:`X`: Input value, a ``Tensor`` with NCHW format.
* :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] .
* :math:`
\\
ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1].
* :math:`
\\
sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H_{out}&=
\\
frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1
\\\\
W_{out}&=
\\
frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Parameters:
num_channels(int): The number of channels in the input image.
num_filters(int): The number of filter. It is as same as the output
feature map.
filter_size (int or tuple): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
candidate_config(dict, optional): Dictionary descripts candidate config of this layer,
such as {'kernel_size': (3, 5, 7), 'channel': (4, 6, 8)}, means the kernel size of
this layer can be choose from (3, 5, 7), the key of candidate_config
only can be 'kernel_size', 'channel' and 'expand_ratio', 'channel' and 'expand_ratio'
CANNOT be set at the same time. Default: None.
transform_kernel(bool, optional): Whether to use transform matrix to transform a large filter
to a small filter. Default: False.
stride (int or tuple, optional): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: 1.
padding (int or tuple, optional): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: 0.
dilation (int or tuple, optional): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: 1.
groups (int, optional): The groups number of the Conv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: 1.
param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter)
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(
\\
frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn (bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
act (str, optional): Activation type, if it is set to None, activation is not appended.
Default: None.
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Attribute:
**weight** (Parameter): the learnable weights of filter of this layer.
**bias** (Parameter or None): the learnable bias of this layer.
Returns:
None
Raises:
ValueError: if ``use_cudnn`` is not a bool value.
Examples:
.. code-block:: python
import paddle
from paddleslim.nas.ofa.layers import SuperConv2D
import numpy as np
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
super_conv2d = SuperConv2D(3, 10, 3)
config = {'channel': 5}
data = paddle.to_variable(data)
conv = super_conv2d(data, config)
"""
### NOTE: filter_size, num_channels and num_filters must be the max of candidate to define a largest network.
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
candidate_config
=
{},
transform_kernel
=
False
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
padding_mode
=
'zeros'
,
weight_attr
=
None
,
bias_attr
=
None
,
data_format
=
'NCHW'
):
super
(
SuperConv2D
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
padding_mode
=
padding_mode
,
dilation
=
dilation
,
groups
=
groups
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
,
data_format
=
data_format
)
self
.
candidate_config
=
candidate_config
if
len
(
candidate_config
.
items
())
!=
0
:
for
k
,
v
in
candidate_config
.
items
():
candidate_config
[
k
]
=
list
(
set
(
v
))
self
.
ks_set
=
candidate_config
[
'kernel_size'
]
if
'kernel_size'
in
candidate_config
else
None
self
.
expand_ratio
=
candidate_config
[
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
channel
=
candidate_config
[
'channel'
]
if
'channel'
in
candidate_config
else
None
self
.
base_channel
=
self
.
_out_channels
if
self
.
expand_ratio
!=
None
:
self
.
base_channel
=
int
(
self
.
_out_channels
/
max
(
self
.
expand_ratio
))
self
.
transform_kernel
=
transform_kernel
if
self
.
ks_set
!=
None
:
self
.
ks_set
.
sort
()
if
self
.
transform_kernel
!=
False
:
scale_param
=
dict
()
### create parameter to transform kernel
for
i
in
range
(
len
(
self
.
ks_set
)
-
1
):
ks_small
=
self
.
ks_set
[
i
]
ks_large
=
self
.
ks_set
[
i
+
1
]
param_name
=
'%dto%d_matrix'
%
(
ks_large
,
ks_small
)
ks_t
=
ks_small
**
2
scale_param
[
param_name
]
=
self
.
create_parameter
(
attr
=
paddle
.
ParamAttr
(
name
=
self
.
_full_name
+
param_name
,
initializer
=
nn
.
initializer
.
Assign
(
np
.
eye
(
ks_t
))),
shape
=
(
ks_t
,
ks_t
),
dtype
=
self
.
_dtype
)
for
name
,
param
in
scale_param
.
items
():
setattr
(
self
,
name
,
param
)
def
get_active_filter
(
self
,
in_nc
,
out_nc
,
kernel_size
):
start
,
end
=
compute_start_end
(
self
.
_kernel_size
[
0
],
kernel_size
)
### if NOT transform kernel, intercept a center filter with kernel_size from largest filter
filters
=
self
.
weight
[:
out_nc
,
:
in_nc
,
start
:
end
,
start
:
end
]
if
self
.
transform_kernel
!=
False
and
kernel_size
<
self
.
_kernel_size
[
0
]:
### if transform kernel, then use matrix to transform
start_filter
=
self
.
weight
[:
out_nc
,
:
in_nc
,
:,
:]
for
i
in
range
(
len
(
self
.
ks_set
)
-
1
,
0
,
-
1
):
src_ks
=
self
.
ks_set
[
i
]
if
src_ks
<=
kernel_size
:
break
target_ks
=
self
.
ks_set
[
i
-
1
]
start
,
end
=
compute_start_end
(
src_ks
,
target_ks
)
_input_filter
=
start_filter
[:,
:,
start
:
end
,
start
:
end
]
_input_filter
=
paddle
.
reshape
(
_input_filter
,
shape
=
[(
_input_filter
.
shape
[
0
]
*
_input_filter
.
shape
[
1
]),
-
1
])
_input_filter
=
paddle
.
matmul
(
_input_filter
,
self
.
__getattr__
(
'%dto%d_matrix'
%
(
src_ks
,
target_ks
)),
False
,
False
)
_input_filter
=
paddle
.
reshape
(
_input_filter
,
shape
=
[
filters
.
shape
[
0
],
filters
.
shape
[
1
],
target_ks
,
target_ks
])
start_filter
=
_input_filter
filters
=
start_filter
return
filters
def
get_groups_in_out_nc
(
self
,
in_nc
,
out_nc
):
### standard conv
return
self
.
_groups
,
in_nc
,
out_nc
def
forward
(
self
,
input
,
kernel_size
=
None
,
expand_ratio
=
None
,
channel
=
None
):
self
.
cur_config
=
{
'kernel_size'
:
kernel_size
,
'expand_ratio'
:
expand_ratio
,
'channel'
:
channel
}
in_nc
=
int
(
input
.
shape
[
1
])
assert
(
expand_ratio
==
None
or
channel
==
None
),
"expand_ratio and channel CANNOT be NOT None at the same time."
if
expand_ratio
!=
None
:
out_nc
=
int
(
expand_ratio
*
self
.
base_channel
)
elif
channel
!=
None
:
out_nc
=
int
(
channel
)
else
:
out_nc
=
self
.
_out_channels
ks
=
int
(
self
.
_kernel_size
[
0
])
if
kernel_size
==
None
else
int
(
kernel_size
)
groups
,
weight_in_nc
,
weight_out_nc
=
self
.
get_groups_in_out_nc
(
in_nc
,
out_nc
)
weight
=
self
.
get_active_filter
(
weight_in_nc
,
weight_out_nc
,
ks
)
if
kernel_size
!=
None
or
'kernel_size'
in
self
.
candidate_config
.
keys
():
padding
=
convert_to_list
(
get_same_padding
(
ks
),
2
)
else
:
padding
=
self
.
_padding
if
self
.
bias
is
not
None
:
bias
=
self
.
bias
[:
out_nc
]
else
:
bias
=
self
.
bias
out
=
F
.
conv2d
(
input
,
weight
,
bias
=
bias
,
stride
=
self
.
_stride
,
padding
=
padding
,
dilation
=
self
.
_dilation
,
groups
=
self
.
_groups
,
data_format
=
self
.
_data_format
)
return
out
class
SuperGroupConv2D
(
SuperConv2D
):
def
get_groups_in_out_nc
(
self
,
in_nc
,
out_nc
):
### groups convolution
### conv: weight: (Cout, Cin/G, Kh, Kw)
groups
=
self
.
_groups
in_nc
=
int
(
in_nc
//
groups
)
return
groups
,
in_nc
,
out_nc
class
SuperDepthwiseConv2D
(
SuperConv2D
):
### depthwise convolution
def
get_groups_in_out_nc
(
self
,
in_nc
,
out_nc
):
if
in_nc
!=
out_nc
:
_logger
.
debug
(
"input channel and output channel in depthwise conv is different, change output channel to input channel! origin channel:(in_nc {}, out_nc {}): "
.
format
(
in_nc
,
out_nc
))
groups
=
in_nc
out_nc
=
in_nc
return
groups
,
in_nc
,
out_nc
class
SuperConv2DTranspose
(
nn
.
Conv2DTranspose
):
"""
This interface is used to construct a callable object of the ``SuperConv2DTranspose``
class.
The difference between ```SuperConv2DTranspose``` and ```Conv2DTranspose``` is:
```SuperConv2DTranspose``` need to feed a config dictionary with the format of
{'channel', num_of_channel} represents the channels of the outputs, used to change
the first dimension of weight and bias, only train the first channels of the weight
and bias.
Note: the channel in config need to less than first defined.
The super convolution2D transpose layer calculates the output based on the input,
filter, and dilations, strides, paddings. Input and output
are in NCHW format. Where N is batch size, C is the number of feature map,
H is the height of the feature map, and W is the width of the feature map.
Filter's shape is [MCHW] , where M is the number of input feature map,
C is the number of output feature map, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input feature map divided by the groups.
If bias attribution and activation type are provided, bias is added to
the output of the convolution, and the corresponding activation function
is applied to the final result.
The details of convolution transpose layer, please refer to the following explanation and references
`conv2dtranspose <http://www.matthewzeiler.com/wp-content/uploads/2017/07/cvpr2010.pdf>`_ .
For each input :math:`X`, the equation is:
.. math::
Out = \sigma (W
\\
ast X + b)
Where:
* :math:`X`: Input value, a ``Tensor`` with NCHW format.
* :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] .
* :math:`
\\
ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1].
* :math:`
\\
sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1
\\\\
W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1
\\\\
H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] )
\\\\
W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] )
Parameters:
num_channels(int): The number of channels in the input image.
num_filters(int): The number of the filter. It is as same as the output
feature map.
filter_size(int or tuple): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
candidate_config(dict, optional): Dictionary descripts candidate config of this layer,
such as {'kernel_size': (3, 5, 7), 'channel': (4, 6, 8)}, means the kernel size of
this layer can be choose from (3, 5, 7), the key of candidate_config
only can be 'kernel_size', 'channel' and 'expand_ratio', 'channel' and 'expand_ratio'
CANNOT be set at the same time. Default: None.
transform_kernel(bool, optional): Whether to use transform matrix to transform a large filter
to a small filter. Default: False.
output_size(int or tuple, optional): The output image size. If output size is a
tuple, it must contain two integers, (image_H, image_W). None if use
filter_size, padding, and stride to calculate output_size.
if output_size and filter_size are specified at the same time, They
should follow the formula above. Default: None.
padding(int or tuple, optional): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: 0.
stride(int or tuple, optional): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: 1.
dilation(int or tuple, optional): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: 1.
groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by
grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
when group=2, the first half of the filters is only connected to the
first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels.
Default: 1.
param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter)
of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d_transpose.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
act (str, optional): Activation type, if it is set to None, activation is not appended.
Default: None.
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Attribute:
**weight** (Parameter): the learnable weights of filters of this layer.
**bias** (Parameter or None): the learnable bias of this layer.
Returns:
None
Examples:
.. code-block:: python
import paddle
import numpy as np
from paddleslim.nas.ofa.layers import SuperConv2DTranspose
data = np.random.random((3, 32, 32, 5)).astype('float32')
config = {'channel': 5}
super_convtranspose = SuperConv2DTranspose(num_channels=32, num_filters=10, filter_size=3)
ret = super_convtranspose(paddle.to_variable(data), config)
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
candidate_config
=
{},
transform_kernel
=
False
,
stride
=
1
,
padding
=
0
,
output_padding
=
0
,
dilation
=
1
,
groups
=
1
,
weight_attr
=
None
,
bias_attr
=
None
,
data_format
=
"NCHW"
):
super
(
SuperConv2DTranspose
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
output_padding
=
output_padding
,
groups
=
groups
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
,
data_format
=
data_format
)
self
.
candidate_config
=
candidate_config
if
len
(
self
.
candidate_config
.
items
())
!=
0
:
for
k
,
v
in
candidate_config
.
items
():
candidate_config
[
k
]
=
list
(
set
(
v
))
self
.
ks_set
=
candidate_config
[
'kernel_size'
]
if
'kernel_size'
in
candidate_config
else
None
self
.
expand_ratio
=
candidate_config
[
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
channel
=
candidate_config
[
'channel'
]
if
'channel'
in
candidate_config
else
None
self
.
base_channel
=
self
.
_out_channels
if
self
.
expand_ratio
:
self
.
base_channel
=
int
(
self
.
_out_channels
/
max
(
self
.
expand_ratio
))
self
.
transform_kernel
=
transform_kernel
if
self
.
ks_set
!=
None
:
self
.
ks_set
.
sort
()
if
self
.
transform_kernel
!=
False
:
scale_param
=
dict
()
### create parameter to transform kernel
for
i
in
range
(
len
(
self
.
ks_set
)
-
1
):
ks_small
=
self
.
ks_set
[
i
]
ks_large
=
self
.
ks_set
[
i
+
1
]
param_name
=
'%dto%d_matrix'
%
(
ks_large
,
ks_small
)
ks_t
=
ks_small
**
2
scale_param
[
param_name
]
=
self
.
create_parameter
(
attr
=
paddle
.
ParamAttr
(
name
=
self
.
_full_name
+
param_name
,
initializer
=
nn
.
initializer
.
Assign
(
np
.
eye
(
ks_t
))),
shape
=
(
ks_t
,
ks_t
),
dtype
=
self
.
_dtype
)
for
name
,
param
in
scale_param
.
items
():
setattr
(
self
,
name
,
param
)
def
get_active_filter
(
self
,
in_nc
,
out_nc
,
kernel_size
):
start
,
end
=
compute_start_end
(
self
.
_kernel_size
[
0
],
kernel_size
)
filters
=
self
.
weight
[:
in_nc
,
:
out_nc
,
start
:
end
,
start
:
end
]
if
self
.
transform_kernel
!=
False
and
kernel_size
<
self
.
_kernel_size
[
0
]:
start_filter
=
self
.
weight
[:
in_nc
,
:
out_nc
,
:,
:]
for
i
in
range
(
len
(
self
.
ks_set
)
-
1
,
0
,
-
1
):
src_ks
=
self
.
ks_set
[
i
]
if
src_ks
<=
kernel_size
:
break
target_ks
=
self
.
ks_set
[
i
-
1
]
start
,
end
=
compute_start_end
(
src_ks
,
target_ks
)
_input_filter
=
start_filter
[:,
:,
start
:
end
,
start
:
end
]
_input_filter
=
paddle
.
reshape
(
_input_filter
,
shape
=
[(
_input_filter
.
shape
[
0
]
*
_input_filter
.
shape
[
1
]),
-
1
])
_input_filter
=
paddle
.
matmul
(
_input_filter
,
self
.
__getattr__
(
'%dto%d_matrix'
%
(
src_ks
,
target_ks
)),
False
,
False
)
_input_filter
=
paddle
.
reshape
(
_input_filter
,
shape
=
[
filters
.
shape
[
0
],
filters
.
shape
[
1
],
target_ks
,
target_ks
])
start_filter
=
_input_filter
filters
=
start_filter
return
filters
def
get_groups_in_out_nc
(
self
,
in_nc
,
out_nc
):
### standard conv
return
self
.
_groups
,
in_nc
,
out_nc
def
forward
(
self
,
input
,
output_size
=
None
,
kernel_size
=
None
,
expand_ratio
=
None
,
channel
=
None
):
self
.
cur_config
=
{
'kernel_size'
:
kernel_size
,
'expand_ratio'
:
expand_ratio
,
'channel'
:
channel
}
in_nc
=
int
(
input
.
shape
[
1
])
assert
(
expand_ratio
==
None
or
channel
==
None
),
"expand_ratio and channel CANNOT be NOT None at the same time."
if
expand_ratio
!=
None
:
out_nc
=
int
(
expand_ratio
*
self
.
base_channel
)
elif
channel
!=
None
:
out_nc
=
int
(
channel
)
else
:
out_nc
=
self
.
_out_channels
ks
=
int
(
self
.
_kernel_size
[
0
])
if
kernel_size
==
None
else
int
(
kernel_size
)
groups
,
weight_in_nc
,
weight_out_nc
=
self
.
get_groups_in_out_nc
(
in_nc
,
out_nc
)
weight
=
self
.
get_active_filter
(
weight_in_nc
,
weight_out_nc
,
ks
)
if
kernel_size
!=
None
or
'kernel_size'
in
self
.
candidate_config
.
keys
():
padding
=
convert_to_list
(
get_same_padding
(
ks
),
2
)
else
:
padding
=
self
.
_padding
if
output_size
is
None
:
output_padding
=
self
.
output_padding
else
:
output_padding
=
0
if
self
.
bias
is
not
None
:
bias
=
self
.
bias
[:
out_nc
]
else
:
bias
=
self
.
bias
out
=
F
.
conv2d_transpose
(
input
,
weight
,
bias
=
bias
,
padding
=
padding
,
output_padding
=
output_padding
,
stride
=
self
.
_stride
,
dilation
=
self
.
_dilation
,
groups
=
self
.
_groups
,
output_size
=
output_size
,
data_format
=
self
.
_data_format
)
return
out
class
SuperGroupConv2DTranspose
(
SuperConv2DTranspose
):
def
get_groups_in_out_nc
(
self
,
in_nc
,
out_nc
):
### groups convolution
### groups conv transpose: weight: (Cin, Cout/G, Kh, Kw)
groups
=
self
.
_groups
out_nc
=
int
(
out_nc
//
groups
)
return
groups
,
in_nc
,
out_nc
class
SuperDepthwiseConv2DTranspose
(
SuperConv2DTranspose
):
def
get_groups_in_out_nc
(
self
,
in_nc
,
out_nc
):
if
in_nc
!=
out_nc
:
_logger
.
debug
(
"input channel and output channel in depthwise conv transpose is different, change output channel to input channel! origin channel:(in_nc {}, out_nc {}): "
.
format
(
in_nc
,
out_nc
))
groups
=
in_nc
out_nc
=
in_nc
return
groups
,
in_nc
,
out_nc
### NOTE: only search channel, write for GAN-compression, maybe change to SuperDepthwiseConv and SuperConv after.
class
SuperSeparableConv2D
(
nn
.
Layer
):
"""
This interface is used to construct a callable object of the ``SuperSeparableConv2D``
class.
The difference between ```SuperSeparableConv2D``` and ```SeparableConv2D``` is:
```SuperSeparableConv2D``` need to feed a config dictionary with the format of
{'channel', num_of_channel} represents the channels of the first conv's outputs and
the second conv's inputs, used to change the first dimension of weight and bias,
only train the first channels of the weight and bias.
The architecture of super separable convolution2D op is [Conv2D, norm layer(may be BatchNorm2D
or InstanceNorm2D), Conv2D]. The first conv is depthwise conv, the filter number is input channel
multiply scale_factor, the group is equal to the number of input channel. The second conv
is standard conv, which filter size and stride size are 1.
Parameters:
num_channels(int): The number of channels in the input image.
num_filters(int): The number of the second conv's filter. It is as same as the output
feature map.
filter_size(int or tuple): The first conv's filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
padding(int or tuple, optional): The first conv's padding size. If padding is a tuple,
it must contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: 0.
stride(int or tuple, optional): The first conv's stride size. If stride is a tuple,
it must contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: 1.
dilation(int or tuple, optional): The first conv's dilation size. If dilation is a tuple,
it must contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: 1.
norm_layer(class): The normalization layer between two convolution. Default: InstanceNorm2D.
bias_attr (ParamAttr or bool, optional): The attribute for the bias of convolution.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, convolution
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
scale_factor(float): The scale factor of the first conv's output channel. Default: 1.
Returns:
None
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
candidate_config
=
{},
stride
=
1
,
padding
=
0
,
dilation
=
1
,
norm_layer
=
nn
.
InstanceNorm2D
,
bias_attr
=
None
,
scale_factor
=
1
):
super
(
SuperSeparableConv2D
,
self
).
__init__
()
self
.
conv
=
nn
.
LayerList
([
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
*
scale_factor
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
in_channels
,
bias_attr
=
bias_attr
)
])
self
.
conv
.
extend
([
norm_layer
(
in_channels
*
scale_factor
)])
self
.
conv
.
extend
([
nn
.
Conv2D
(
in_channels
=
in_channels
*
scale_factor
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
,
bias_attr
=
bias_attr
)
])
self
.
candidate_config
=
candidate_config
self
.
expand_ratio
=
candidate_config
[
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
base_output_dim
=
self
.
conv
[
0
].
_out_channels
if
self
.
expand_ratio
!=
None
:
self
.
base_output_dim
=
int
(
self
.
conv
[
0
].
_out_channels
/
max
(
self
.
expand_ratio
))
def
forward
(
self
,
input
,
expand_ratio
=
None
,
channel
=
None
):
self
.
cur_config
=
{
'expand_ratio'
:
expand_ratio
,
'channel'
:
channel
}
in_nc
=
int
(
input
.
shape
[
1
])
assert
(
expand_ratio
==
None
or
channel
==
None
),
"expand_ratio and channel CANNOT be NOT None at the same time."
if
expand_ratio
!=
None
:
out_nc
=
int
(
expand_ratio
*
self
.
base_output_dim
)
elif
channel
!=
None
:
out_nc
=
int
(
channel
)
else
:
out_nc
=
self
.
conv
[
0
].
_out_channels
weight
=
self
.
conv
[
0
].
weight
[:
in_nc
]
### conv1
if
self
.
conv
[
0
].
bias
is
not
None
:
bias
=
self
.
conv
[
0
].
bias
[:
in_nc
]
else
:
bias
=
self
.
conv
[
0
].
bias
conv0_out
=
F
.
conv2d
(
input
,
weight
,
bias
,
stride
=
self
.
conv
[
0
].
_stride
,
padding
=
self
.
conv
[
0
].
_padding
,
dilation
=
self
.
conv
[
0
].
_dilation
,
groups
=
in_nc
,
data_format
=
self
.
conv
[
0
].
_data_format
)
norm_out
=
self
.
conv
[
1
](
conv0_out
)
weight
=
self
.
conv
[
2
].
weight
[:
out_nc
,
:
in_nc
,
:,
:]
if
self
.
conv
[
2
].
bias
is
not
None
:
bias
=
self
.
conv
[
2
].
bias
[:
out_nc
]
else
:
bias
=
self
.
conv
[
2
].
bias
conv1_out
=
F
.
conv2d
(
norm_out
,
weight
,
bias
,
stride
=
self
.
conv
[
2
].
_stride
,
padding
=
self
.
conv
[
2
].
_padding
,
dilation
=
self
.
conv
[
2
].
_dilation
,
groups
=
self
.
conv
[
2
].
_groups
,
data_format
=
self
.
conv
[
2
].
_data_format
)
return
conv1_out
class
SuperLinear
(
nn
.
Linear
):
"""
"""
def
__init__
(
self
,
in_features
,
out_features
,
candidate_config
=
{},
weight_attr
=
None
,
bias_attr
=
None
,
name
=
None
):
super
(
SuperLinear
,
self
).
__init__
(
in_features
,
out_features
,
weight_attr
,
bias_attr
,
name
)
self
.
_weight_attr
=
weight_attr
self
.
_bias_attr
=
bias_attr
self
.
_in_features
=
in_features
self
.
_out_features
=
out_features
self
.
candidate_config
=
candidate_config
self
.
expand_ratio
=
candidate_config
[
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
base_output_dim
=
self
.
_out_features
if
self
.
expand_ratio
!=
None
:
self
.
base_output_dim
=
int
(
self
.
_out_features
/
max
(
self
.
expand_ratio
))
def
forward
(
self
,
input
,
expand_ratio
=
None
,
channel
=
None
):
self
.
cur_config
=
{
'expand_ratio'
:
expand_ratio
,
'channel'
:
channel
}
### weight: (Cin, Cout)
in_nc
=
int
(
input
.
shape
[
-
1
])
assert
(
expand_ratio
==
None
or
channel
==
None
),
"expand_ratio and channel CANNOT be NOT None at the same time."
if
expand_ratio
!=
None
:
out_nc
=
int
(
expand_ratio
*
self
.
base_output_dim
)
elif
channel
!=
None
:
out_nc
=
int
(
channel
)
else
:
out_nc
=
self
.
_out_features
weight
=
self
.
weight
[:
in_nc
,
:
out_nc
]
if
self
.
_bias_attr
!=
False
:
bias
=
self
.
bias
[:
out_nc
]
else
:
bias
=
self
.
bias
out
=
F
.
linear
(
x
=
input
,
weight
=
weight
,
bias
=
bias
,
name
=
self
.
name
)
return
out
class
SuperBatchNorm2D
(
nn
.
BatchNorm2D
):
"""
add comment
"""
def
__init__
(
self
,
num_features
,
momentum
=
0.9
,
epsilon
=
1e-05
,
weight_attr
=
None
,
bias_attr
=
None
,
data_format
=
'NCHW'
,
name
=
None
):
super
(
SuperBatchNorm2D
,
self
).
__init__
(
num_features
,
momentum
,
epsilon
,
weight_attr
,
bias_attr
,
data_format
,
name
)
def
forward
(
self
,
input
):
self
.
_check_data_format
(
self
.
_data_format
)
self
.
_check_input_dim
(
input
)
feature_dim
=
int
(
input
.
shape
[
1
])
weight
=
self
.
weight
[:
feature_dim
]
bias
=
self
.
bias
[:
feature_dim
]
mean
=
self
.
_mean
[:
feature_dim
]
variance
=
self
.
_variance
[:
feature_dim
]
return
F
.
batch_norm
(
input
,
mean
,
variance
,
weight
=
weight
,
bias
=
bias
,
training
=
self
.
training
,
momentum
=
self
.
_momentum
,
epsilon
=
self
.
_epsilon
,
data_format
=
self
.
_data_format
)
class
SuperInstanceNorm2D
(
nn
.
InstanceNorm2D
):
"""
"""
def
__init__
(
self
,
num_features
,
epsilon
=
1e-05
,
momentum
=
0.9
,
weight_attr
=
None
,
bias_attr
=
None
,
data_format
=
'NCHW'
,
name
=
None
):
super
(
SuperInstanceNorm2D
,
self
).
__init__
(
num_features
,
epsilon
,
momentum
,
weight_attr
,
bias_attr
,
data_format
,
name
)
def
forward
(
self
,
input
):
self
.
_check_input_dim
(
input
)
feature_dim
=
int
(
input
.
shape
[
1
])
if
self
.
_weight_attr
==
False
and
self
.
_bias_attr
==
False
:
scale
=
None
bias
=
None
else
:
scale
=
self
.
scale
[:
feature_dim
]
bias
=
self
.
bias
[:
feature_dim
]
return
F
.
instance_norm
(
input
,
scale
,
bias
,
eps
=
self
.
_epsilon
)
class
SuperLayerNorm
(
nn
.
LayerNorm
):
def
__init__
(
self
,
normalized_shape
,
epsilon
=
1e-05
,
weight_attr
=
None
,
bias_attr
=
None
,
name
=
None
):
super
(
SuperLayerNorm
,
self
).
__init__
(
normalized_shape
,
epsilon
,
weight_attr
,
bias_attr
,
name
)
def
forward
(
self
,
input
):
### TODO(ceci3): fix if normalized_shape is not a single number
input_ndim
=
len
(
list
(
input
.
shape
))
normalized_ndim
=
len
(
self
.
_normalized_shape
)
begin_norm_axis
=
input_ndim
-
normalized_ndim
feature_dim
=
int
(
input
.
shape
[
-
1
])
if
self
.
_weight_attr
!=
False
:
weight
=
self
.
weight
[:
feature_dim
]
else
:
weight
=
None
if
self
.
_bias_attr
!=
False
:
bias
=
self
.
bias
[:
feature_dim
]
else
:
bias
=
None
out
,
_
,
_
=
core
.
ops
.
layer_norm
(
input
,
weight
,
bias
,
'epsilon'
,
self
.
_epsilon
,
'begin_norm_axis'
,
begin_norm_axis
)
return
out
class
SuperEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
candidate_config
=
{},
padding_idx
=
None
,
sparse
=
False
,
weight_attr
=
None
,
name
=
None
):
super
(
SuperEmbedding
,
self
).
__init__
(
num_embeddings
,
embedding_dim
,
sparse
,
weight_attr
,
name
)
self
.
candidate_config
=
candidate_config
self
.
expand_ratio
=
candidate_config
[
'expand_ratio'
]
if
'expand_ratio'
in
candidate_config
else
None
self
.
base_output_dim
=
self
.
_embedding_dim
if
self
.
expand_ratio
!=
None
:
self
.
base_output_dim
=
int
(
self
.
_embedding_dim
/
max
(
self
.
expand_ratio
))
def
forward
(
self
,
input
,
expand_ratio
=
None
,
channel
=
None
):
assert
(
expand_ratio
==
None
or
channel
==
None
),
"expand_ratio and channel CANNOT be NOT None at the same time."
if
expand_ratio
!=
None
:
out_nc
=
int
(
expand_ratio
*
self
.
base_output_dim
)
elif
channel
!=
None
:
out_nc
=
int
(
channel
)
else
:
out_nc
=
self
.
_embedding_dim
weight
=
self
.
weight
[:,
:
out_nc
]
return
F
.
embedding
(
input
,
weight
=
weight
,
padding_idx
=
self
.
_padding_idx
,
sparse
=
self
.
_sparse
,
name
=
self
.
_name
)
paddleslim/nas/ofa/ofa.py
浏览文件 @
40e0684a
...
...
@@ -16,10 +16,15 @@ import logging
import
numpy
as
np
from
collections
import
namedtuple
import
paddle
#import paddle.nn as nn
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Conv2D
from
.layers
import
BaseBlock
,
Block
,
SuperConv2D
,
SuperBatchNorm
from
.utils.utils
import
get_paddle_version
pd_ver
=
get_paddle_version
()
if
pd_ver
==
185
:
from
.layers
import
BaseBlock
,
SuperConv2D
Layer
=
paddle
.
fluid
.
dygraph
.
Layer
else
:
from
.layers_new
import
BaseBlock
,
SuperConv2D
Layer
=
paddle
.
nn
.
Layer
from
.utils.utils
import
search_idx
from
...common
import
get_logger
...
...
@@ -40,7 +45,7 @@ DistillConfig = namedtuple('DistillConfig', [
DistillConfig
.
__new__
.
__defaults__
=
(
None
,
)
*
len
(
DistillConfig
.
_fields
)
class
OFABase
(
fluid
.
dygraph
.
Layer
):
class
OFABase
(
Layer
):
def
__init__
(
self
,
model
):
super
(
OFABase
,
self
).
__init__
()
self
.
model
=
model
...
...
@@ -169,8 +174,7 @@ class OFA(OFABase):
)
### instance model by user can input super-param easily.
assert
isinstance
(
self
.
distill_config
.
teacher_model
,
paddle
.
fluid
.
dygraph
.
Layer
)
assert
isinstance
(
self
.
distill_config
.
teacher_model
,
Layer
)
# load teacher parameter
if
self
.
distill_config
.
teacher_model_path
!=
None
:
...
...
@@ -190,9 +194,10 @@ class OFA(OFABase):
for
name
,
sublayer
in
self
.
model
.
named_sublayers
():
if
name
in
mapping_layers
:
netA
=
SuperConv2D
(
sublayer
.
_num_filters
,
sublayer
.
_num_filters
,
filter_size
=
1
)
getattr
(
sublayer
,
'_num_filters'
,
sublayer
.
_out_channels
),
getattr
(
sublayer
,
'_num_filters'
,
sublayer
.
_out_channels
),
1
)
self
.
netAs_param
.
extend
(
netA
.
parameters
())
self
.
netAs
.
append
(
netA
)
...
...
@@ -288,7 +293,8 @@ class OFA(OFABase):
n
=
self
.
distill_config
.
mapping_layers
[
i
]
Tact
=
self
.
Tacts
[
n
]
Sact
=
self
.
Sacts
[
n
]
Sact
=
netA
(
Sact
,
channel
=
netA
.
_num_filters
)
Sact
=
netA
(
Sact
,
channel
=
getattr
(
netA
,
'_num_filters'
,
netA
.
_out_channels
))
if
self
.
distill_config
.
distill_fn
==
None
:
loss
=
fluid
.
layers
.
mse_loss
(
Sact
,
Tact
)
else
:
...
...
paddleslim/nas/ofa/utils/utils.py
浏览文件 @
40e0684a
...
...
@@ -44,3 +44,13 @@ def search_idx(num, sorted_nestlist):
return
idx
,
phase_idx
assert
num
>
max_num
return
len
(
sorted_nestlist
)
-
1
,
max_idx
def
get_paddle_version
():
import
paddle
pd_ver
=
185
if
hasattr
(
paddle
,
'nn'
):
if
hasattr
(
paddle
.
nn
,
'Conv1D'
):
### judge 2.0 alpha
pd_ver
=
200
return
pd_ver
tests/test_convert_supernet.py
0 → 100644
浏览文件 @
40e0684a
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
sys
.
path
.
append
(
"../"
)
import
unittest
from
paddle.vision.models
import
mobilenet_v1
from
paddleslim.nas.ofa.convert_super
import
Convert
,
supernet
class
TestConvertSuper
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
model
=
mobilenet_v1
()
def
test_convert
(
self
):
sp_net_config
=
supernet
(
kernel_size
=
(
3
,
5
,
7
),
expand_ratio
=
[
1
,
2
,
4
])
sp_model
=
Convert
(
sp_net_config
).
convert
(
self
.
model
)
assert
len
(
sp_model
.
sublayers
())
==
151
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/test_ofa.py
浏览文件 @
40e0684a
...
...
@@ -17,16 +17,15 @@ sys.path.append("../")
import
numpy
as
np
import
unittest
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.dygraph.nn
as
nn
import
paddle.nn
as
nn
from
paddle.nn
import
ReLU
from
paddleslim.nas
import
ofa
from
paddleslim.nas.ofa
import
OFA
,
RunConfig
,
DistillConfig
from
paddleslim.nas.ofa.convert_super
import
supernet
from
paddleslim.nas.ofa.layers
import
Block
,
SuperSeparableConv2D
from
paddleslim.nas.ofa.layers
_new
import
Block
,
SuperSeparableConv2D
class
ModelConv
(
fluid
.
dygraph
.
Layer
):
class
ModelConv
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
ModelConv
,
self
).
__init__
()
with
supernet
(
...
...
@@ -35,16 +34,13 @@ class ModelConv(fluid.dygraph.Layer):
(
8
,
12
,
16
)))
as
ofa_super
:
models
=
[]
models
+=
[
nn
.
Conv2D
(
3
,
4
,
3
,
padding
=
1
)]
models
+=
[
nn
.
InstanceNorm
(
4
)]
models
+=
[
nn
.
InstanceNorm
2D
(
4
)]
models
+=
[
ReLU
()]
models
+=
[
nn
.
Conv2D
(
4
,
4
,
3
,
groups
=
4
)]
models
+=
[
nn
.
InstanceNorm
(
4
)]
models
+=
[
nn
.
InstanceNorm
2D
(
4
)]
models
+=
[
ReLU
()]
models
+=
[
nn
.
Conv2DTranspose
(
4
,
4
,
3
,
groups
=
4
,
padding
=
1
,
use_cudnn
=
True
)
]
models
+=
[
nn
.
BatchNorm
(
4
)]
models
+=
[
nn
.
Conv2DTranspose
(
4
,
4
,
3
,
groups
=
4
,
padding
=
1
)]
models
+=
[
nn
.
BatchNorm2D
(
4
)]
models
+=
[
ReLU
()]
models
+=
[
nn
.
Conv2D
(
4
,
3
,
3
)]
models
+=
[
ReLU
()]
...
...
@@ -60,21 +56,23 @@ class ModelConv(fluid.dygraph.Layer):
kernel_size
=
(
3
,
5
,
7
),
expand_ratio
=
(
1
,
2
,
4
))
as
ofa_super
:
models1
=
[]
models1
+=
[
nn
.
Conv2D
(
6
,
4
,
3
)]
models1
+=
[
nn
.
BatchNorm
(
4
)]
models1
+=
[
nn
.
BatchNorm
2D
(
4
)]
models1
+=
[
ReLU
()]
models1
+=
[
nn
.
Conv2D
(
4
,
4
,
3
,
groups
=
2
)]
models1
+=
[
nn
.
InstanceNorm
(
4
)]
models1
+=
[
nn
.
InstanceNorm
2D
(
4
)]
models1
+=
[
ReLU
()]
models1
+=
[
nn
.
Conv2DTranspose
(
4
,
4
,
3
,
groups
=
2
)]
models1
+=
[
nn
.
BatchNorm
(
4
)]
models1
+=
[
nn
.
BatchNorm
2D
(
4
)]
models1
+=
[
ReLU
()]
models1
+=
[
nn
.
Conv2DTranspose
(
4
,
4
,
3
)]
models1
+=
[
nn
.
BatchNorm
(
4
)]
models1
+=
[
nn
.
BatchNorm2D
(
4
)]
models1
+=
[
ReLU
()]
models1
+=
[
nn
.
Conv2DTranspose
(
4
,
4
,
1
)]
models1
+=
[
nn
.
BatchNorm2D
(
4
)]
models1
+=
[
ReLU
()]
models1
=
ofa_super
.
convert
(
models1
)
models
+=
models1
self
.
models
=
paddle
.
nn
.
Sequential
(
*
models
)
def
forward
(
self
,
inputs
,
depth
=
None
):
...
...
@@ -89,16 +87,61 @@ class ModelConv(fluid.dygraph.Layer):
return
inputs
class
ModelLinear
(
fluid
.
dygraph
.
Layer
):
class
ModelConv2
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
ModelConv2
,
self
).
__init__
()
with
supernet
(
expand_ratio
=
(
1
,
2
,
4
))
as
ofa_super
:
models
=
[]
models
+=
[
nn
.
Conv2DTranspose
(
4
,
4
,
3
)]
models
+=
[
nn
.
BatchNorm2D
(
4
)]
models
+=
[
ReLU
()]
models
+=
[
nn
.
Conv2D
(
4
,
4
,
3
)]
models
+=
[
nn
.
BatchNorm2D
(
4
)]
models
+=
[
ReLU
()]
models
=
ofa_super
.
convert
(
models
)
with
supernet
(
channel
=
((
4
,
6
,
8
),
(
4
,
6
,
8
)))
as
ofa_super
:
models1
=
[]
models1
+=
[
nn
.
Conv2DTranspose
(
4
,
4
,
3
)]
models1
+=
[
nn
.
BatchNorm2D
(
4
)]
models1
+=
[
ReLU
()]
models1
+=
[
nn
.
Conv2DTranspose
(
4
,
4
,
3
)]
models1
+=
[
nn
.
BatchNorm2D
(
4
)]
models1
+=
[
ReLU
()]
models1
=
ofa_super
.
convert
(
models1
)
models
+=
models1
with
supernet
(
kernel_size
=
(
3
,
5
,
7
))
as
ofa_super
:
models2
=
[]
models2
+=
[
nn
.
Conv2D
(
4
,
4
,
3
)]
models2
+=
[
nn
.
BatchNorm2D
(
4
)]
models2
+=
[
ReLU
()]
models2
+=
[
nn
.
Conv2DTranspose
(
4
,
4
,
3
)]
models2
+=
[
nn
.
BatchNorm2D
(
4
)]
models2
+=
[
ReLU
()]
models2
+=
[
nn
.
Conv2D
(
4
,
4
,
3
)]
models2
+=
[
nn
.
BatchNorm2D
(
4
)]
models2
+=
[
ReLU
()]
models2
=
ofa_super
.
convert
(
models2
)
models
+=
models2
self
.
models
=
paddle
.
nn
.
Sequential
(
*
models
)
class
ModelLinear
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
ModelLinear
,
self
).
__init__
()
with
supernet
(
expand_ratio
=
(
1
,
2
,
4
))
as
ofa_super
:
models
=
[]
models
+=
[
nn
.
Embedding
(
num_embeddings
=
64
,
embedding_dim
=
64
)]
models
+=
[
nn
.
Linear
(
64
,
128
)]
models
+=
[
nn
.
LayerNorm
(
128
)]
models
+=
[
nn
.
Linear
(
128
,
256
)]
models
=
ofa_super
.
convert
(
models
)
with
supernet
(
expand_ratio
=
(
1
,
2
,
4
))
as
ofa_super
:
models1
=
[]
models1
+=
[
nn
.
Embedding
(
size
=
(
64
,
64
))]
models1
+=
[
nn
.
Linear
(
64
,
128
)]
models1
+=
[
nn
.
LayerNorm
(
128
)]
models1
+=
[
nn
.
Linear
(
128
,
256
)]
models1
+=
[
nn
.
Linear
(
256
,
256
)]
models1
=
ofa_super
.
convert
(
models1
)
models
+=
models1
...
...
@@ -116,17 +159,21 @@ class ModelLinear(fluid.dygraph.Layer):
return
inputs
class
ModelLinear1
(
fluid
.
dygraph
.
Layer
):
class
ModelLinear1
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
ModelLinear1
,
self
).
__init__
()
models
=
[]
with
supernet
(
channel
=
((
64
,
128
,
256
),
(
64
,
128
,
256
),
(
64
,
128
,
256
)))
as
ofa_super
:
models
=
[]
models
+=
[
nn
.
Embedding
(
num_embeddings
=
64
,
embedding_dim
=
64
)]
models
+=
[
nn
.
Linear
(
64
,
128
)]
models
+=
[
nn
.
LayerNorm
(
128
)]
models
+=
[
nn
.
Linear
(
128
,
256
)]
models
=
ofa_super
.
convert
(
models
)
with
supernet
(
channel
=
((
64
,
128
,
256
),
))
as
ofa_super
:
models1
=
[]
models1
+=
[
nn
.
Embedding
(
size
=
(
64
,
64
))]
models1
+=
[
nn
.
Linear
(
64
,
128
)]
models1
+=
[
nn
.
LayerNorm
(
128
)]
models1
+=
[
nn
.
Linear
(
128
,
256
)]
models1
+=
[
nn
.
Linear
(
256
,
256
)]
models1
=
ofa_super
.
convert
(
models1
)
models
+=
models1
...
...
@@ -145,20 +192,16 @@ class ModelLinear1(fluid.dygraph.Layer):
return
inputs
class
ModelLinear2
(
fluid
.
dygraph
.
Layer
):
class
ModelLinear2
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
ModelLinear2
,
self
).
__init__
()
models
=
[]
with
supernet
(
expand_ratio
=
None
)
as
ofa_super
:
models1
=
[]
models1
+=
[
nn
.
Embedding
(
size
=
(
64
,
64
))]
models1
+=
[
nn
.
Linear
(
64
,
128
)]
models1
+=
[
nn
.
LayerNorm
(
128
)]
models1
+=
[
nn
.
Linear
(
128
,
256
)]
models1
=
ofa_super
.
convert
(
models1
)
models
+=
models1
models
=
[]
models
+=
[
nn
.
Embedding
(
num_embeddings
=
64
,
embedding_dim
=
64
)]
models
+=
[
nn
.
Linear
(
64
,
128
)]
models
+=
[
nn
.
LayerNorm
(
128
)]
models
+=
[
nn
.
Linear
(
128
,
256
)]
models
=
ofa_super
.
convert
(
models
)
self
.
models
=
paddle
.
nn
.
Sequential
(
*
models
)
def
forward
(
self
,
inputs
,
depth
=
None
):
...
...
@@ -175,7 +218,6 @@ class ModelLinear2(fluid.dygraph.Layer):
class
TestOFA
(
unittest
.
TestCase
):
def
setUp
(
self
):
fluid
.
enable_dygraph
()
self
.
init_model_and_data
()
self
.
init_config
()
...
...
@@ -185,7 +227,7 @@ class TestOFA(unittest.TestCase):
data_np
=
np
.
random
.
random
((
1
,
3
,
10
,
10
)).
astype
(
np
.
float32
)
label_np
=
np
.
random
.
random
((
1
)).
astype
(
np
.
float32
)
self
.
data
=
fluid
.
dygraph
.
to_variable
(
data_np
)
self
.
data
=
paddle
.
to_tensor
(
data_np
)
def
init_config
(
self
):
default_run_config
=
{
...
...
@@ -217,10 +259,9 @@ class TestOFA(unittest.TestCase):
cur_idx
=
self
.
run_config
.
n_epochs
[
idx
]
for
ph_idx
in
range
(
len
(
cur_idx
)):
cur_lr
=
self
.
run_config
.
init_learning_rate
[
idx
][
ph_idx
]
adam
=
fluid
.
optimizer
.
Adam
(
adam
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
cur_lr
,
parameter_list
=
(
ofa_model
.
parameters
()
+
ofa_model
.
netAs_param
))
parameters
=
(
ofa_model
.
parameters
()
+
ofa_model
.
netAs_param
))
for
epoch_id
in
range
(
start_epoch
,
self
.
run_config
.
n_epochs
[
idx
][
ph_idx
]):
if
epoch_id
==
0
:
...
...
@@ -228,7 +269,7 @@ class TestOFA(unittest.TestCase):
for
model_no
in
range
(
self
.
run_config
.
dynamic_batch_size
[
idx
]):
output
,
_
=
ofa_model
(
self
.
data
)
loss
=
fluid
.
layers
.
reduce_
mean
(
output
)
loss
=
paddle
.
mean
(
output
)
if
self
.
distill_config
.
mapping_layers
!=
None
:
dis_loss
=
ofa_model
.
calc_distill_loss
()
loss
+=
dis_loss
...
...
@@ -249,7 +290,7 @@ class TestOFACase1(TestOFA):
self
.
teacher_model
=
ModelLinear
()
data_np
=
np
.
random
.
random
((
3
,
64
)).
astype
(
np
.
int64
)
self
.
data
=
fluid
.
dygraph
.
to_variable
(
data_np
)
self
.
data
=
paddle
.
to_tensor
(
data_np
)
def
init_config
(
self
):
default_run_config
=
{
...
...
@@ -275,7 +316,7 @@ class TestOFACase2(TestOFACase1):
self
.
teacher_model
=
ModelLinear1
()
data_np
=
np
.
random
.
random
((
3
,
64
)).
astype
(
np
.
int64
)
self
.
data
=
fluid
.
dygraph
.
to_variable
(
data_np
)
self
.
data
=
paddle
.
to_tensor
(
data_np
)
class
TestOFACase3
(
unittest
.
TestCase
):
...
...
@@ -285,5 +326,10 @@ class TestOFACase3(unittest.TestCase):
ofa_model
.
set_net_config
({
'expand_ratio'
:
None
})
class
TestOFACase3
(
unittest
.
TestCase
):
def
test_ofa
(
self
):
self
.
model
=
ModelConv2
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录