Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
7166db92
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7166db92
编写于
8月 12, 2020
作者:
Y
Yelrose
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Batch GraphWrapper
上级
6933c683
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
87 addition
and
18 deletion
+87
-18
pgl/graph_wrapper.py
pgl/graph_wrapper.py
+87
-18
未找到文件。
pgl/graph_wrapper.py
浏览文件 @
7166db92
...
@@ -25,7 +25,7 @@ from pgl.utils import op
...
@@ -25,7 +25,7 @@ from pgl.utils import op
from
pgl.utils
import
paddle_helper
from
pgl.utils
import
paddle_helper
from
pgl.utils.logger
import
log
from
pgl.utils.logger
import
log
__all__
=
[
"BaseGraphWrapper"
,
"GraphWrapper"
,
"StaticGraphWrapper"
]
__all__
=
[
"BaseGraphWrapper"
,
"GraphWrapper"
,
"StaticGraphWrapper"
,
"BatchGraphWrapper"
]
def
send
(
src
,
dst
,
nfeat
,
efeat
,
message_func
):
def
send
(
src
,
dst
,
nfeat
,
efeat
,
message_func
):
"""Send message from src to dst.
"""Send message from src to dst.
...
@@ -101,7 +101,6 @@ class BaseGraphWrapper(object):
...
@@ -101,7 +101,6 @@ class BaseGraphWrapper(object):
self
.
_indegree
=
None
self
.
_indegree
=
None
self
.
_edge_uniq_dst
=
None
self
.
_edge_uniq_dst
=
None
self
.
_edge_uniq_dst_count
=
None
self
.
_edge_uniq_dst_count
=
None
self
.
_node_ids
=
None
self
.
_graph_lod
=
None
self
.
_graph_lod
=
None
self
.
_num_graph
=
None
self
.
_num_graph
=
None
self
.
_num_edges
=
None
self
.
_num_edges
=
None
...
@@ -416,13 +415,6 @@ class StaticGraphWrapper(BaseGraphWrapper):
...
@@ -416,13 +415,6 @@ class StaticGraphWrapper(BaseGraphWrapper):
value
=
graph_lod
)
value
=
graph_lod
)
self
.
_initializers
.
append
(
init
)
self
.
_initializers
.
append
(
init
)
node_ids_value
=
np
.
arange
(
0
,
graph
.
num_nodes
,
dtype
=
"int64"
)
self
.
_node_ids
,
init
=
paddle_helper
.
constant
(
name
=
self
.
_data_name_prefix
+
"/node_ids"
,
dtype
=
"int64"
,
value
=
node_ids_value
)
self
.
_initializers
.
append
(
init
)
self
.
_indegree
,
init
=
paddle_helper
.
constant
(
self
.
_indegree
,
init
=
paddle_helper
.
constant
(
name
=
self
.
_data_name_prefix
+
"/indegree"
,
name
=
self
.
_data_name_prefix
+
"/indegree"
,
dtype
=
"int64"
,
dtype
=
"int64"
,
...
@@ -601,12 +593,6 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -601,12 +593,6 @@ class GraphWrapper(BaseGraphWrapper):
dtype
=
"int32"
,
dtype
=
"int32"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_node_ids
=
L
.
data
(
self
.
_data_name_prefix
+
"/node_ids"
,
shape
=
[
None
],
append_batch_size
=
False
,
dtype
=
"int64"
,
stop_gradient
=
True
)
self
.
_indegree
=
L
.
data
(
self
.
_indegree
=
L
.
data
(
self
.
_data_name_prefix
+
"/indegree"
,
self
.
_data_name_prefix
+
"/indegree"
,
shape
=
[
None
],
shape
=
[
None
],
...
@@ -619,7 +605,6 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -619,7 +605,6 @@ class GraphWrapper(BaseGraphWrapper):
self
.
_num_nodes
,
self
.
_num_nodes
,
self
.
_edge_uniq_dst
,
self
.
_edge_uniq_dst
,
self
.
_edge_uniq_dst_count
,
self
.
_edge_uniq_dst_count
,
self
.
_node_ids
,
self
.
_indegree
,
self
.
_indegree
,
self
.
_graph_lod
,
self
.
_graph_lod
,
self
.
_num_graph
,
self
.
_num_graph
,
...
@@ -700,7 +685,6 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -700,7 +685,6 @@ class GraphWrapper(BaseGraphWrapper):
[
graph
.
num_nodes
],
dtype
=
"int64"
)
[
graph
.
num_nodes
],
dtype
=
"int64"
)
feed_dict
[
self
.
_data_name_prefix
+
'/uniq_dst'
]
=
uniq_dst
feed_dict
[
self
.
_data_name_prefix
+
'/uniq_dst'
]
=
uniq_dst
feed_dict
[
self
.
_data_name_prefix
+
'/uniq_dst_count'
]
=
uniq_dst_count
feed_dict
[
self
.
_data_name_prefix
+
'/uniq_dst_count'
]
=
uniq_dst_count
feed_dict
[
self
.
_data_name_prefix
+
'/node_ids'
]
=
graph
.
nodes
feed_dict
[
self
.
_data_name_prefix
+
'/indegree'
]
=
indegree
feed_dict
[
self
.
_data_name_prefix
+
'/indegree'
]
=
indegree
feed_dict
[
self
.
_data_name_prefix
+
'/graph_lod'
]
=
graph_lod
feed_dict
[
self
.
_data_name_prefix
+
'/graph_lod'
]
=
graph_lod
feed_dict
[
self
.
_data_name_prefix
+
'/num_graph'
]
=
np
.
array
(
feed_dict
[
self
.
_data_name_prefix
+
'/num_graph'
]
=
np
.
array
(
...
@@ -746,7 +730,6 @@ class DropEdgeWrapper(BaseGraphWrapper):
...
@@ -746,7 +730,6 @@ class DropEdgeWrapper(BaseGraphWrapper):
self
.
_num_nodes
=
graph_wrapper
.
num_nodes
self
.
_num_nodes
=
graph_wrapper
.
num_nodes
self
.
_graph_lod
=
graph_wrapper
.
graph_lod
self
.
_graph_lod
=
graph_wrapper
.
graph_lod
self
.
_num_graph
=
graph_wrapper
.
num_graph
self
.
_num_graph
=
graph_wrapper
.
num_graph
self
.
_node_ids
=
L
.
range
(
0
,
self
.
_num_nodes
,
step
=
1
,
dtype
=
"int32"
)
# Dropout Edges
# Dropout Edges
src
,
dst
=
graph_wrapper
.
edges
src
,
dst
=
graph_wrapper
.
edges
...
@@ -780,3 +763,89 @@ class DropEdgeWrapper(BaseGraphWrapper):
...
@@ -780,3 +763,89 @@ class DropEdgeWrapper(BaseGraphWrapper):
self
.
_edge_uniq_dst_count
=
L
.
concat
([
uniq_count
,
last
])
self
.
_edge_uniq_dst_count
=
L
.
concat
([
uniq_count
,
last
])
self
.
_edge_uniq_dst_count
.
stop_gradient
=
True
self
.
_edge_uniq_dst_count
.
stop_gradient
=
True
self
.
_indegree
=
get_degree
(
self
.
_edges_dst
,
self
.
_num_nodes
)
self
.
_indegree
=
get_degree
(
self
.
_edges_dst
,
self
.
_num_nodes
)
class
BatchGraphWrapper
(
BaseGraphWrapper
):
"""Implement a graph wrapper that user can use their own data holder.
And this graph wrapper support multiple graphs which is benefit for data parallel algorithms.
Args:
num_nodes (int32 or int64): Shape [ num_graph ].
num_edges (int32 or int64): Shape [ num_graph ].
edges (int32 or int64): Shape [ total_num_edges_in_the_graphs, 2 ]
node_feats: A dictionary for node features. Each value should be tensor
with shape [ total_num_nodes_in_the_graphs, feature_size]
edge_feats: A dictionary for edge features. Each value should be tensor
with shape [ total_num_edges_in_the_graphs, feature_size]
"""
def
__init__
(
self
,
num_nodes
,
num_edges
,
edges
,
node_feats
=
None
,
edge_feats
=
None
):
super
(
BatchGraphWrapper
,
self
).
__init__
()
node_shift
,
edge_lod
=
self
.
__build_meta_data
(
num_nodes
,
num_edges
)
self
.
__build_edges
(
edges
,
node_shift
,
edge_lod
)
# assign node features
if
node_feats
is
not
None
:
for
key
,
value
in
node_feats
.
items
():
self
.
node_feat_tensor_dict
[
key
]
=
value
# assign edge features
if
edge_feats
is
not
None
:
for
key
,
value
in
edge_feats
.
items
():
self
.
edge_feat_tensor_dict
[
key
]
=
value
# other meta-data
self
.
_edge_uniq_dst
,
_
,
uniq_count
=
L
.
unique_with_counts
(
self
.
_edges_dst
,
dtype
=
"int32"
)
self
.
_edge_uniq_dst
.
stop_gradient
=
True
last
=
L
.
reduce_sum
(
uniq_count
,
keep_dim
=
True
)
uniq_count
=
L
.
cumsum
(
uniq_count
,
exclusive
=
True
)
self
.
_edge_uniq_dst_count
=
L
.
concat
([
uniq_count
,
last
])
self
.
_edge_uniq_dst_count
.
stop_gradient
=
True
self
.
_indegree
=
get_degree
(
self
.
_edges_dst
,
self
.
_num_nodes
)
def
__build_meta_data
(
self
,
num_nodes
,
num_edges
):
""" Merge information for nodes and edges.
"""
num_nodes
=
L
.
reshape
(
num_nodes
,
[
-
1
])
num_edges
=
L
.
reshape
(
num_edges
,
[
-
1
])
num_nodes
=
paddle_helper
.
ensure_dtype
(
num_nodes
,
dtype
=
"int32"
)
num_edges
=
paddle_helper
.
ensure_dtype
(
num_edges
,
dtype
=
"int32"
)
num_graph
=
L
.
shape
(
num_nodes
)[
0
]
sum_num_nodes
=
L
.
reduce_sum
(
num_nodes
)
sum_num_edges
=
L
.
reduce_sum
(
num_edges
)
edge_lod
=
L
.
concat
([
L
.
cumsum
(
num_edges
,
exclusive
=
True
),
sum_num_edges
])
node_shift
=
L
.
cumsum
(
num_nodes
,
exclusive
=
True
)
graph_lod
=
L
.
concat
([
node_shift
,
sum_num_nodes
])
self
.
_num_nodes
=
sum_num_nodes
self
.
_num_edges
=
sum_num_edges
self
.
_num_graph
=
num_graph
self
.
_graph_lod
=
graph_lod
return
node_shift
,
edge_lod
def
__build_edges
(
self
,
edges
,
node_shift
,
edge_lod
):
""" Merge subgraph edges.
"""
src
=
edges
[:,
0
]
dst
=
edges
[:,
1
]
src
=
L
.
reshape
(
src
,
[
-
1
])
dst
=
L
.
reshape
(
dst
,
[
-
1
])
src
=
paddle_helper
.
ensure_dtype
(
src
,
dtype
=
"int32"
)
dst
=
paddle_helper
.
ensure_dtype
(
dst
,
dtype
=
"int32"
)
# preprocess edges
lod_dst
=
L
.
lod_reset
(
dst
,
edge_lod
)
node_shift
=
L
.
reshape
(
node_shift
,
[
-
1
,
1
])
node_shift
=
L
.
sequence_expand_as
(
node_shift
,
lod_dst
)
node_shift
=
L
.
reshape
(
node_shift
,
[
-
1
])
src
=
src
+
node_shift
dst
=
dst
+
node_shift
# sort edges
self
.
_edges_dst
,
index
=
L
.
argsort
(
dst
)
self
.
_edges_src
=
L
.
gather
(
src
,
index
,
overwrite
=
False
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录