Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
1e12c326
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“52e8ebf246faa737a9553e1fe7df531a10234074”上不存在“develop/doc/design/model_format.html”
未验证
提交
1e12c326
编写于
12月 09, 2020
作者:
C
ceci3
提交者:
GitHub
12月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add for nlp (#529)
* for bert
上级
d9ccc901
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
515 addition
and
30 deletion
+515
-30
paddleslim/nas/ofa/convert_super.py
paddleslim/nas/ofa/convert_super.py
+2
-2
paddleslim/nas/ofa/ofa.py
paddleslim/nas/ofa/ofa.py
+38
-14
paddleslim/nas/ofa/utils/__init__.py
paddleslim/nas/ofa/utils/__init__.py
+5
-0
paddleslim/nas/ofa/utils/nlp_utils.py
paddleslim/nas/ofa/utils/nlp_utils.py
+272
-0
paddleslim/nas/ofa/utils/utils.py
paddleslim/nas/ofa/utils/utils.py
+46
-10
tests/test_ofa.py
tests/test_ofa.py
+23
-4
tests/test_ofa_utils.py
tests/test_ofa_utils.py
+129
-0
未找到文件。
paddleslim/nas/ofa/convert_super.py
浏览文件 @
1e12c326
...
@@ -25,7 +25,7 @@ if pd_ver == 185:
...
@@ -25,7 +25,7 @@ if pd_ver == 185:
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Conv2DTranspose
,
Linear
,
LayerNorm
,
Embedding
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Conv2DTranspose
,
Linear
,
LayerNorm
,
Embedding
from
.layers
import
*
from
.layers
import
*
from
.
import
layers
from
.
import
layers
Layer
=
fluid
.
dygraph
.
nn
.
Layer
Layer
=
paddle
.
fluid
.
dygraph
.
Layer
else
:
else
:
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
paddle.nn
import
Conv2D
,
Conv2DTranspose
,
Linear
,
LayerNorm
,
Embedding
from
paddle.nn
import
Conv2D
,
Conv2DTranspose
,
Linear
,
LayerNorm
,
Embedding
...
@@ -35,7 +35,7 @@ else:
...
@@ -35,7 +35,7 @@ else:
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
__all__
=
[
'supernet'
]
__all__
=
[
'supernet'
,
'Convert'
]
WEIGHT_LAYER
=
[
'conv'
,
'linear'
,
'embedding'
]
WEIGHT_LAYER
=
[
'conv'
,
'linear'
,
'embedding'
]
...
...
paddleslim/nas/ofa/ofa.py
浏览文件 @
1e12c326
...
@@ -20,10 +20,10 @@ import paddle.fluid as fluid
...
@@ -20,10 +20,10 @@ import paddle.fluid as fluid
from
.utils.utils
import
get_paddle_version
from
.utils.utils
import
get_paddle_version
pd_ver
=
get_paddle_version
()
pd_ver
=
get_paddle_version
()
if
pd_ver
==
185
:
if
pd_ver
==
185
:
from
.layers
import
BaseBlock
,
SuperConv2D
from
.layers
import
BaseBlock
,
SuperConv2D
,
SuperLinear
Layer
=
paddle
.
fluid
.
dygraph
.
Layer
Layer
=
paddle
.
fluid
.
dygraph
.
Layer
else
:
else
:
from
.layers_new
import
BaseBlock
,
SuperConv2D
from
.layers_new
import
BaseBlock
,
SuperConv2D
,
SuperLinear
Layer
=
paddle
.
nn
.
Layer
Layer
=
paddle
.
nn
.
Layer
from
.utils.utils
import
search_idx
from
.utils.utils
import
search_idx
from
...common
import
get_logger
from
...common
import
get_logger
...
@@ -40,7 +40,7 @@ RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields)
...
@@ -40,7 +40,7 @@ RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields)
DistillConfig
=
namedtuple
(
'DistillConfig'
,
[
DistillConfig
=
namedtuple
(
'DistillConfig'
,
[
'lambda_distill'
,
'teacher_model'
,
'mapping_layers'
,
'teacher_model_path'
,
'lambda_distill'
,
'teacher_model'
,
'mapping_layers'
,
'teacher_model_path'
,
'distill_fn'
'distill_fn'
,
'mapping_op'
])
])
DistillConfig
.
__new__
.
__defaults__
=
(
None
,
)
*
len
(
DistillConfig
.
_fields
)
DistillConfig
.
__new__
.
__defaults__
=
(
None
,
)
*
len
(
DistillConfig
.
_fields
)
...
@@ -193,12 +193,28 @@ class OFA(OFABase):
...
@@ -193,12 +193,28 @@ class OFA(OFABase):
self
.
netAs
=
[]
self
.
netAs
=
[]
for
name
,
sublayer
in
self
.
model
.
named_sublayers
():
for
name
,
sublayer
in
self
.
model
.
named_sublayers
():
if
name
in
mapping_layers
:
if
name
in
mapping_layers
:
netA
=
SuperConv2D
(
if
self
.
distill_config
.
mapping_op
!=
None
:
getattr
(
sublayer
,
'_num_filters'
,
if
self
.
distill_config
.
mapping_op
.
lower
()
==
'conv2d'
:
sublayer
.
_out_channels
),
netA
=
SuperConv2D
(
getattr
(
sublayer
,
'_num_filters'
,
getattr
(
sublayer
,
'_num_filters'
,
sublayer
.
_out_channels
),
1
)
sublayer
.
_out_channels
),
self
.
netAs_param
.
extend
(
netA
.
parameters
())
getattr
(
sublayer
,
'_num_filters'
,
sublayer
.
_out_channels
),
1
)
elif
self
.
distill_config
.
mapping_op
.
lower
()
==
'linear'
:
netA
=
SuperLinear
(
getattr
(
sublayer
,
'_output_dim'
,
sublayer
.
_out_features
),
getattr
(
sublayer
,
'_output_dim'
,
sublayer
.
_out_features
))
else
:
raise
NotImplementedError
(
"Not Support Op: {}"
.
format
(
self
.
distill_config
.
mapping_op
.
lower
()))
else
:
netA
=
None
if
netA
!=
None
:
self
.
netAs_param
.
extend
(
netA
.
parameters
())
self
.
netAs
.
append
(
netA
)
self
.
netAs
.
append
(
netA
)
def
get_activation
(
mem
,
name
):
def
get_activation
(
mem
,
name
):
...
@@ -289,16 +305,24 @@ class OFA(OFABase):
...
@@ -289,16 +305,24 @@ class OFA(OFABase):
losses
=
[]
losses
=
[]
assert
len
(
self
.
netAs
)
>
0
assert
len
(
self
.
netAs
)
>
0
for
i
,
netA
in
enumerate
(
self
.
netAs
):
for
i
,
netA
in
enumerate
(
self
.
netAs
):
assert
isinstance
(
netA
,
SuperConv2D
)
n
=
self
.
distill_config
.
mapping_layers
[
i
]
n
=
self
.
distill_config
.
mapping_layers
[
i
]
Tact
=
self
.
Tacts
[
n
]
Tact
=
self
.
Tacts
[
n
]
Sact
=
self
.
Sacts
[
n
]
Sact
=
self
.
Sacts
[
n
]
Sact
=
netA
(
if
isinstance
(
netA
,
SuperConv2D
):
Sact
,
channel
=
getattr
(
netA
,
'_num_filters'
,
netA
.
_out_channels
))
Sact
=
netA
(
Sact
,
channel
=
getattr
(
netA
,
'_num_filters'
,
netA
.
_out_channels
))
elif
isinstance
(
netA
,
SuperLinear
):
Sact
=
netA
(
Sact
,
channel
=
getattr
(
netA
,
'_output_dim'
,
netA
.
_out_features
))
else
:
Sact
=
Sact
if
self
.
distill_config
.
distill_fn
==
None
:
if
self
.
distill_config
.
distill_fn
==
None
:
loss
=
fluid
.
layers
.
mse_loss
(
Sact
,
Tact
)
loss
=
fluid
.
layers
.
mse_loss
(
Sact
,
Tact
.
detach
()
)
else
:
else
:
loss
=
distill_fn
(
Sact
,
Tact
)
loss
=
distill_fn
(
Sact
,
Tact
.
detach
()
)
losses
.
append
(
loss
)
losses
.
append
(
loss
)
return
sum
(
losses
)
*
self
.
distill_config
.
lambda_distill
return
sum
(
losses
)
*
self
.
distill_config
.
lambda_distill
...
...
paddleslim/nas/ofa/utils/__init__.py
浏览文件 @
1e12c326
...
@@ -13,3 +13,8 @@
...
@@ -13,3 +13,8 @@
# limitations under the License.
# limitations under the License.
from
.utils
import
*
from
.utils
import
*
from
.utils
import
get_paddle_version
pd_ver
=
get_paddle_version
()
if
pd_ver
==
200
:
from
.nlp_utils
import
*
paddleslim/nas/ofa/utils/nlp_utils.py
0 → 100644
浏览文件 @
1e12c326
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
__all__
=
[
"compute_neuron_head_importance"
,
"reorder_head"
,
"reorder_neuron"
]
def
compute_neuron_head_importance
(
task_name
,
model
,
data_loader
,
num_layers
,
num_heads
,
loss_fct
=
nn
.
loss
.
CrossEntropyLoss
(),
intermediate_name
=
'linear1'
,
output_name
=
'linear2'
):
"""
Compute the importance of multi-head attention and feed-forward neuron in each transformer layer.
Args:
task_name(str): task name.
model(paddle.nn.Layer): the instance of transformer model.
data_loader(DataLoader): An iterable data loader is used for evaluate. An instance of `paddle.io.Dataloader`.
num_layers(int): number of transformer layers.
num_heads(int): number of heads in each multi-head attention.
loss_fct(Loss|optional): loss function can be a `paddle.nn.Layer` instance. Default: `nn.loss.CrossEntropyLoss()`.
intermediate_name(str|optional): the name of intermediate `Linear` layer in feed-forward. Default: `linear1`.
output_name(str|optional): the name of output `Linear` layer in feed-forward. Default: `linear2`.
"""
head_importance
=
paddle
.
zeros
(
shape
=
[
num_layers
,
num_heads
],
dtype
=
'float32'
)
head_mask
=
paddle
.
ones
(
shape
=
[
num_layers
,
num_heads
],
dtype
=
'float32'
)
head_mask
.
stop_gradient
=
False
intermediate_weight
=
[]
intermediate_bias
=
[]
output_weight
=
[]
for
name
,
w
in
model
.
named_parameters
():
if
intermediate_name
in
name
:
if
len
(
w
.
shape
)
>
1
:
intermediate_weight
.
append
(
w
)
else
:
intermediate_bias
.
append
(
w
)
if
output_name
in
name
:
if
len
(
w
.
shape
)
>
1
:
output_weight
.
append
(
w
)
neuron_importance
=
[]
for
w
in
intermediate_weight
:
neuron_importance
.
append
(
np
.
zeros
(
shape
=
[
w
.
shape
[
1
]],
dtype
=
'float32'
))
for
batch
in
data_loader
:
input_ids
,
segment_ids
,
labels
=
batch
logits
=
model
(
input_ids
,
segment_ids
,
attention_mask
=
[
None
,
head_mask
])
loss
=
loss_fct
(
logits
,
labels
)
loss
.
backward
()
head_importance
+=
paddle
.
abs
(
paddle
.
to_tensor
(
head_mask
.
gradient
()))
for
w1
,
b1
,
w2
,
current_importance
in
zip
(
intermediate_weight
,
intermediate_bias
,
output_weight
,
neuron_importance
):
current_importance
+=
np
.
abs
(
(
np
.
sum
(
w1
.
numpy
()
*
w1
.
gradient
(),
axis
=
0
)
+
b1
.
numpy
()
*
b1
.
gradient
()))
current_importance
+=
np
.
abs
(
np
.
sum
(
w2
.
numpy
()
*
w2
.
gradient
(),
axis
=
1
))
return
head_importance
,
neuron_importance
def
reorder_head
(
layer
,
index
):
"""
Reorder head weights according index.
Args:
layer(paddle.nn.Layer): the instance of `paddle.nn.MultiHeadAttention` layer.
index(list): the sort indices of multi-head.
"""
assert
isinstance
(
layer
,
nn
.
MultiHeadAttention
),
\
"layer in reorder_head must be the instance of `paddle.nn.MultiHeadAttention`."
n
,
a
=
layer
.
num_heads
,
layer
.
head_dim
idx
=
paddle
.
reshape
(
paddle
.
index_select
(
paddle
.
reshape
(
paddle
.
arange
(
0
,
n
*
a
,
dtype
=
'int64'
),
shape
=
[
n
,
a
]),
index
=
index
,
axis
=
0
),
shape
=
[
-
1
])
def
reorder_head_matrix
(
linearLayer
,
index
,
dim
=
1
):
W
=
paddle
.
index_select
(
linearLayer
.
weight
,
index
,
axis
=
dim
).
detach
()
if
linearLayer
.
bias
is
not
None
:
if
dim
==
0
:
b
=
paddle
.
assign
(
linearLayer
.
bias
).
detach
()
else
:
b
=
paddle
.
assign
(
paddle
.
index_select
(
linearLayer
.
bias
,
index
,
axis
=
0
)).
detach
()
linearLayer
.
weight
.
stop_gradient
=
True
linearLayer
.
weight
.
set_value
(
W
)
linearLayer
.
weight
.
stop_gradient
=
False
if
linearLayer
.
bias
is
not
None
:
linearLayer
.
bias
.
stop_gradient
=
True
linearLayer
.
bias
.
set_value
(
b
)
linearLayer
.
bias
.
stop_gradient
=
False
reorder_head_matrix
(
layer
.
q_proj
.
fn
if
hasattr
(
layer
.
q_proj
,
'fn'
)
else
layer
.
q_proj
,
idx
)
reorder_head_matrix
(
layer
.
k_proj
.
fn
if
hasattr
(
layer
.
k_proj
,
'fn'
)
else
layer
.
k_proj
,
idx
)
reorder_head_matrix
(
layer
.
v_proj
.
fn
if
hasattr
(
layer
.
v_proj
,
'fn'
)
else
layer
.
v_proj
,
idx
)
reorder_head_matrix
(
layer
.
out_proj
.
fn
if
hasattr
(
layer
.
out_proj
,
'fn'
)
else
layer
.
out_proj
,
idx
,
dim
=
0
)
def
reorder_neuron
(
layer
,
index
,
dim
=
0
):
"""
Reorder feed-forward weights according index.
Args:
layer(paddle.nn.Layer): the instance of `paddle.nn.Linear` layer.
index(list): the sort indices of feed-forward.
dim(int): select weights according to the dim.
"""
linearLayer
=
layer
.
fn
if
hasattr
(
layer
,
'fn'
)
else
layer
W
=
paddle
.
index_select
(
linearLayer
.
weight
,
index
,
axis
=
dim
).
detach
()
if
linearLayer
.
bias
is
not
None
:
if
dim
==
0
:
b
=
paddle
.
assign
(
linearLayer
.
bias
).
detach
()
else
:
b
=
paddle
.
assign
(
paddle
.
index_select
(
linearLayer
.
bias
,
index
,
axis
=
0
)).
detach
()
linearLayer
.
weight
.
stop_gradient
=
True
linearLayer
.
weight
.
set_value
(
W
)
linearLayer
.
weight
.
stop_gradient
=
False
if
linearLayer
.
bias
is
not
None
:
linearLayer
.
bias
.
stop_gradient
=
True
linearLayer
.
bias
.
set_value
(
b
)
linearLayer
.
bias
.
stop_gradient
=
False
### monkey patch for MultiHeadAttention _prepare_qkv to change num_heads.
def
_prepare_qkv
(
self
,
query
,
key
,
value
,
cache
=
None
):
q
=
self
.
q_proj
(
query
)
if
hasattr
(
self
.
q_proj
,
'fn'
)
and
self
.
q_proj
.
fn
.
cur_config
[
'expand_ratio'
]
!=
None
:
self
.
num_heads
=
int
(
self
.
num_heads
*
self
.
q_proj
.
fn
.
cur_config
[
'expand_ratio'
])
q
=
paddle
.
reshape
(
x
=
q
,
shape
=
[
0
,
0
,
self
.
num_heads
,
self
.
head_dim
])
q
=
paddle
.
transpose
(
x
=
q
,
perm
=
[
0
,
2
,
1
,
3
])
if
isinstance
(
cache
,
self
.
StaticCache
):
# for encoder-decoder attention in inference and has cached
k
,
v
=
cache
.
k
,
cache
.
v
else
:
k
,
v
=
self
.
compute_kv
(
key
,
value
)
if
isinstance
(
cache
,
self
.
Cache
):
# for decoder self-attention in inference
k
=
paddle
.
concat
([
cache
.
k
,
k
],
axis
=
2
)
v
=
paddle
.
concat
([
cache
.
v
,
v
],
axis
=
2
)
cache
=
self
.
Cache
(
k
,
v
)
return
(
q
,
k
,
v
)
if
cache
is
None
else
(
q
,
k
,
v
,
cache
)
### monkey patch for MultiHeadAttention forward to accept head_mask
### attn_mask[0] = attn_mask, attn_mask[1] = head_mask
def
_mha_forward
(
self
,
query
,
key
,
value
,
attn_mask
=
None
,
cache
=
None
):
key
=
query
if
key
is
None
else
key
value
=
query
if
value
is
None
else
value
# compute q ,k ,v
if
cache
is
None
:
q
,
k
,
v
=
self
.
_prepare_qkv
(
query
,
key
,
value
,
cache
)
else
:
q
,
k
,
v
,
cache
=
self
.
_prepare_qkv
(
query
,
key
,
value
,
cache
)
# scale dot product attention
# TODO: use paddle.matmul, however it doesn't support `alpha`
product
=
paddle
.
fluid
.
layers
.
matmul
(
x
=
q
,
y
=
k
,
transpose_y
=
True
,
alpha
=
self
.
head_dim
**-
0.5
)
if
attn_mask
[
0
]
is
not
None
:
# TODO(guosheng): support bool mask
product
=
product
+
attn_mask
[
0
]
weights
=
F
.
softmax
(
product
)
if
self
.
dropout
:
weights
=
F
.
dropout
(
weights
,
self
.
dropout
,
training
=
self
.
training
,
mode
=
"upscale_in_train"
)
if
attn_mask
[
1
]
is
not
None
:
weights
=
weights
*
attn_mask
[
1
]
out
=
paddle
.
matmul
(
weights
,
v
)
# combine heads
out
=
paddle
.
transpose
(
out
,
perm
=
[
0
,
2
,
1
,
3
])
out
=
paddle
.
reshape
(
x
=
out
,
shape
=
[
0
,
0
,
out
.
shape
[
2
]
*
out
.
shape
[
3
]])
# project to output
out
=
self
.
out_proj
(
out
)
outs
=
[
out
]
if
self
.
need_weights
:
outs
.
append
(
weights
)
if
cache
is
not
None
:
outs
.
append
(
cache
)
if
hasattr
(
self
.
q_proj
,
'fn'
)
and
self
.
q_proj
.
fn
.
cur_config
[
'expand_ratio'
]
!=
None
:
self
.
num_heads
=
int
(
float
(
self
.
num_heads
)
/
self
.
q_proj
.
fn
.
cur_config
[
'expand_ratio'
])
return
out
if
len
(
outs
)
==
1
else
tuple
(
outs
)
### monkey patch for TransformerEncoder forward to accept head_mask
### attn_mask[0] = attn_mask, attn_mask[1] = head_mask
def
_encoder_forward
(
self
,
src
,
src_mask
=
[
None
,
None
]):
output
=
src
if
src_mask
[
1
]
is
not
None
:
head_mask
=
src_mask
[
1
]
if
len
(
head_mask
.
shape
)
==
1
:
head_mask
=
paddle
.
unsqueeze
(
paddle
.
unsqueeze
(
paddle
.
unsqueeze
(
paddle
.
unsqueeze
(
head_mask
,
0
),
0
),
-
1
),
-
1
)
head_mask
=
paddle
.
expand
(
head_mask
,
shape
=
[
self
.
num_layers
]
+
head_mask
.
shape
[
1
:])
elif
len
(
head_mask
.
shape
)
==
2
:
head_mask
=
paddle
.
unsqueeze
(
paddle
.
unsqueeze
(
paddle
.
unsqueeze
(
head_mask
,
1
),
-
1
),
-
1
)
else
:
head_mask
=
[
None
]
*
self
.
num_layers
for
i
,
mod
in
enumerate
(
self
.
layers
):
output
=
mod
(
output
,
src_mask
=
[
src_mask
[
0
],
head_mask
[
i
]])
if
self
.
norm
is
not
None
:
output
=
self
.
norm
(
output
)
return
output
nn
.
MultiHeadAttention
.
forward
=
_mha_forward
nn
.
MultiHeadAttention
.
_prepare_qkv
=
_prepare_qkv
nn
.
TransformerEncoder
.
forward
=
_encoder_forward
paddleslim/nas/ofa/utils/utils.py
浏览文件 @
1e12c326
...
@@ -12,6 +12,52 @@
...
@@ -12,6 +12,52 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
logging
import
paddle
from
....common
import
get_logger
def
get_paddle_version
():
import
paddle
pd_ver
=
185
if
hasattr
(
paddle
,
'nn'
):
if
hasattr
(
paddle
.
nn
,
'Conv1D'
):
### judge 2.0 alpha
pd_ver
=
200
return
pd_ver
pd_ver
=
get_paddle_version
()
if
pd_ver
==
185
:
Layer
=
paddle
.
fluid
.
dygraph
.
Layer
else
:
Layer
=
paddle
.
nn
.
Layer
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
__all__
=
[
'set_state_dict'
]
def
set_state_dict
(
model
,
state_dict
):
"""
Set state dict from origin model to supernet model.
Args:
model(paddle.nn.Layer): model after convert to supernet.
state_dict(dict): dict with the type of {name: param} in origin model.
"""
assert
isinstance
(
model
,
Layer
)
assert
isinstance
(
state_dict
,
dict
)
for
name
,
param
in
model
.
named_parameters
():
tmp_n
=
name
.
split
(
'.'
)[:
-
2
]
+
[
name
.
split
(
'.'
)[
-
1
]]
tmp_n
=
'.'
.
join
(
tmp_n
)
if
name
in
state_dict
:
param
.
set_value
(
state_dict
[
name
])
elif
tmp_n
in
state_dict
:
param
.
set_value
(
state_dict
[
tmp_n
])
else
:
_logger
.
info
(
'{} is not in state_dict'
.
format
(
tmp_n
))
def
compute_start_end
(
kernel_size
,
sub_kernel_size
):
def
compute_start_end
(
kernel_size
,
sub_kernel_size
):
center
=
kernel_size
//
2
center
=
kernel_size
//
2
...
@@ -44,13 +90,3 @@ def search_idx(num, sorted_nestlist):
...
@@ -44,13 +90,3 @@ def search_idx(num, sorted_nestlist):
return
idx
,
phase_idx
return
idx
,
phase_idx
assert
num
>
max_num
assert
num
>
max_num
return
len
(
sorted_nestlist
)
-
1
,
max_idx
return
len
(
sorted_nestlist
)
-
1
,
max_idx
def
get_paddle_version
():
import
paddle
pd_ver
=
185
if
hasattr
(
paddle
,
'nn'
):
if
hasattr
(
paddle
.
nn
,
'Conv1D'
):
### judge 2.0 alpha
pd_ver
=
200
return
pd_ver
tests/test_ofa.py
浏览文件 @
1e12c326
...
@@ -243,7 +243,8 @@ class TestOFA(unittest.TestCase):
...
@@ -243,7 +243,8 @@ class TestOFA(unittest.TestCase):
default_distill_config
=
{
default_distill_config
=
{
'lambda_distill'
:
0.01
,
'lambda_distill'
:
0.01
,
'teacher_model'
:
self
.
teacher_model
,
'teacher_model'
:
self
.
teacher_model
,
'mapping_layers'
:
[
'models.0.fn'
]
'mapping_layers'
:
[
'models.0.fn'
],
'mapping_op'
:
'conv2d'
}
}
self
.
distill_config
=
DistillConfig
(
**
default_distill_config
)
self
.
distill_config
=
DistillConfig
(
**
default_distill_config
)
self
.
elastic_order
=
[
'kernel_size'
,
'width'
,
'depth'
]
self
.
elastic_order
=
[
'kernel_size'
,
'width'
,
'depth'
]
...
@@ -289,7 +290,6 @@ class TestOFACase1(TestOFA):
...
@@ -289,7 +290,6 @@ class TestOFACase1(TestOFA):
self
.
model
=
ModelLinear
()
self
.
model
=
ModelLinear
()
self
.
teacher_model
=
ModelLinear
()
self
.
teacher_model
=
ModelLinear
()
data_np
=
np
.
random
.
random
((
3
,
64
)).
astype
(
np
.
int64
)
data_np
=
np
.
random
.
random
((
3
,
64
)).
astype
(
np
.
int64
)
self
.
data
=
paddle
.
to_tensor
(
data_np
)
self
.
data
=
paddle
.
to_tensor
(
data_np
)
def
init_config
(
self
):
def
init_config
(
self
):
...
@@ -305,12 +305,14 @@ class TestOFACase1(TestOFA):
...
@@ -305,12 +305,14 @@ class TestOFACase1(TestOFA):
default_distill_config
=
{
default_distill_config
=
{
'lambda_distill'
:
0.01
,
'lambda_distill'
:
0.01
,
'teacher_model'
:
self
.
teacher_model
,
'teacher_model'
:
self
.
teacher_model
,
'mapping_op'
:
'linear'
,
'mapping_layers'
:
[
'models.3.fn'
],
}
}
self
.
distill_config
=
DistillConfig
(
**
default_distill_config
)
self
.
distill_config
=
DistillConfig
(
**
default_distill_config
)
self
.
elastic_order
=
None
self
.
elastic_order
=
None
class
TestOFACase2
(
TestOFA
Case1
):
class
TestOFACase2
(
TestOFA
):
def
init_model_and_data
(
self
):
def
init_model_and_data
(
self
):
self
.
model
=
ModelLinear1
()
self
.
model
=
ModelLinear1
()
self
.
teacher_model
=
ModelLinear1
()
self
.
teacher_model
=
ModelLinear1
()
...
@@ -318,6 +320,23 @@ class TestOFACase2(TestOFACase1):
...
@@ -318,6 +320,23 @@ class TestOFACase2(TestOFACase1):
self
.
data
=
paddle
.
to_tensor
(
data_np
)
self
.
data
=
paddle
.
to_tensor
(
data_np
)
def
init_config
(
self
):
default_run_config
=
{
'train_batch_size'
:
1
,
'n_epochs'
:
[[
2
,
5
]],
'init_learning_rate'
:
[[
0.003
,
0.001
]],
'dynamic_batch_size'
:
[
1
],
'total_images'
:
1
,
}
self
.
run_config
=
RunConfig
(
**
default_run_config
)
default_distill_config
=
{
'lambda_distill'
:
0.01
,
'teacher_model'
:
self
.
teacher_model
,
'mapping_layers'
:
[
'models.3.fn'
],
}
self
.
distill_config
=
DistillConfig
(
**
default_distill_config
)
self
.
elastic_order
=
None
class
TestOFACase3
(
unittest
.
TestCase
):
class
TestOFACase3
(
unittest
.
TestCase
):
def
test_ofa
(
self
):
def
test_ofa
(
self
):
...
@@ -326,7 +345,7 @@ class TestOFACase3(unittest.TestCase):
...
@@ -326,7 +345,7 @@ class TestOFACase3(unittest.TestCase):
ofa_model
.
set_net_config
({
'expand_ratio'
:
None
})
ofa_model
.
set_net_config
({
'expand_ratio'
:
None
})
class
TestOFACase
3
(
unittest
.
TestCase
):
class
TestOFACase
4
(
unittest
.
TestCase
):
def
test_ofa
(
self
):
def
test_ofa
(
self
):
self
.
model
=
ModelConv2
()
self
.
model
=
ModelConv2
()
...
...
tests/test_ofa_utils.py
0 → 100644
浏览文件 @
1e12c326
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
sys
.
path
.
append
(
"../"
)
import
unittest
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
from
paddle.vision.models
import
mobilenet_v1
from
paddleslim.nas.ofa.convert_super
import
Convert
,
supernet
from
paddleslim.nas.ofa.utils
import
compute_neuron_head_importance
,
reorder_head
,
reorder_neuron
,
set_state_dict
class
TestComputeImportance
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
model
=
self
.
init_model
()
self
.
data_loader
=
self
.
init_data
()
def
init_model
(
self
):
class
TestModel
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
TestModel
,
self
).
__init__
()
encoder_layer
=
nn
.
TransformerEncoderLayer
(
312
,
12
,
1024
,
dropout
=
0.1
,
activation
=
'gelu'
,
attn_dropout
=
0.1
,
act_dropout
=
0
)
self
.
encoder
=
nn
.
TransformerEncoder
(
encoder_layer
,
3
)
self
.
fc
=
nn
.
Linear
(
312
,
3
)
def
forward
(
self
,
input_ids
,
segment_ids
,
attention_mask
=
[
None
,
None
]):
src
=
input_ids
+
segment_ids
out
=
self
.
encoder
(
src
,
attention_mask
)
out
=
self
.
fc
(
out
[:,
0
])
return
out
return
TestModel
()
def
init_data
(
self
):
batch_size
=
16
hidden_size
=
312
d_model
=
26
input_ids
=
np
.
random
.
rand
(
batch_size
,
d_model
,
hidden_size
).
astype
(
"float32"
)
segment_ids
=
np
.
random
.
rand
(
batch_size
,
d_model
,
hidden_size
).
astype
(
"float32"
)
labels
=
np
.
random
.
randint
(
0
,
high
=
3
,
size
=
(
batch_size
,
1
))
data
=
((
paddle
.
to_tensor
(
input_ids
),
paddle
.
to_tensor
(
segment_ids
),
paddle
.
to_tensor
(
labels
)),
)
return
data
def
reorder_reorder_neuron_head
(
self
,
model
,
head_importance
,
neuron_importance
):
# reorder heads and ffn neurons
for
layer
,
current_importance
in
enumerate
(
neuron_importance
):
# reorder heads
idx
=
paddle
.
argsort
(
head_importance
[
layer
],
descending
=
True
)
reorder_head
(
model
.
encoder
.
layers
[
layer
].
self_attn
,
idx
)
# reorder neurons
idx
=
paddle
.
argsort
(
paddle
.
to_tensor
(
current_importance
),
descending
=
True
)
reorder_neuron
(
model
.
encoder
.
layers
[
layer
].
linear1
,
idx
,
dim
=
1
)
reorder_neuron
(
model
.
encoder
.
layers
[
layer
].
linear2
,
idx
,
dim
=
0
)
def
test_compute
(
self
):
head_importance
,
neuron_importance
=
compute_neuron_head_importance
(
task_name
=
'xnli'
,
model
=
self
.
model
,
data_loader
=
self
.
data_loader
,
num_layers
=
3
,
num_heads
=
12
)
assert
(
len
(
head_importance
)
==
3
)
assert
(
len
(
neuron_importance
)
==
3
)
self
.
reorder_reorder_neuron_head
(
self
.
model
,
head_importance
,
neuron_importance
)
class
TestComputeImportanceCase1
(
TestComputeImportance
):
def
test_compute
(
self
):
for
batch
in
self
.
data_loader
:
input_ids
,
segment_ids
,
labels
=
batch
logits
=
self
.
model
(
input_ids
,
segment_ids
,
attention_mask
=
[
None
,
None
])
assert
logits
.
shape
[
1
]
==
3
class
TestComputeImportanceCase2
(
TestComputeImportance
):
def
test_compute
(
self
):
head_mask
=
paddle
.
ones
(
shape
=
[
12
],
dtype
=
'float32'
)
for
batch
in
self
.
data_loader
:
input_ids
,
segment_ids
,
labels
=
batch
logits
=
self
.
model
(
input_ids
,
segment_ids
,
attention_mask
=
[
None
,
head_mask
])
assert
logits
.
shape
[
1
]
==
3
class
TestSetStateDict
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
model
=
mobilenet_v1
()
self
.
origin_weights
=
{}
for
name
,
param
in
self
.
model
.
named_parameters
():
self
.
origin_weights
[
name
]
=
param
def
test_set_state_dict
(
self
):
sp_net_config
=
supernet
(
expand_ratio
=
[
0.5
,
1.0
])
sp_model
=
Convert
(
sp_net_config
).
convert
(
self
.
model
)
set_state_dict
(
sp_model
,
self
.
origin_weights
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录