Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
cd94b6b7
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 1 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cd94b6b7
编写于
12月 14, 2020
作者:
C
ceci3
提交者:
GitHub
12月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix when the task is mnli (#540)
* fix * fix when mnli
上级
186960cc
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
20 addition
and
15 deletion
+20
-15
paddleslim/nas/ofa/utils/nlp_utils.py
paddleslim/nas/ofa/utils/nlp_utils.py
+20
-15
未找到文件。
paddleslim/nas/ofa/utils/nlp_utils.py
浏览文件 @
cd94b6b7
...
...
@@ -66,21 +66,26 @@ def compute_neuron_head_importance(task_name,
for
w
in
intermediate_weight
:
neuron_importance
.
append
(
np
.
zeros
(
shape
=
[
w
.
shape
[
1
]],
dtype
=
'float32'
))
for
batch
in
data_loader
:
input_ids
,
segment_ids
,
labels
=
batch
logits
=
model
(
input_ids
,
segment_ids
,
attention_mask
=
[
None
,
head_mask
])
loss
=
loss_fct
(
logits
,
labels
)
loss
.
backward
()
head_importance
+=
paddle
.
abs
(
paddle
.
to_tensor
(
head_mask
.
gradient
()))
for
w1
,
b1
,
w2
,
current_importance
in
zip
(
intermediate_weight
,
intermediate_bias
,
output_weight
,
neuron_importance
):
current_importance
+=
np
.
abs
(
(
np
.
sum
(
w1
.
numpy
()
*
w1
.
gradient
(),
axis
=
0
)
+
b1
.
numpy
()
*
b1
.
gradient
()))
current_importance
+=
np
.
abs
(
np
.
sum
(
w2
.
numpy
()
*
w2
.
gradient
(),
axis
=
1
))
if
task_name
.
lower
()
!=
'mnli'
:
data_loader
=
(
data_loader
,
)
for
data
in
data_loader
:
for
batch
in
data
:
input_ids
,
segment_ids
,
labels
=
batch
logits
=
model
(
input_ids
,
segment_ids
,
attention_mask
=
[
None
,
head_mask
])
loss
=
loss_fct
(
logits
,
labels
)
loss
.
backward
()
head_importance
+=
paddle
.
abs
(
paddle
.
to_tensor
(
head_mask
.
gradient
()))
for
w1
,
b1
,
w2
,
current_importance
in
zip
(
intermediate_weight
,
intermediate_bias
,
output_weight
,
neuron_importance
):
current_importance
+=
np
.
abs
(
(
np
.
sum
(
w1
.
numpy
()
*
w1
.
gradient
(),
axis
=
0
)
+
b1
.
numpy
()
*
b1
.
gradient
()))
current_importance
+=
np
.
abs
(
np
.
sum
(
w2
.
numpy
()
*
w2
.
gradient
(),
axis
=
1
))
return
head_importance
,
neuron_importance
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录