Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
0bd10e14
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0bd10e14
编写于
2月 14, 2020
作者:
L
liweibin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
speed up sampling
上级
570bf814
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
278 addition
and
288 deletion
+278
-288
examples/graphsage/reader.py
examples/graphsage/reader.py
+42
-59
examples/graphsage/train.py
examples/graphsage/train.py
+37
-58
examples/graphsage/train_multi.py
examples/graphsage/train_multi.py
+26
-53
examples/graphsage/train_scale.py
examples/graphsage/train_scale.py
+27
-58
pgl/graph.py
pgl/graph.py
+69
-7
pgl/graph_wrapper.py
pgl/graph_wrapper.py
+45
-40
pgl/utils/mp_reader.py
pgl/utils/mp_reader.py
+32
-13
未找到文件。
examples/graphsage/reader.py
浏览文件 @
0bd10e14
...
@@ -19,8 +19,8 @@ import pgl
...
@@ -19,8 +19,8 @@ import pgl
import
time
import
time
from
pgl.utils
import
mp_reader
from
pgl.utils
import
mp_reader
from
pgl.utils.logger
import
log
from
pgl.utils.logger
import
log
import
train
import
time
import
time
import
copy
def
node_batch_iter
(
nodes
,
node_label
,
batch_size
):
def
node_batch_iter
(
nodes
,
node_label
,
batch_size
):
...
@@ -46,12 +46,11 @@ def traverse(item):
...
@@ -46,12 +46,11 @@ def traverse(item):
yield
item
yield
item
def
flat_node_and_edge
(
nodes
,
eids
):
def
flat_node_and_edge
(
nodes
):
"""flat_node_and_edge
"""flat_node_and_edge
"""
"""
nodes
=
list
(
set
(
traverse
(
nodes
)))
nodes
=
list
(
set
(
traverse
(
nodes
)))
eids
=
list
(
set
(
traverse
(
eids
)))
return
nodes
return
nodes
,
eids
def
worker
(
batch_info
,
graph
,
graph_wrapper
,
samples
):
def
worker
(
batch_info
,
graph
,
graph_wrapper
,
samples
):
...
@@ -61,31 +60,42 @@ def worker(batch_info, graph, graph_wrapper, samples):
...
@@ -61,31 +60,42 @@ def worker(batch_info, graph, graph_wrapper, samples):
def
work
():
def
work
():
"""work
"""work
"""
"""
first
=
True
_graph_wrapper
=
copy
.
copy
(
graph_wrapper
)
_graph_wrapper
.
node_feat_tensor_dict
=
{}
for
batch_train_samples
,
batch_train_labels
in
batch_info
:
for
batch_train_samples
,
batch_train_labels
in
batch_info
:
start_nodes
=
batch_train_samples
start_nodes
=
batch_train_samples
nodes
=
start_nodes
nodes
=
start_nodes
e
id
s
=
[]
e
dge
s
=
[]
for
max_deg
in
samples
:
for
max_deg
in
samples
:
pred
,
pred_eid
=
graph
.
sample_predecessor
(
pred_nodes
=
graph
.
sample_predecessor
(
start_nodes
,
max_degree
=
max_deg
,
return_eids
=
True
)
start_nodes
,
max_degree
=
max_deg
)
for
dst_node
,
src_nodes
in
zip
(
start_nodes
,
pred_nodes
):
for
src_node
in
src_nodes
:
edges
.
append
((
src_node
,
dst_node
))
last_nodes
=
nodes
last_nodes
=
nodes
nodes
=
[
nodes
,
pred
]
nodes
=
[
nodes
,
pred_nodes
]
eids
=
[
eids
,
pred_eid
]
nodes
=
flat_node_and_edge
(
nodes
)
nodes
,
eids
=
flat_node_and_edge
(
nodes
,
eids
)
# Find new nodes
# Find new nodes
start_nodes
=
list
(
set
(
nodes
)
-
set
(
last_nodes
))
start_nodes
=
list
(
set
(
nodes
)
-
set
(
last_nodes
))
if
len
(
start_nodes
)
==
0
:
if
len
(
start_nodes
)
==
0
:
break
break
subgraph
=
graph
.
subgraph
(
nodes
=
nodes
,
eid
=
eids
)
subgraph
=
graph
.
subgraph
(
nodes
=
nodes
,
edges
=
edges
,
with_node_feat
=
False
,
with_edge_feat
=
False
)
sub_node_index
=
subgraph
.
reindex_from_parrent_nodes
(
sub_node_index
=
subgraph
.
reindex_from_parrent_nodes
(
batch_train_samples
)
batch_train_samples
)
feed_dict
=
graph_wrapper
.
to_feed
(
subgraph
)
feed_dict
=
_
graph_wrapper
.
to_feed
(
subgraph
)
feed_dict
[
"node_label"
]
=
np
.
expand_dims
(
feed_dict
[
"node_label"
]
=
np
.
expand_dims
(
np
.
array
(
np
.
array
(
batch_train_labels
,
dtype
=
"int64"
),
-
1
)
batch_train_labels
,
dtype
=
"int64"
),
-
1
)
feed_dict
[
"node_index"
]
=
sub_node_index
feed_dict
[
"node_index"
]
=
sub_node_index
feed_dict
[
"parent_node_index"
]
=
np
.
array
(
nodes
,
dtype
=
"int64"
)
yield
feed_dict
yield
feed_dict
return
work
return
work
...
@@ -97,23 +107,25 @@ def multiprocess_graph_reader(graph,
...
@@ -97,23 +107,25 @@ def multiprocess_graph_reader(graph,
node_index
,
node_index
,
batch_size
,
batch_size
,
node_label
,
node_label
,
with_parent_node_index
=
False
,
num_workers
=
4
):
num_workers
=
4
):
"""multiprocess_graph_reader
"""multiprocess_graph_reader
"""
"""
def
parse_to_subgraph
(
rd
):
def
parse_to_subgraph
(
rd
,
prefix
,
node_feat
,
_with_parent_node_index
):
"""parse_to_subgraph
"""parse_to_subgraph
"""
"""
def
work
():
def
work
():
"""work
"""work
"""
"""
last
=
time
.
time
()
for
data
in
rd
():
for
data
in
rd
():
this
=
time
.
time
()
feed_dict
=
data
feed_dict
=
data
now
=
time
.
time
()
for
key
in
node_feat
:
last
=
now
feed_dict
[
prefix
+
'/node_feat/'
+
key
]
=
node_feat
[
key
][
feed_dict
[
"parent_node_index"
]]
if
not
_with_parent_node_index
:
del
feed_dict
[
"parent_node_index"
]
yield
feed_dict
yield
feed_dict
return
work
return
work
...
@@ -129,46 +141,17 @@ def multiprocess_graph_reader(graph,
...
@@ -129,46 +141,17 @@ def multiprocess_graph_reader(graph,
reader_pool
.
append
(
reader_pool
.
append
(
worker
(
batch_info
[
block_size
*
i
:
block_size
*
(
i
+
1
)],
graph
,
worker
(
batch_info
[
block_size
*
i
:
block_size
*
(
i
+
1
)],
graph
,
graph_wrapper
,
samples
))
graph_wrapper
,
samples
))
multi_process_sample
=
mp_reader
.
multiprocess_reader
(
reader_pool
,
use_pipe
=
True
,
queue_size
=
1000
)
r
=
parse_to_subgraph
(
multi_process_sample
)
return
paddle
.
reader
.
buffered
(
r
,
1000
)
return
reader
()
def
graph_reader
(
graph
,
graph_wrapper
,
samples
,
node_index
,
batch_size
,
node_label
):
"""graph_reader"""
def
reader
():
"""reader"""
for
batch_train_samples
,
batch_train_labels
in
node_batch_iter
(
node_index
,
node_label
,
batch_size
=
batch_size
):
start_nodes
=
batch_train_samples
nodes
=
start_nodes
eids
=
[]
for
max_deg
in
samples
:
pred
,
pred_eid
=
graph
.
sample_predecessor
(
start_nodes
,
max_degree
=
max_deg
,
return_eids
=
True
)
last_nodes
=
nodes
nodes
=
[
nodes
,
pred
]
eids
=
[
eids
,
pred_eid
]
nodes
,
eids
=
flat_node_and_edge
(
nodes
,
eids
)
# Find new nodes
start_nodes
=
list
(
set
(
nodes
)
-
set
(
last_nodes
))
if
len
(
start_nodes
)
==
0
:
break
subgraph
=
graph
.
subgraph
(
nodes
=
nodes
,
eid
=
eids
)
if
len
(
reader_pool
)
==
1
:
feed_dict
=
graph_wrapper
.
to_feed
(
subgraph
)
r
=
parse_to_subgraph
(
reader_pool
[
0
],
sub_node_index
=
subgraph
.
reindex_from_parrent_nodes
(
repr
(
graph_wrapper
),
graph
.
node_feat
,
batch_train_samples
)
with_parent_node_index
)
else
:
multi_process_sample
=
mp_reader
.
multiprocess_reader
(
reader_pool
,
use_pipe
=
True
,
queue_size
=
1000
)
r
=
parse_to_subgraph
(
multi_process_sample
,
repr
(
graph_wrapper
),
graph
.
node_feat
,
with_parent_node_index
)
return
paddle
.
reader
.
buffered
(
r
,
num_workers
)
feed_dict
[
"node_label"
]
=
np
.
expand_dims
(
return
reader
()
np
.
array
(
batch_train_labels
,
dtype
=
"int64"
),
-
1
)
feed_dict
[
"node_index"
]
=
np
.
array
(
sub_node_index
,
dtype
=
"int32"
)
yield
feed_dict
return
paddle
.
reader
.
buffered
(
reader
,
1000
)
examples/graphsage/train.py
浏览文件 @
0bd10e14
...
@@ -63,10 +63,7 @@ def load_data(normalize=True, symmetry=True):
...
@@ -63,10 +63,7 @@ def load_data(normalize=True, symmetry=True):
log
.
info
(
"Feature shape %s"
%
(
repr
(
feature
.
shape
)))
log
.
info
(
"Feature shape %s"
%
(
repr
(
feature
.
shape
)))
graph
=
pgl
.
graph
.
Graph
(
graph
=
pgl
.
graph
.
Graph
(
num_nodes
=
feature
.
shape
[
0
],
num_nodes
=
feature
.
shape
[
0
],
edges
=
list
(
zip
(
src
,
dst
)))
edges
=
list
(
zip
(
src
,
dst
)),
node_feat
=
{
"index"
:
np
.
arange
(
0
,
len
(
feature
),
dtype
=
"int64"
)})
return
{
return
{
"graph"
:
graph
,
"graph"
:
graph
,
...
@@ -89,7 +86,13 @@ def build_graph_model(graph_wrapper, num_class, k_hop, graphsage_type,
...
@@ -89,7 +86,13 @@ def build_graph_model(graph_wrapper, num_class, k_hop, graphsage_type,
node_label
=
fluid
.
layers
.
data
(
node_label
=
fluid
.
layers
.
data
(
"node_label"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
,
append_batch_size
=
False
)
"node_label"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
,
append_batch_size
=
False
)
feature
=
fluid
.
layers
.
gather
(
feature
,
graph_wrapper
.
node_feat
[
'index'
])
parent_node_index
=
fluid
.
layers
.
data
(
"parent_node_index"
,
shape
=
[
None
],
dtype
=
"int64"
,
append_batch_size
=
False
)
feature
=
fluid
.
layers
.
gather
(
feature
,
parent_node_index
)
feature
.
stop_gradient
=
True
feature
.
stop_gradient
=
True
for
i
in
range
(
k_hop
):
for
i
in
range
(
k_hop
):
...
@@ -221,59 +224,35 @@ def main(args):
...
@@ -221,59 +224,35 @@ def main(args):
exe
.
run
(
startup_program
)
exe
.
run
(
startup_program
)
feature_init
(
place
)
feature_init
(
place
)
if
args
.
sample_workers
>
1
:
train_iter
=
reader
.
multiprocess_graph_reader
(
train_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
data
[
'graph'
],
graph_wrapper
,
graph_wrapper
,
samples
=
samples
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
with_parent_node_index
=
True
,
node_index
=
data
[
'train_index'
],
node_index
=
data
[
'train_index'
],
node_label
=
data
[
"train_label"
])
node_label
=
data
[
"train_label"
])
else
:
train_iter
=
reader
.
graph_reader
(
val_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
data
[
'graph'
],
graph_wrapper
,
graph_wrapper
,
samples
=
samples
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
num_workers
=
args
.
sample_workers
,
node_index
=
data
[
'train_index'
],
batch_size
=
args
.
batch_size
,
node_label
=
data
[
"train_label"
])
with_parent_node_index
=
True
,
node_index
=
data
[
'val_index'
],
if
args
.
sample_workers
>
1
:
node_label
=
data
[
"val_label"
])
val_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
test_iter
=
reader
.
multiprocess_graph_reader
(
graph_wrapper
,
data
[
'graph'
],
samples
=
samples
,
graph_wrapper
,
num_workers
=
args
.
sample_workers
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
num_workers
=
args
.
sample_workers
,
node_index
=
data
[
'val_index'
],
batch_size
=
args
.
batch_size
,
node_label
=
data
[
"val_label"
])
with_parent_node_index
=
True
,
else
:
node_index
=
data
[
'test_index'
],
val_iter
=
reader
.
graph_reader
(
node_label
=
data
[
"test_label"
])
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"val_label"
])
if
args
.
sample_workers
>
1
:
test_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
else
:
test_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
for
epoch
in
range
(
args
.
epoch
):
for
epoch
in
range
(
args
.
epoch
):
run_epoch
(
run_epoch
(
...
...
examples/graphsage/train_multi.py
浏览文件 @
0bd10e14
...
@@ -262,59 +262,32 @@ def main(args):
...
@@ -262,59 +262,32 @@ def main(args):
else
:
else
:
train_exe
=
exe
train_exe
=
exe
if
args
.
sample_workers
>
1
:
train_iter
=
reader
.
multiprocess_graph_reader
(
train_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
data
[
'graph'
],
graph_wrapper
,
graph_wrapper
,
samples
=
samples
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'train_index'
],
node_index
=
data
[
'train_index'
],
node_label
=
data
[
"train_label"
])
node_label
=
data
[
"train_label"
])
else
:
val_iter
=
reader
.
multiprocess_graph_reader
(
train_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
data
[
'graph'
],
graph_wrapper
,
graph_wrapper
,
samples
=
samples
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'train_index'
],
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"train_label"
])
node_label
=
data
[
"val_label"
])
if
args
.
sample_workers
>
1
:
test_iter
=
reader
.
multiprocess_graph_reader
(
val_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
data
[
'graph'
],
graph_wrapper
,
graph_wrapper
,
samples
=
samples
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"test_label"
])
node_label
=
data
[
"val_label"
])
else
:
val_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"val_label"
])
if
args
.
sample_workers
>
1
:
test_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
else
:
test_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
for
epoch
in
range
(
args
.
epoch
):
for
epoch
in
range
(
args
.
epoch
):
run_epoch
(
run_epoch
(
...
...
examples/graphsage/train_scale.py
浏览文件 @
0bd10e14
...
@@ -97,11 +97,7 @@ def load_data(normalize=True, symmetry=True, scale=1):
...
@@ -97,11 +97,7 @@ def load_data(normalize=True, symmetry=True, scale=1):
graph
=
pgl
.
graph
.
Graph
(
graph
=
pgl
.
graph
.
Graph
(
num_nodes
=
feature
.
shape
[
0
],
num_nodes
=
feature
.
shape
[
0
],
edges
=
edges
,
edges
=
edges
,
node_feat
=
{
node_feat
=
{
"feature"
:
feature
})
"index"
:
np
.
arange
(
0
,
len
(
feature
),
dtype
=
"int64"
),
"feature"
:
feature
})
return
{
return
{
"graph"
:
graph
,
"graph"
:
graph
,
...
@@ -244,59 +240,32 @@ def main(args):
...
@@ -244,59 +240,32 @@ def main(args):
test_program
=
train_program
.
clone
(
for_test
=
True
)
test_program
=
train_program
.
clone
(
for_test
=
True
)
if
args
.
sample_workers
>
1
:
train_iter
=
reader
.
multiprocess_graph_reader
(
train_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
data
[
'graph'
],
graph_wrapper
,
graph_wrapper
,
samples
=
samples
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'train_index'
],
node_index
=
data
[
'train_index'
],
node_label
=
data
[
"train_label"
])
node_label
=
data
[
"train_label"
])
else
:
val_iter
=
reader
.
multiprocess_graph_reader
(
train_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
data
[
'graph'
],
graph_wrapper
,
graph_wrapper
,
samples
=
samples
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'train_index'
],
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"train_label"
])
node_label
=
data
[
"val_label"
])
if
args
.
sample_workers
>
1
:
test_iter
=
reader
.
multiprocess_graph_reader
(
val_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
data
[
'graph'
],
graph_wrapper
,
graph_wrapper
,
samples
=
samples
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"test_label"
])
node_label
=
data
[
"val_label"
])
else
:
val_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"val_label"
])
if
args
.
sample_workers
>
1
:
test_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
else
:
test_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
with
fluid
.
program_guard
(
train_program
,
startup_program
):
with
fluid
.
program_guard
(
train_program
,
startup_program
):
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
args
.
lr
)
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
args
.
lr
)
...
...
pgl/graph.py
浏览文件 @
0bd10e14
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
This package implement Graph structure for handling graph data.
This package implement Graph structure for handling graph data.
"""
"""
import
os
import
numpy
as
np
import
numpy
as
np
import
pickle
as
pkl
import
pickle
as
pkl
import
time
import
time
...
@@ -77,6 +78,15 @@ class EdgeIndex(object):
...
@@ -77,6 +78,15 @@ class EdgeIndex(object):
"""
"""
return
self
.
_sorted_u
,
self
.
_sorted_v
,
self
.
_sorted_eid
return
self
.
_sorted_u
,
self
.
_sorted_v
,
self
.
_sorted_eid
def
dump
(
self
,
path
):
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
np
.
save
(
path
+
'/degree.npy'
,
self
.
_degree
)
np
.
save
(
path
+
'/sorted_u.npy'
,
self
.
_sorted_u
)
np
.
save
(
path
+
'/sorted_v.npy'
,
self
.
_sorted_v
)
np
.
save
(
path
+
'/sorted_eid.npy'
,
self
.
_sorted_eid
)
np
.
save
(
path
+
'/indptr.npy'
,
self
.
_indptr
)
class
Graph
(
object
):
class
Graph
(
object
):
"""Implementation of graph structure in pgl.
"""Implementation of graph structure in pgl.
...
@@ -136,6 +146,18 @@ class Graph(object):
...
@@ -136,6 +146,18 @@ class Graph(object):
self
.
_adj_src_index
=
None
self
.
_adj_src_index
=
None
self
.
_adj_dst_index
=
None
self
.
_adj_dst_index
=
None
def
dump
(
self
,
path
):
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
np
.
save
(
path
+
'/num_nodes.npy'
,
self
.
_num_nodes
)
np
.
save
(
path
+
'/edges.npy'
,
self
.
_edges
)
if
self
.
_adj_src_index
:
self
.
_adj_src_index
.
dump
(
path
+
'/adj_src'
)
if
self
.
_adj_dst_index
:
self
.
_adj_dst_index
.
dump
(
path
+
'/adj_dst'
)
@
property
@
property
def
adj_src_index
(
self
):
def
adj_src_index
(
self
):
"""Return an EdgeIndex object for src.
"""Return an EdgeIndex object for src.
...
@@ -506,7 +528,13 @@ class Graph(object):
...
@@ -506,7 +528,13 @@ class Graph(object):
(
key
,
_hide_num_nodes
(
value
.
shape
),
value
.
dtype
))
(
key
,
_hide_num_nodes
(
value
.
shape
),
value
.
dtype
))
return
edge_feat_info
return
edge_feat_info
def
subgraph
(
self
,
nodes
,
eid
=
None
,
edges
=
None
):
def
subgraph
(
self
,
nodes
,
eid
=
None
,
edges
=
None
,
edge_feats
=
None
,
with_node_feat
=
True
,
with_edge_feat
=
True
):
"""Generate subgraph with nodes and edge ids.
"""Generate subgraph with nodes and edge ids.
This function will generate a :code:`pgl.graph.Subgraph` object and
This function will generate a :code:`pgl.graph.Subgraph` object and
...
@@ -521,6 +549,10 @@ class Graph(object):
...
@@ -521,6 +549,10 @@ class Graph(object):
eid (optional): Edge ids which will be included in the subgraph.
eid (optional): Edge ids which will be included in the subgraph.
edges (optional): Edge(src, dst) list which will be included in the subgraph.
edges (optional): Edge(src, dst) list which will be included in the subgraph.
with_node_feat: Whether to inherit node features from parent graph.
with_edge_feat: Whether to inherit edge features from parent graph.
Return:
Return:
A :code:`pgl.graph.Subgraph` object.
A :code:`pgl.graph.Subgraph` object.
...
@@ -543,14 +575,20 @@ class Graph(object):
...
@@ -543,14 +575,20 @@ class Graph(object):
len
(
edges
),
dtype
=
"int64"
),
edges
,
reindex
)
len
(
edges
),
dtype
=
"int64"
),
edges
,
reindex
)
sub_edge_feat
=
{}
sub_edge_feat
=
{}
for
key
,
value
in
self
.
_edge_feat
.
items
():
if
edges
is
None
:
if
eid
is
None
:
if
with_edge_feat
:
raise
ValueError
(
"Eid can not be None with edge features."
)
for
key
,
value
in
self
.
_edge_feat
.
items
():
sub_edge_feat
[
key
]
=
value
[
eid
]
if
eid
is
None
:
raise
ValueError
(
"Eid can not be None with edge features."
)
sub_edge_feat
[
key
]
=
value
[
eid
]
else
:
sub_edge_feat
=
edge_feats
sub_node_feat
=
{}
sub_node_feat
=
{}
for
key
,
value
in
self
.
_node_feat
.
items
():
if
with_node_feat
:
sub_node_feat
[
key
]
=
value
[
nodes
]
for
key
,
value
in
self
.
_node_feat
.
items
():
sub_node_feat
[
key
]
=
value
[
nodes
]
subgraph
=
SubGraph
(
subgraph
=
SubGraph
(
num_nodes
=
len
(
nodes
),
num_nodes
=
len
(
nodes
),
...
@@ -779,3 +817,27 @@ class SubGraph(Graph):
...
@@ -779,3 +817,27 @@ class SubGraph(Graph):
A list of node ids in parent graph.
A list of node ids in parent graph.
"""
"""
return
graph_kernel
.
map_nodes
(
nodes
,
self
.
_to_reindex
)
return
graph_kernel
.
map_nodes
(
nodes
,
self
.
_to_reindex
)
class
MemmapEdgeIndex
(
EdgeIndex
):
def
__init__
(
self
,
path
):
self
.
_degree
=
np
.
load
(
path
+
'/degree.npy'
,
mmap_mode
=
"r"
)
self
.
_sorted_u
=
np
.
load
(
path
+
'/sorted_u.npy'
,
mmap_mode
=
"r"
)
self
.
_sorted_v
=
np
.
load
(
path
+
'/sorted_v.npy'
,
mmap_mode
=
"r"
)
self
.
_sorted_eid
=
np
.
load
(
path
+
'/sorted_eid.npy'
,
mmap_mode
=
"r"
)
self
.
_indptr
=
np
.
load
(
path
+
'/indptr.npy'
,
mmap_mode
=
"r"
)
class
MemmapGraph
(
Graph
):
def
__init__
(
self
,
path
):
self
.
_num_nodes
=
np
.
load
(
path
+
'/num_nodes.npy'
)
self
.
_edges
=
np
.
load
(
path
+
'/edges.npy'
,
mmap_mode
=
"r"
)
if
os
.
path
.
exists
(
path
+
'/adj_src'
):
self
.
_adj_src_index
=
MemmapEdgeIndex
(
path
+
'/adj_src'
)
else
:
self
.
_adj_src_index
=
None
if
os
.
path
.
exists
(
path
+
'/adj_dst'
):
self
.
_adj_dst_index
=
MemmapEdgeIndex
(
path
+
'/adj_dst'
)
else
:
self
.
_adj_dst_index
=
None
pgl/graph_wrapper.py
浏览文件 @
0bd10e14
...
@@ -89,8 +89,8 @@ class BaseGraphWrapper(object):
...
@@ -89,8 +89,8 @@ class BaseGraphWrapper(object):
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_
node_feat_tensor_dict
=
{}
self
.
node_feat_tensor_dict
=
{}
self
.
_
edge_feat_tensor_dict
=
{}
self
.
edge_feat_tensor_dict
=
{}
self
.
_edges_src
=
None
self
.
_edges_src
=
None
self
.
_edges_dst
=
None
self
.
_edges_dst
=
None
self
.
_num_nodes
=
None
self
.
_num_nodes
=
None
...
@@ -98,6 +98,10 @@ class BaseGraphWrapper(object):
...
@@ -98,6 +98,10 @@ class BaseGraphWrapper(object):
self
.
_edge_uniq_dst
=
None
self
.
_edge_uniq_dst
=
None
self
.
_edge_uniq_dst_count
=
None
self
.
_edge_uniq_dst_count
=
None
self
.
_node_ids
=
None
self
.
_node_ids
=
None
self
.
_data_name_prefix
=
""
def
__repr__
(
self
):
return
self
.
_data_name_prefix
def
send
(
self
,
message_func
,
nfeat_list
=
None
,
efeat_list
=
None
):
def
send
(
self
,
message_func
,
nfeat_list
=
None
,
efeat_list
=
None
):
"""Send message from all src nodes to dst nodes.
"""Send message from all src nodes to dst nodes.
...
@@ -220,7 +224,7 @@ class BaseGraphWrapper(object):
...
@@ -220,7 +224,7 @@ class BaseGraphWrapper(object):
A dictionary whose keys are the feature names and the values
A dictionary whose keys are the feature names and the values
are feature tensor.
are feature tensor.
"""
"""
return
self
.
_
edge_feat_tensor_dict
return
self
.
edge_feat_tensor_dict
@
property
@
property
def
node_feat
(
self
):
def
node_feat
(
self
):
...
@@ -230,7 +234,7 @@ class BaseGraphWrapper(object):
...
@@ -230,7 +234,7 @@ class BaseGraphWrapper(object):
A dictionary whose keys are the feature names and the values
A dictionary whose keys are the feature names and the values
are feature tensor.
are feature tensor.
"""
"""
return
self
.
_
node_feat_tensor_dict
return
self
.
node_feat_tensor_dict
def
indegree
(
self
):
def
indegree
(
self
):
"""Return the indegree tensor for all nodes.
"""Return the indegree tensor for all nodes.
...
@@ -298,8 +302,8 @@ class StaticGraphWrapper(BaseGraphWrapper):
...
@@ -298,8 +302,8 @@ class StaticGraphWrapper(BaseGraphWrapper):
def
__init__
(
self
,
name
,
graph
,
place
):
def
__init__
(
self
,
name
,
graph
,
place
):
super
(
StaticGraphWrapper
,
self
).
__init__
()
super
(
StaticGraphWrapper
,
self
).
__init__
()
self
.
_data_name_prefix
=
name
self
.
_initializers
=
[]
self
.
_initializers
=
[]
self
.
__data_name_prefix
=
name
self
.
__create_graph_attr
(
graph
)
self
.
__create_graph_attr
(
graph
)
def
__create_graph_attr
(
self
,
graph
):
def
__create_graph_attr
(
self
,
graph
):
...
@@ -326,43 +330,43 @@ class StaticGraphWrapper(BaseGraphWrapper):
...
@@ -326,43 +330,43 @@ class StaticGraphWrapper(BaseGraphWrapper):
self
.
_edges_src
,
init
=
paddle_helper
.
constant
(
self
.
_edges_src
,
init
=
paddle_helper
.
constant
(
dtype
=
"int64"
,
dtype
=
"int64"
,
value
=
src
,
value
=
src
,
name
=
self
.
_
_
data_name_prefix
+
'/edges_src'
)
name
=
self
.
_data_name_prefix
+
'/edges_src'
)
self
.
_initializers
.
append
(
init
)
self
.
_initializers
.
append
(
init
)
self
.
_edges_dst
,
init
=
paddle_helper
.
constant
(
self
.
_edges_dst
,
init
=
paddle_helper
.
constant
(
dtype
=
"int64"
,
dtype
=
"int64"
,
value
=
dst
,
value
=
dst
,
name
=
self
.
_
_
data_name_prefix
+
'/edges_dst'
)
name
=
self
.
_data_name_prefix
+
'/edges_dst'
)
self
.
_initializers
.
append
(
init
)
self
.
_initializers
.
append
(
init
)
self
.
_num_nodes
,
init
=
paddle_helper
.
constant
(
self
.
_num_nodes
,
init
=
paddle_helper
.
constant
(
dtype
=
"int64"
,
dtype
=
"int64"
,
hide_batch_size
=
False
,
hide_batch_size
=
False
,
value
=
np
.
array
([
graph
.
num_nodes
]),
value
=
np
.
array
([
graph
.
num_nodes
]),
name
=
self
.
_
_
data_name_prefix
+
'/num_nodes'
)
name
=
self
.
_data_name_prefix
+
'/num_nodes'
)
self
.
_initializers
.
append
(
init
)
self
.
_initializers
.
append
(
init
)
self
.
_edge_uniq_dst
,
init
=
paddle_helper
.
constant
(
self
.
_edge_uniq_dst
,
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
"/uniq_dst"
,
name
=
self
.
_data_name_prefix
+
"/uniq_dst"
,
dtype
=
"int64"
,
dtype
=
"int64"
,
value
=
uniq_dst
)
value
=
uniq_dst
)
self
.
_initializers
.
append
(
init
)
self
.
_initializers
.
append
(
init
)
self
.
_edge_uniq_dst_count
,
init
=
paddle_helper
.
constant
(
self
.
_edge_uniq_dst_count
,
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
"/uniq_dst_count"
,
name
=
self
.
_data_name_prefix
+
"/uniq_dst_count"
,
dtype
=
"int32"
,
dtype
=
"int32"
,
value
=
uniq_dst_count
)
value
=
uniq_dst_count
)
self
.
_initializers
.
append
(
init
)
self
.
_initializers
.
append
(
init
)
node_ids_value
=
np
.
arange
(
0
,
graph
.
num_nodes
,
dtype
=
"int64"
)
node_ids_value
=
np
.
arange
(
0
,
graph
.
num_nodes
,
dtype
=
"int64"
)
self
.
_node_ids
,
init
=
paddle_helper
.
constant
(
self
.
_node_ids
,
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
"/node_ids"
,
name
=
self
.
_data_name_prefix
+
"/node_ids"
,
dtype
=
"int64"
,
dtype
=
"int64"
,
value
=
node_ids_value
)
value
=
node_ids_value
)
self
.
_initializers
.
append
(
init
)
self
.
_initializers
.
append
(
init
)
self
.
_indegree
,
init
=
paddle_helper
.
constant
(
self
.
_indegree
,
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
"/indegree"
,
name
=
self
.
_data_name_prefix
+
"/indegree"
,
dtype
=
"int64"
,
dtype
=
"int64"
,
value
=
indegree
)
value
=
indegree
)
self
.
_initializers
.
append
(
init
)
self
.
_initializers
.
append
(
init
)
...
@@ -373,9 +377,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
...
@@ -373,9 +377,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
for
node_feat_name
,
node_feat_value
in
node_feat
.
items
():
for
node_feat_name
,
node_feat_value
in
node_feat
.
items
():
node_feat_shape
=
node_feat_value
.
shape
node_feat_shape
=
node_feat_value
.
shape
node_feat_dtype
=
node_feat_value
.
dtype
node_feat_dtype
=
node_feat_value
.
dtype
self
.
_
node_feat_tensor_dict
[
self
.
node_feat_tensor_dict
[
node_feat_name
],
init
=
paddle_helper
.
constant
(
node_feat_name
],
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
'/node_feat/'
+
name
=
self
.
_data_name_prefix
+
'/node_feat/'
+
node_feat_name
,
node_feat_name
,
dtype
=
node_feat_dtype
,
dtype
=
node_feat_dtype
,
value
=
node_feat_value
)
value
=
node_feat_value
)
...
@@ -387,9 +391,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
...
@@ -387,9 +391,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
for
edge_feat_name
,
edge_feat_value
in
edge_feat
.
items
():
for
edge_feat_name
,
edge_feat_value
in
edge_feat
.
items
():
edge_feat_shape
=
edge_feat_value
.
shape
edge_feat_shape
=
edge_feat_value
.
shape
edge_feat_dtype
=
edge_feat_value
.
dtype
edge_feat_dtype
=
edge_feat_value
.
dtype
self
.
_
edge_feat_tensor_dict
[
self
.
edge_feat_tensor_dict
[
edge_feat_name
],
init
=
paddle_helper
.
constant
(
edge_feat_name
],
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
'/edge_feat/'
+
name
=
self
.
_data_name_prefix
+
'/edge_feat/'
+
edge_feat_name
,
edge_feat_name
,
dtype
=
edge_feat_dtype
,
dtype
=
edge_feat_dtype
,
value
=
edge_feat_value
)
value
=
edge_feat_value
)
...
@@ -477,8 +481,8 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -477,8 +481,8 @@ class GraphWrapper(BaseGraphWrapper):
def
__init__
(
self
,
name
,
place
,
node_feat
=
[],
edge_feat
=
[]):
def
__init__
(
self
,
name
,
place
,
node_feat
=
[],
edge_feat
=
[]):
super
(
GraphWrapper
,
self
).
__init__
()
super
(
GraphWrapper
,
self
).
__init__
()
# collect holders for PyReader
# collect holders for PyReader
self
.
_data_name_prefix
=
name
self
.
_holder_list
=
[]
self
.
_holder_list
=
[]
self
.
__data_name_prefix
=
name
self
.
_place
=
place
self
.
_place
=
place
self
.
__create_graph_attr_holders
()
self
.
__create_graph_attr_holders
()
for
node_feat_name
,
node_feat_shape
,
node_feat_dtype
in
node_feat
:
for
node_feat_name
,
node_feat_shape
,
node_feat_dtype
in
node_feat
:
...
@@ -493,43 +497,43 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -493,43 +497,43 @@ class GraphWrapper(BaseGraphWrapper):
"""Create data holders for graph attributes.
"""Create data holders for graph attributes.
"""
"""
self
.
_edges_src
=
fluid
.
layers
.
data
(
self
.
_edges_src
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
'/edges_src'
,
self
.
_data_name_prefix
+
'/edges_src'
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int64"
,
dtype
=
"int64"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_edges_dst
=
fluid
.
layers
.
data
(
self
.
_edges_dst
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
'/edges_dst'
,
self
.
_data_name_prefix
+
'/edges_dst'
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int64"
,
dtype
=
"int64"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_num_nodes
=
fluid
.
layers
.
data
(
self
.
_num_nodes
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
'/num_nodes'
,
self
.
_data_name_prefix
+
'/num_nodes'
,
shape
=
[
1
],
shape
=
[
1
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
'int64'
,
dtype
=
'int64'
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_edge_uniq_dst
=
fluid
.
layers
.
data
(
self
.
_edge_uniq_dst
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
"/uniq_dst"
,
self
.
_data_name_prefix
+
"/uniq_dst"
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int64"
,
dtype
=
"int64"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_edge_uniq_dst_count
=
fluid
.
layers
.
data
(
self
.
_edge_uniq_dst_count
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
"/uniq_dst_count"
,
self
.
_data_name_prefix
+
"/uniq_dst_count"
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int32"
,
dtype
=
"int32"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_node_ids
=
fluid
.
layers
.
data
(
self
.
_node_ids
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
"/node_ids"
,
self
.
_data_name_prefix
+
"/node_ids"
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int64"
,
dtype
=
"int64"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_indegree
=
fluid
.
layers
.
data
(
self
.
_indegree
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
"/indegree"
,
self
.
_data_name_prefix
+
"/indegree"
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int64"
,
dtype
=
"int64"
,
...
@@ -545,12 +549,12 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -545,12 +549,12 @@ class GraphWrapper(BaseGraphWrapper):
"""Create data holders for node features.
"""Create data holders for node features.
"""
"""
feat_holder
=
fluid
.
layers
.
data
(
feat_holder
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
'/node_feat/'
+
node_feat_name
,
self
.
_data_name_prefix
+
'/node_feat/'
+
node_feat_name
,
shape
=
node_feat_shape
,
shape
=
node_feat_shape
,
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
node_feat_dtype
,
dtype
=
node_feat_dtype
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_
node_feat_tensor_dict
[
node_feat_name
]
=
feat_holder
self
.
node_feat_tensor_dict
[
node_feat_name
]
=
feat_holder
self
.
_holder_list
.
append
(
feat_holder
)
self
.
_holder_list
.
append
(
feat_holder
)
def
__create_graph_edge_feat_holders
(
self
,
edge_feat_name
,
edge_feat_shape
,
def
__create_graph_edge_feat_holders
(
self
,
edge_feat_name
,
edge_feat_shape
,
...
@@ -558,12 +562,12 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -558,12 +562,12 @@ class GraphWrapper(BaseGraphWrapper):
"""Create edge holders for edge features.
"""Create edge holders for edge features.
"""
"""
feat_holder
=
fluid
.
layers
.
data
(
feat_holder
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
'/edge_feat/'
+
edge_feat_name
,
self
.
_data_name_prefix
+
'/edge_feat/'
+
edge_feat_name
,
shape
=
edge_feat_shape
,
shape
=
edge_feat_shape
,
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
edge_feat_dtype
,
dtype
=
edge_feat_dtype
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_
edge_feat_tensor_dict
[
edge_feat_name
]
=
feat_holder
self
.
edge_feat_tensor_dict
[
edge_feat_name
]
=
feat_holder
self
.
_holder_list
.
append
(
feat_holder
)
self
.
_holder_list
.
append
(
feat_holder
)
def
to_feed
(
self
,
graph
):
def
to_feed
(
self
,
graph
):
...
@@ -594,20 +598,21 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -594,20 +598,21 @@ class GraphWrapper(BaseGraphWrapper):
edge_feat
[
key
]
=
value
[
eid
]
edge_feat
[
key
]
=
value
[
eid
]
node_feat
=
graph
.
node_feat
node_feat
=
graph
.
node_feat
feed_dict
[
self
.
__data_name_prefix
+
'/edges_src'
]
=
src
feed_dict
[
self
.
_data_name_prefix
+
'/edges_src'
]
=
src
feed_dict
[
self
.
__data_name_prefix
+
'/edges_dst'
]
=
dst
feed_dict
[
self
.
_data_name_prefix
+
'/edges_dst'
]
=
dst
feed_dict
[
self
.
__data_name_prefix
+
'/num_nodes'
]
=
np
.
array
(
graph
.
num_nodes
)
feed_dict
[
self
.
_data_name_prefix
+
'/num_nodes'
]
=
np
.
array
(
feed_dict
[
self
.
__data_name_prefix
+
'/uniq_dst'
]
=
uniq_dst
graph
.
num_nodes
)
feed_dict
[
self
.
__data_name_prefix
+
'/uniq_dst_count'
]
=
uniq_dst_count
feed_dict
[
self
.
_data_name_prefix
+
'/uniq_dst'
]
=
uniq_dst
feed_dict
[
self
.
__data_name_prefix
+
'/node_ids'
]
=
graph
.
nodes
feed_dict
[
self
.
_data_name_prefix
+
'/uniq_dst_count'
]
=
uniq_dst_count
feed_dict
[
self
.
__data_name_prefix
+
'/indegree'
]
=
indegree
feed_dict
[
self
.
_data_name_prefix
+
'/node_ids'
]
=
graph
.
nodes
feed_dict
[
self
.
_data_name_prefix
+
'/indegree'
]
=
indegree
for
key
in
self
.
_node_feat_tensor_dict
:
feed_dict
[
self
.
__data_name_prefix
+
'/node_feat/'
+
for
key
in
self
.
node_feat_tensor_dict
:
feed_dict
[
self
.
_data_name_prefix
+
'/node_feat/'
+
key
]
=
node_feat
[
key
]
key
]
=
node_feat
[
key
]
for
key
in
self
.
_
edge_feat_tensor_dict
:
for
key
in
self
.
edge_feat_tensor_dict
:
feed_dict
[
self
.
_
_
data_name_prefix
+
'/edge_feat/'
+
feed_dict
[
self
.
_data_name_prefix
+
'/edge_feat/'
+
key
]
=
edge_feat
[
key
]
key
]
=
edge_feat
[
key
]
return
feed_dict
return
feed_dict
...
...
pgl/utils/mp_reader.py
浏览文件 @
0bd10e14
...
@@ -25,6 +25,8 @@ except:
...
@@ -25,6 +25,8 @@ except:
import
numpy
as
np
import
numpy
as
np
import
time
import
time
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
queue
import
Queue
import
threading
def
serialize_data
(
data
):
def
serialize_data
(
data
):
...
@@ -129,22 +131,39 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000, pipe_size=10):
...
@@ -129,22 +131,39 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000, pipe_size=10):
p
.
start
()
p
.
start
()
reader_num
=
len
(
readers
)
reader_num
=
len
(
readers
)
finish_num
=
0
conn_to_remove
=
[]
conn_to_remove
=
[]
finish_flag
=
np
.
zeros
(
len
(
conns
),
dtype
=
"int32"
)
finish_flag
=
np
.
zeros
(
len
(
conns
),
dtype
=
"int32"
)
start
=
time
.
time
()
def
queue_worker
(
sub_conn
,
que
):
while
True
:
buff
=
sub_conn
.
recv
()
sample
=
deserialize_data
(
buff
)
if
sample
is
None
:
que
.
put
(
None
)
sub_conn
.
close
()
break
que
.
put
(
sample
)
thread_pool
=
[]
output_queue
=
Queue
(
maxsize
=
reader_num
)
for
i
in
range
(
reader_num
):
t
=
threading
.
Thread
(
target
=
queue_worker
,
args
=
(
conns
[
i
],
output_queue
))
t
.
daemon
=
True
t
.
start
()
thread_pool
.
append
(
t
)
finish_num
=
0
while
finish_num
<
reader_num
:
while
finish_num
<
reader_num
:
for
conn_id
,
conn
in
enumerate
(
conns
):
sample
=
output_queue
.
get
()
if
finish_flag
[
conn_id
]
>
0
:
if
sample
is
None
:
continue
finish_num
+=
1
if
conn
.
poll
(
0.01
):
else
:
buff
=
conn
.
recv
()
yield
sample
sample
=
deserialize_data
(
buff
)
if
sample
is
None
:
for
thread
in
thread_pool
:
finish_num
+=
1
thread
.
join
()
conn
.
close
()
finish_flag
[
conn_id
]
=
1
else
:
yield
sample
if
use_pipe
:
if
use_pipe
:
return
pipe_reader
return
pipe_reader
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录