Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
08da20a6
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
08da20a6
编写于
5月 18, 2020
作者:
S
suweiyue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
1. fix stop_gradient, 2. reduce_sum keep_dim
上级
e68b8b25
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
3 addition
and
3 deletion
+3
-3
examples/erniesage/models/base.py
examples/erniesage/models/base.py
+2
-2
examples/erniesage/models/erniesage_v2.py
examples/erniesage/models/erniesage_v2.py
+1
-1
未找到文件。
examples/erniesage/models/base.py
浏览文件 @
08da20a6
...
@@ -191,12 +191,12 @@ def all_gather(X):
...
@@ -191,12 +191,12 @@ def all_gather(X):
for
i
in
range
(
trainer_num
):
for
i
in
range
(
trainer_num
):
copy_X
=
X
*
1
copy_X
=
X
*
1
copy_X
=
L
.
collective
.
_broadcast
(
copy_X
,
i
,
True
)
copy_X
=
L
.
collective
.
_broadcast
(
copy_X
,
i
,
True
)
copy_X
.
stop_gradient
s
=
True
copy_X
.
stop_gradient
=
True
Xs
.
append
(
copy_X
)
Xs
.
append
(
copy_X
)
if
len
(
Xs
)
>
1
:
if
len
(
Xs
)
>
1
:
Xs
=
L
.
concat
(
Xs
,
0
)
Xs
=
L
.
concat
(
Xs
,
0
)
Xs
.
stop_gradient
s
=
True
Xs
.
stop_gradient
=
True
else
:
else
:
Xs
=
Xs
[
0
]
Xs
=
Xs
[
0
]
return
Xs
return
Xs
...
...
examples/erniesage/models/erniesage_v2.py
浏览文件 @
08da20a6
...
@@ -27,7 +27,7 @@ class ErnieSageV2(BaseNet):
...
@@ -27,7 +27,7 @@ class ErnieSageV2(BaseNet):
src_position_ids
=
L
.
expand
(
src_position_ids
,
[
src_batch
,
1
,
1
])
# [B, slot_seqlen * num_b, 1]
src_position_ids
=
L
.
expand
(
src_position_ids
,
[
src_batch
,
1
,
1
])
# [B, slot_seqlen * num_b, 1]
zero
=
L
.
fill_constant
([
1
],
dtype
=
'int64'
,
value
=
0
)
zero
=
L
.
fill_constant
([
1
],
dtype
=
'int64'
,
value
=
0
)
input_mask
=
L
.
cast
(
L
.
equal
(
src_ids
,
zero
),
"int32"
)
# assume pad id == 0 [B, slot_seqlen, 1]
input_mask
=
L
.
cast
(
L
.
equal
(
src_ids
,
zero
),
"int32"
)
# assume pad id == 0 [B, slot_seqlen, 1]
src_pad_len
=
L
.
reduce_sum
(
input_mask
,
1
)
# [B, 1, 1]
src_pad_len
=
L
.
reduce_sum
(
input_mask
,
1
,
keep_dim
=
True
)
# [B, 1, 1]
dst_position_ids
=
L
.
reshape
(
dst_position_ids
=
L
.
reshape
(
L
.
range
(
L
.
range
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录