Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
6c4a0850
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6c4a0850
编写于
4月 28, 2020
作者:
Y
yelrose
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixed graph_wrapper for dtype inference
上级
03cb3621
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
5 addition
and
3 deletion
+5
-3
pgl/graph_wrapper.py
pgl/graph_wrapper.py
+5
-3
未找到文件。
pgl/graph_wrapper.py
浏览文件 @
6c4a0850
...
@@ -40,7 +40,6 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
...
@@ -40,7 +40,6 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
num_edges
):
num_edges
):
"""Recv message from given msg to dst nodes.
"""Recv message from given msg to dst nodes.
"""
"""
empty_msg_flag
=
fluid
.
layers
.
cast
(
num_edges
>
0
,
dtype
=
"float32"
)
if
reduce_function
==
"sum"
:
if
reduce_function
==
"sum"
:
if
isinstance
(
msg
,
dict
):
if
isinstance
(
msg
,
dict
):
raise
TypeError
(
"The message for build-in function"
raise
TypeError
(
"The message for build-in function"
...
@@ -49,8 +48,9 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
...
@@ -49,8 +48,9 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
try
:
try
:
out_dim
=
msg
.
shape
[
-
1
]
out_dim
=
msg
.
shape
[
-
1
]
init_output
=
fluid
.
layers
.
fill_constant
(
init_output
=
fluid
.
layers
.
fill_constant
(
shape
=
[
num_nodes
,
out_dim
],
value
=
0
,
dtype
=
"float32"
)
shape
=
[
num_nodes
,
out_dim
],
value
=
0
,
dtype
=
msg
.
dtype
)
init_output
.
stop_gradient
=
False
init_output
.
stop_gradient
=
False
empty_msg_flag
=
fluid
.
layers
.
cast
(
num_edges
>
0
,
dtype
=
msg
.
dtype
)
msg
=
msg
*
empty_msg_flag
msg
=
msg
*
empty_msg_flag
output
=
paddle_helper
.
scatter_add
(
init_output
,
dst
,
msg
)
output
=
paddle_helper
.
scatter_add
(
init_output
,
dst
,
msg
)
return
output
return
output
...
@@ -66,10 +66,12 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
...
@@ -66,10 +66,12 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
bucketed_msg
=
op
.
nested_lod_reset
(
msg
,
bucketing_index
)
bucketed_msg
=
op
.
nested_lod_reset
(
msg
,
bucketing_index
)
output
=
reduce_function
(
bucketed_msg
)
output
=
reduce_function
(
bucketed_msg
)
output_dim
=
output
.
shape
[
-
1
]
output_dim
=
output
.
shape
[
-
1
]
empty_msg_flag
=
fluid
.
layers
.
cast
(
num_edges
>
0
,
dtype
=
output
.
dtype
)
output
=
output
*
empty_msg_flag
output
=
output
*
empty_msg_flag
init_output
=
fluid
.
layers
.
fill_constant
(
init_output
=
fluid
.
layers
.
fill_constant
(
shape
=
[
num_nodes
,
output_dim
],
value
=
0
,
dtype
=
"float32"
)
shape
=
[
num_nodes
,
output_dim
],
value
=
0
,
dtype
=
output
.
dtype
)
init_output
.
stop_gradient
=
True
init_output
.
stop_gradient
=
True
final_output
=
fluid
.
layers
.
scatter
(
init_output
,
uniq_dst
,
output
)
final_output
=
fluid
.
layers
.
scatter
(
init_output
,
uniq_dst
,
output
)
return
final_output
return
final_output
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录