Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
d347a2bb
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d347a2bb
编写于
2月 17, 2020
作者:
H
Huang Zhengjie
提交者:
GitHub
2月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #25 from Liwb5/develop
develop PGL v1.1
上级
c87716c4
752b6169
变更
30
显示空白变更内容
内联
并排
Showing
30 changed file
with
1576 addition
and
305 deletion
+1576
-305
examples/GATNE/Dataset.py
examples/GATNE/Dataset.py
+1
-1
examples/GATNE/model.py
examples/GATNE/model.py
+1
-1
examples/graphsage/reader.py
examples/graphsage/reader.py
+42
-59
examples/graphsage/train.py
examples/graphsage/train.py
+37
-58
examples/graphsage/train_multi.py
examples/graphsage/train_multi.py
+26
-53
examples/graphsage/train_scale.py
examples/graphsage/train_scale.py
+27
-58
examples/metapath2vec/Dataset.py
examples/metapath2vec/Dataset.py
+11
-3
examples/metapath2vec/config.yaml
examples/metapath2vec/config.yaml
+2
-1
examples/metapath2vec/sample.py
examples/metapath2vec/sample.py
+1
-1
ogb_examples/linkproppred/main_pgl.py
ogb_examples/linkproppred/main_pgl.py
+208
-0
ogb_examples/nodeproppred/main_pgl.py
ogb_examples/nodeproppred/main_pgl.py
+176
-0
pgl/__init__.py
pgl/__init__.py
+2
-0
pgl/contrib/ogb/__init__.py
pgl/contrib/ogb/__init__.py
+13
-0
pgl/contrib/ogb/graphproppred/__init__.py
pgl/contrib/ogb/graphproppred/__init__.py
+14
-0
pgl/contrib/ogb/graphproppred/dataset_pgl.py
pgl/contrib/ogb/graphproppred/dataset_pgl.py
+152
-0
pgl/contrib/ogb/io/__init__.py
pgl/contrib/ogb/io/__init__.py
+2
-5
pgl/contrib/ogb/io/read_graph_pgl.py
pgl/contrib/ogb/io/read_graph_pgl.py
+49
-0
pgl/contrib/ogb/linkproppred/__init__.py
pgl/contrib/ogb/linkproppred/__init__.py
+15
-0
pgl/contrib/ogb/linkproppred/dataset_pgl.py
pgl/contrib/ogb/linkproppred/dataset_pgl.py
+149
-0
pgl/contrib/ogb/nodeproppred/__init__.py
pgl/contrib/ogb/nodeproppred/__init__.py
+15
-0
pgl/contrib/ogb/nodeproppred/dataset_pgl.py
pgl/contrib/ogb/nodeproppred/dataset_pgl.py
+153
-0
pgl/graph.py
pgl/graph.py
+69
-7
pgl/graph_wrapper.py
pgl/graph_wrapper.py
+45
-40
pgl/heter_graph.py
pgl/heter_graph.py
+0
-0
pgl/heter_graph_wrapper.py
pgl/heter_graph_wrapper.py
+2
-2
pgl/redis_hetergraph.py
pgl/redis_hetergraph.py
+1
-1
pgl/sample.py
pgl/sample.py
+163
-2
pgl/tests/test_hetergraph.py
pgl/tests/test_hetergraph.py
+92
-0
pgl/tests/test_metapath_randomwalk.py
pgl/tests/test_metapath_randomwalk.py
+76
-0
pgl/utils/mp_reader.py
pgl/utils/mp_reader.py
+32
-13
未找到文件。
examples/GATNE/Dataset.py
浏览文件 @
d347a2bb
...
...
@@ -21,7 +21,7 @@ import tqdm
import
numpy
as
np
import
logging
import
random
from
pgl
.contrib
import
heter_graph
from
pgl
import
heter_graph
import
pickle
as
pkl
...
...
examples/GATNE/model.py
浏览文件 @
d347a2bb
...
...
@@ -21,7 +21,7 @@ import logging
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
fl
from
pgl
.contrib
import
heter_graph_wrapper
from
pgl
import
heter_graph_wrapper
class
GATNE
(
object
):
...
...
examples/graphsage/reader.py
浏览文件 @
d347a2bb
...
...
@@ -19,8 +19,8 @@ import pgl
import
time
from
pgl.utils
import
mp_reader
from
pgl.utils.logger
import
log
import
train
import
time
import
copy
def
node_batch_iter
(
nodes
,
node_label
,
batch_size
):
...
...
@@ -46,12 +46,11 @@ def traverse(item):
yield
item
def
flat_node_and_edge
(
nodes
,
eids
):
def
flat_node_and_edge
(
nodes
):
"""flat_node_and_edge
"""
nodes
=
list
(
set
(
traverse
(
nodes
)))
eids
=
list
(
set
(
traverse
(
eids
)))
return
nodes
,
eids
return
nodes
def
worker
(
batch_info
,
graph
,
graph_wrapper
,
samples
):
...
...
@@ -61,31 +60,42 @@ def worker(batch_info, graph, graph_wrapper, samples):
def
work
():
"""work
"""
first
=
True
_graph_wrapper
=
copy
.
copy
(
graph_wrapper
)
_graph_wrapper
.
node_feat_tensor_dict
=
{}
for
batch_train_samples
,
batch_train_labels
in
batch_info
:
start_nodes
=
batch_train_samples
nodes
=
start_nodes
e
id
s
=
[]
e
dge
s
=
[]
for
max_deg
in
samples
:
pred
,
pred_eid
=
graph
.
sample_predecessor
(
start_nodes
,
max_degree
=
max_deg
,
return_eids
=
True
)
pred_nodes
=
graph
.
sample_predecessor
(
start_nodes
,
max_degree
=
max_deg
)
for
dst_node
,
src_nodes
in
zip
(
start_nodes
,
pred_nodes
):
for
src_node
in
src_nodes
:
edges
.
append
((
src_node
,
dst_node
))
last_nodes
=
nodes
nodes
=
[
nodes
,
pred
]
eids
=
[
eids
,
pred_eid
]
nodes
,
eids
=
flat_node_and_edge
(
nodes
,
eids
)
nodes
=
[
nodes
,
pred_nodes
]
nodes
=
flat_node_and_edge
(
nodes
)
# Find new nodes
start_nodes
=
list
(
set
(
nodes
)
-
set
(
last_nodes
))
if
len
(
start_nodes
)
==
0
:
break
subgraph
=
graph
.
subgraph
(
nodes
=
nodes
,
eid
=
eids
)
subgraph
=
graph
.
subgraph
(
nodes
=
nodes
,
edges
=
edges
,
with_node_feat
=
False
,
with_edge_feat
=
False
)
sub_node_index
=
subgraph
.
reindex_from_parrent_nodes
(
batch_train_samples
)
feed_dict
=
graph_wrapper
.
to_feed
(
subgraph
)
feed_dict
=
_
graph_wrapper
.
to_feed
(
subgraph
)
feed_dict
[
"node_label"
]
=
np
.
expand_dims
(
np
.
array
(
batch_train_labels
,
dtype
=
"int64"
),
-
1
)
feed_dict
[
"node_index"
]
=
sub_node_index
feed_dict
[
"parent_node_index"
]
=
np
.
array
(
nodes
,
dtype
=
"int64"
)
yield
feed_dict
return
work
...
...
@@ -97,23 +107,25 @@ def multiprocess_graph_reader(graph,
node_index
,
batch_size
,
node_label
,
with_parent_node_index
=
False
,
num_workers
=
4
):
"""multiprocess_graph_reader
"""
def
parse_to_subgraph
(
rd
):
def
parse_to_subgraph
(
rd
,
prefix
,
node_feat
,
_with_parent_node_index
):
"""parse_to_subgraph
"""
def
work
():
"""work
"""
last
=
time
.
time
()
for
data
in
rd
():
this
=
time
.
time
()
feed_dict
=
data
now
=
time
.
time
()
last
=
now
for
key
in
node_feat
:
feed_dict
[
prefix
+
'/node_feat/'
+
key
]
=
node_feat
[
key
][
feed_dict
[
"parent_node_index"
]]
if
not
_with_parent_node_index
:
del
feed_dict
[
"parent_node_index"
]
yield
feed_dict
return
work
...
...
@@ -129,46 +141,17 @@ def multiprocess_graph_reader(graph,
reader_pool
.
append
(
worker
(
batch_info
[
block_size
*
i
:
block_size
*
(
i
+
1
)],
graph
,
graph_wrapper
,
samples
))
if
len
(
reader_pool
)
==
1
:
r
=
parse_to_subgraph
(
reader_pool
[
0
],
repr
(
graph_wrapper
),
graph
.
node_feat
,
with_parent_node_index
)
else
:
multi_process_sample
=
mp_reader
.
multiprocess_reader
(
reader_pool
,
use_pipe
=
True
,
queue_size
=
1000
)
r
=
parse_to_subgraph
(
multi_process_sample
)
return
paddle
.
reader
.
buffered
(
r
,
1000
)
r
=
parse_to_subgraph
(
multi_process_sample
,
repr
(
graph_wrapper
),
graph
.
node_feat
,
with_parent_node_index
)
return
paddle
.
reader
.
buffered
(
r
,
num_workers
)
return
reader
()
def
graph_reader
(
graph
,
graph_wrapper
,
samples
,
node_index
,
batch_size
,
node_label
):
"""graph_reader"""
def
reader
():
"""reader"""
for
batch_train_samples
,
batch_train_labels
in
node_batch_iter
(
node_index
,
node_label
,
batch_size
=
batch_size
):
start_nodes
=
batch_train_samples
nodes
=
start_nodes
eids
=
[]
for
max_deg
in
samples
:
pred
,
pred_eid
=
graph
.
sample_predecessor
(
start_nodes
,
max_degree
=
max_deg
,
return_eids
=
True
)
last_nodes
=
nodes
nodes
=
[
nodes
,
pred
]
eids
=
[
eids
,
pred_eid
]
nodes
,
eids
=
flat_node_and_edge
(
nodes
,
eids
)
# Find new nodes
start_nodes
=
list
(
set
(
nodes
)
-
set
(
last_nodes
))
if
len
(
start_nodes
)
==
0
:
break
subgraph
=
graph
.
subgraph
(
nodes
=
nodes
,
eid
=
eids
)
feed_dict
=
graph_wrapper
.
to_feed
(
subgraph
)
sub_node_index
=
subgraph
.
reindex_from_parrent_nodes
(
batch_train_samples
)
feed_dict
[
"node_label"
]
=
np
.
expand_dims
(
np
.
array
(
batch_train_labels
,
dtype
=
"int64"
),
-
1
)
feed_dict
[
"node_index"
]
=
np
.
array
(
sub_node_index
,
dtype
=
"int32"
)
yield
feed_dict
return
paddle
.
reader
.
buffered
(
reader
,
1000
)
examples/graphsage/train.py
浏览文件 @
d347a2bb
...
...
@@ -63,10 +63,7 @@ def load_data(normalize=True, symmetry=True):
log
.
info
(
"Feature shape %s"
%
(
repr
(
feature
.
shape
)))
graph
=
pgl
.
graph
.
Graph
(
num_nodes
=
feature
.
shape
[
0
],
edges
=
list
(
zip
(
src
,
dst
)),
node_feat
=
{
"index"
:
np
.
arange
(
0
,
len
(
feature
),
dtype
=
"int64"
)})
num_nodes
=
feature
.
shape
[
0
],
edges
=
list
(
zip
(
src
,
dst
)))
return
{
"graph"
:
graph
,
...
...
@@ -89,7 +86,13 @@ def build_graph_model(graph_wrapper, num_class, k_hop, graphsage_type,
node_label
=
fluid
.
layers
.
data
(
"node_label"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
,
append_batch_size
=
False
)
feature
=
fluid
.
layers
.
gather
(
feature
,
graph_wrapper
.
node_feat
[
'index'
])
parent_node_index
=
fluid
.
layers
.
data
(
"parent_node_index"
,
shape
=
[
None
],
dtype
=
"int64"
,
append_batch_size
=
False
)
feature
=
fluid
.
layers
.
gather
(
feature
,
parent_node_index
)
feature
.
stop_gradient
=
True
for
i
in
range
(
k_hop
):
...
...
@@ -221,57 +224,33 @@ def main(args):
exe
.
run
(
startup_program
)
feature_init
(
place
)
if
args
.
sample_workers
>
1
:
train_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'train_index'
],
node_label
=
data
[
"train_label"
])
else
:
train_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
with_parent_node_index
=
True
,
node_index
=
data
[
'train_index'
],
node_label
=
data
[
"train_label"
])
if
args
.
sample_workers
>
1
:
val_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"val_label"
])
else
:
val_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
with_parent_node_index
=
True
,
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"val_label"
])
if
args
.
sample_workers
>
1
:
test_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
num_workers
=
args
.
sample_workers
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
else
:
test_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
with_parent_node_index
=
True
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
...
...
examples/graphsage/train_multi.py
浏览文件 @
d347a2bb
...
...
@@ -262,7 +262,6 @@ def main(args):
else
:
train_exe
=
exe
if
args
.
sample_workers
>
1
:
train_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
...
...
@@ -271,16 +270,7 @@ def main(args):
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'train_index'
],
node_label
=
data
[
"train_label"
])
else
:
train_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'train_index'
],
node_label
=
data
[
"train_label"
])
if
args
.
sample_workers
>
1
:
val_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
...
...
@@ -289,16 +279,7 @@ def main(args):
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"val_label"
])
else
:
val_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"val_label"
])
if
args
.
sample_workers
>
1
:
test_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
...
...
@@ -307,14 +288,6 @@ def main(args):
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
else
:
test_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
for
epoch
in
range
(
args
.
epoch
):
run_epoch
(
...
...
examples/graphsage/train_scale.py
浏览文件 @
d347a2bb
...
...
@@ -97,11 +97,7 @@ def load_data(normalize=True, symmetry=True, scale=1):
graph
=
pgl
.
graph
.
Graph
(
num_nodes
=
feature
.
shape
[
0
],
edges
=
edges
,
node_feat
=
{
"index"
:
np
.
arange
(
0
,
len
(
feature
),
dtype
=
"int64"
),
"feature"
:
feature
})
node_feat
=
{
"feature"
:
feature
})
return
{
"graph"
:
graph
,
...
...
@@ -244,7 +240,6 @@ def main(args):
test_program
=
train_program
.
clone
(
for_test
=
True
)
if
args
.
sample_workers
>
1
:
train_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
...
...
@@ -253,16 +248,7 @@ def main(args):
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'train_index'
],
node_label
=
data
[
"train_label"
])
else
:
train_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'train_index'
],
node_label
=
data
[
"train_label"
])
if
args
.
sample_workers
>
1
:
val_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
...
...
@@ -271,16 +257,7 @@ def main(args):
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"val_label"
])
else
:
val_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'val_index'
],
node_label
=
data
[
"val_label"
])
if
args
.
sample_workers
>
1
:
test_iter
=
reader
.
multiprocess_graph_reader
(
data
[
'graph'
],
graph_wrapper
,
...
...
@@ -289,14 +266,6 @@ def main(args):
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
else
:
test_iter
=
reader
.
graph_reader
(
data
[
'graph'
],
graph_wrapper
,
samples
=
samples
,
batch_size
=
args
.
batch_size
,
node_index
=
data
[
'test_index'
],
node_label
=
data
[
"test_label"
])
with
fluid
.
program_guard
(
train_program
,
startup_program
):
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
args
.
lr
)
...
...
examples/metapath2vec/Dataset.py
浏览文件 @
d347a2bb
...
...
@@ -23,7 +23,7 @@ import tqdm
import
time
import
logging
import
random
from
pgl
.contrib
import
heter_graph
from
pgl
import
heter_graph
import
pickle
as
pkl
...
...
@@ -71,6 +71,10 @@ class Dataset(object):
if
len
(
walk
)
>
1
:
self
.
sentences_count
+=
1
for
word
in
walk
:
if
int
(
word
)
>=
self
.
config
[
'paper_start_index'
]:
# remove paper
continue
else
:
self
.
token_count
+=
1
word_freq
[
word
]
=
word_freq
.
get
(
word
,
0
)
+
1
...
...
@@ -126,6 +130,10 @@ class Dataset(object):
with
open
(
filename
)
as
reader
:
for
line
in
reader
:
words
=
line
.
strip
().
split
()
words
=
[
w
for
w
in
words
if
int
(
w
)
<
self
.
config
[
'paper_start_index'
]
]
if
len
(
words
)
>
1
:
word_ids
=
[
self
.
word2id
[
w
]
for
w
in
words
if
w
in
self
.
word2id
...
...
examples/metapath2vec/config.yaml
浏览文件 @
d347a2bb
...
...
@@ -42,9 +42,10 @@ data_loader:
walk_path
:
walks/*
word2id_file
:
word2id.pkl
batch_size
:
32
win_size
:
7
# default: 7
win_size
:
5
# default: 7
neg_num
:
5
min_count
:
10
paper_start_index
:
1697414
model
:
type
:
SkipgramModel
...
...
examples/metapath2vec/sample.py
浏览文件 @
d347a2bb
...
...
@@ -28,7 +28,7 @@ import tqdm
import
time
import
logging
import
random
from
pgl
.contrib
import
heter_graph
from
pgl
import
heter_graph
from
pgl.sample
import
metapath_randomwalk
from
utils
import
*
...
...
ogb_examples/linkproppred/main_pgl.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test ogb
"""
import
argparse
import
pgl
import
numpy
as
np
import
paddle.fluid
as
fluid
from
pgl.contrib.ogb.linkproppred.dataset_pgl
import
PglLinkPropPredDataset
from
pgl.utils
import
paddle_helper
from
ogb.linkproppred
import
Evaluator
def
send_func
(
src_feat
,
dst_feat
,
edge_feat
):
"""send_func"""
return
src_feat
[
"h"
]
def
recv_func
(
feat
):
"""recv_func"""
return
fluid
.
layers
.
sequence_pool
(
feat
,
pool_type
=
"sum"
)
class
GNNModel
(
object
):
"""GNNModel"""
def
__init__
(
self
,
name
,
num_nodes
,
emb_dim
,
num_layers
):
self
.
num_nodes
=
num_nodes
self
.
emb_dim
=
emb_dim
self
.
num_layers
=
num_layers
self
.
name
=
name
self
.
src_nodes
=
fluid
.
layers
.
data
(
name
=
'src_nodes'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
,
)
self
.
dst_nodes
=
fluid
.
layers
.
data
(
name
=
'dst_nodes'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
,
)
self
.
edge_label
=
fluid
.
layers
.
data
(
name
=
'edge_label'
,
shape
=
[
None
,
1
],
dtype
=
'float32'
,
)
def
forward
(
self
,
graph
):
"""forward"""
h
=
fluid
.
layers
.
create_parameter
(
shape
=
[
self
.
num_nodes
,
self
.
emb_dim
],
dtype
=
"float32"
,
name
=
self
.
name
+
"_embedding"
)
# edge_attr = fluid.layers.fc(graph.edge_feat["feat"], size=self.emb_dim)
for
layer
in
range
(
self
.
num_layers
):
msg
=
graph
.
send
(
send_func
,
nfeat_list
=
[(
"h"
,
h
)],
)
h
=
graph
.
recv
(
msg
,
recv_func
)
h
=
fluid
.
layers
.
fc
(
h
,
size
=
self
.
emb_dim
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
name
+
'_%s'
%
layer
))
h
=
h
*
graph
.
node_feat
[
"norm"
]
bias
=
fluid
.
layers
.
create_parameter
(
shape
=
[
self
.
emb_dim
],
dtype
=
'float32'
,
is_bias
=
True
,
name
=
self
.
name
+
'_bias_%s'
%
layer
)
h
=
fluid
.
layers
.
elementwise_add
(
h
,
bias
,
act
=
"relu"
)
src
=
fluid
.
layers
.
gather
(
h
,
self
.
src_nodes
)
dst
=
fluid
.
layers
.
gather
(
h
,
self
.
dst_nodes
)
edge_embed
=
src
*
dst
pred
=
fluid
.
layers
.
fc
(
input
=
edge_embed
,
size
=
1
,
name
=
self
.
name
+
"_pred_output"
)
prob
=
fluid
.
layers
.
sigmoid
(
pred
)
loss
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
pred
,
self
.
edge_label
)
loss
=
fluid
.
layers
.
reduce_mean
(
loss
)
return
pred
,
prob
,
loss
def
main
():
"""main
"""
# Training settings
parser
=
argparse
.
ArgumentParser
(
description
=
'Graph Dataset'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
100
,
help
=
'number of epochs to train (default: 100)'
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
"ogbl-ppa"
,
help
=
'dataset name (default: protein protein associations)'
)
args
=
parser
.
parse_args
()
#place = fluid.CUDAPlace(0)
place
=
fluid
.
CPUPlace
()
# Dataset too big to use GPU
### automatic dataloading and splitting
print
(
"loadding dataset"
)
dataset
=
PglLinkPropPredDataset
(
name
=
args
.
dataset
)
splitted_edge
=
dataset
.
get_edge_split
()
print
(
splitted_edge
[
'train_edge'
].
shape
)
print
(
splitted_edge
[
'train_edge_label'
].
shape
)
print
(
"building evaluator"
)
### automatic evaluator. takes dataset name as input
evaluator
=
Evaluator
(
args
.
dataset
)
graph_data
=
dataset
[
0
]
print
(
"num_nodes: %d"
%
graph_data
.
num_nodes
)
train_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
test_program
=
fluid
.
Program
()
# degree normalize
indegree
=
graph_data
.
indegree
()
norm
=
np
.
zeros_like
(
indegree
,
dtype
=
"float32"
)
norm
[
indegree
>
0
]
=
np
.
power
(
indegree
[
indegree
>
0
],
-
0.5
)
graph_data
.
node_feat
[
"norm"
]
=
np
.
expand_dims
(
norm
,
-
1
).
astype
(
"float32"
)
with
fluid
.
program_guard
(
train_program
,
startup_program
):
model
=
GNNModel
(
name
=
"gnn"
,
num_nodes
=
graph_data
.
num_nodes
,
emb_dim
=
64
,
num_layers
=
2
)
gw
=
pgl
.
graph_wrapper
.
GraphWrapper
(
"graph"
,
place
,
node_feat
=
graph_data
.
node_feat_info
(),
edge_feat
=
graph_data
.
edge_feat_info
())
pred
,
prob
,
loss
=
model
.
forward
(
gw
)
val_program
=
train_program
.
clone
(
for_test
=
True
)
with
fluid
.
program_guard
(
train_program
,
startup_program
):
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
1e-2
,
regularization
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0005
))
adam
.
minimize
(
loss
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
feed
=
gw
.
to_feed
(
graph_data
)
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
feed
[
'src_nodes'
]
=
splitted_edge
[
"train_edge"
][:,
0
].
reshape
(
-
1
,
1
)
feed
[
'dst_nodes'
]
=
splitted_edge
[
"train_edge"
][:,
1
].
reshape
(
-
1
,
1
)
feed
[
'edge_label'
]
=
splitted_edge
[
"train_edge_label"
].
astype
(
"float32"
).
reshape
(
-
1
,
1
)
res_loss
,
y_pred
=
exe
.
run
(
train_program
,
feed
=
feed
,
fetch_list
=
[
loss
,
prob
])
print
(
"Loss %s"
%
res_loss
[
0
])
result
=
{}
print
(
"Evaluating..."
)
feed
[
'src_nodes'
]
=
splitted_edge
[
"valid_edge"
][:,
0
].
reshape
(
-
1
,
1
)
feed
[
'dst_nodes'
]
=
splitted_edge
[
"valid_edge"
][:,
1
].
reshape
(
-
1
,
1
)
feed
[
'edge_label'
]
=
splitted_edge
[
"valid_edge_label"
].
astype
(
"float32"
).
reshape
(
-
1
,
1
)
y_pred
=
exe
.
run
(
val_program
,
feed
=
feed
,
fetch_list
=
[
prob
])[
0
]
input_dict
=
{
"y_true"
:
splitted_edge
[
"valid_edge_label"
],
"y_pred"
:
y_pred
.
reshape
(
-
1
,
),
}
result
[
"valid"
]
=
evaluator
.
eval
(
input_dict
)
feed
[
'src_nodes'
]
=
splitted_edge
[
"test_edge"
][:,
0
].
reshape
(
-
1
,
1
)
feed
[
'dst_nodes'
]
=
splitted_edge
[
"test_edge"
][:,
1
].
reshape
(
-
1
,
1
)
feed
[
'edge_label'
]
=
splitted_edge
[
"test_edge_label"
].
astype
(
"float32"
).
reshape
(
-
1
,
1
)
y_pred
=
exe
.
run
(
val_program
,
feed
=
feed
,
fetch_list
=
[
prob
])[
0
]
input_dict
=
{
"y_true"
:
splitted_edge
[
"test_edge_label"
],
"y_pred"
:
y_pred
.
reshape
(
-
1
,
),
}
result
[
"test"
]
=
evaluator
.
eval
(
input_dict
)
print
(
result
)
if
__name__
==
"__main__"
:
main
()
ogb_examples/nodeproppred/main_pgl.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test ogb
"""
import
argparse
import
pgl
import
numpy
as
np
import
paddle.fluid
as
fluid
from
pgl.contrib.ogb.nodeproppred.dataset_pgl
import
PglNodePropPredDataset
from
pgl.utils
import
paddle_helper
from
ogb.nodeproppred
import
Evaluator
def
train
():
pass
def
send_func
(
src_feat
,
dst_feat
,
edge_feat
):
return
(
src_feat
[
"h"
]
+
edge_feat
[
"h"
])
*
src_feat
[
"norm"
]
class
GNNModel
(
object
):
def
__init__
(
self
,
name
,
emb_dim
,
num_task
,
num_layers
):
self
.
num_task
=
num_task
self
.
emb_dim
=
emb_dim
self
.
num_layers
=
num_layers
self
.
name
=
name
def
forward
(
self
,
graph
):
h
=
fluid
.
layers
.
embedding
(
graph
.
node_feat
[
"x"
],
size
=
(
2
,
self
.
emb_dim
))
# name=self.name + "_embedding")
edge_attr
=
fluid
.
layers
.
fc
(
graph
.
edge_feat
[
"feat"
],
size
=
self
.
emb_dim
)
for
layer
in
range
(
self
.
num_layers
):
msg
=
graph
.
send
(
send_func
,
nfeat_list
=
[(
"h"
,
h
),
(
"norm"
,
graph
.
node_feat
[
"norm"
])],
efeat_list
=
[(
"h"
,
edge_attr
)])
h
=
graph
.
recv
(
msg
,
"sum"
)
h
=
fluid
.
layers
.
fc
(
h
,
size
=
self
.
emb_dim
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
name
+
'_%s'
%
layer
))
h
=
h
*
graph
.
node_feat
[
"norm"
]
bias
=
fluid
.
layers
.
create_parameter
(
shape
=
[
self
.
emb_dim
],
dtype
=
'float32'
,
is_bias
=
True
,
name
=
self
.
name
+
'_bias_%s'
%
layer
)
h
=
fluid
.
layers
.
elementwise_add
(
h
,
bias
,
act
=
"relu"
)
pred
=
fluid
.
layers
.
fc
(
h
,
self
.
num_task
,
act
=
None
,
name
=
self
.
name
+
"_pred_output"
)
return
pred
def
main
():
"""main
"""
# Training settings
parser
=
argparse
.
ArgumentParser
(
description
=
'Graph Dataset'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
100
,
help
=
'number of epochs to train (default: 100)'
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
"ogbn-proteins"
,
help
=
'dataset name (default: proteinfunc)'
)
args
=
parser
.
parse_args
()
#device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
#place = fluid.CUDAPlace(0)
place
=
fluid
.
CPUPlace
()
# Dataset too big to use GPU
### automatic dataloading and splitting
dataset
=
PglNodePropPredDataset
(
name
=
args
.
dataset
)
splitted_idx
=
dataset
.
get_idx_split
()
### automatic evaluator. takes dataset name as input
evaluator
=
Evaluator
(
args
.
dataset
)
graph_data
,
label
=
dataset
[
0
]
train_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
test_program
=
fluid
.
Program
()
# degree normalize
indegree
=
graph_data
.
indegree
()
norm
=
np
.
zeros_like
(
indegree
,
dtype
=
"float32"
)
norm
[
indegree
>
0
]
=
np
.
power
(
indegree
[
indegree
>
0
],
-
0.5
)
graph_data
.
node_feat
[
"norm"
]
=
np
.
expand_dims
(
norm
,
-
1
).
astype
(
"float32"
)
graph_data
.
node_feat
[
"x"
]
=
np
.
zeros
((
len
(
indegree
),
1
),
dtype
=
"int64"
)
graph_data
.
edge_feat
[
"feat"
]
=
graph_data
.
edge_feat
[
"feat"
].
astype
(
"float32"
)
model
=
GNNModel
(
name
=
"gnn"
,
num_task
=
dataset
.
num_tasks
,
emb_dim
=
64
,
num_layers
=
2
)
with
fluid
.
program_guard
(
train_program
,
startup_program
):
gw
=
pgl
.
graph_wrapper
.
StaticGraphWrapper
(
"graph"
,
graph_data
,
place
)
pred
=
model
.
forward
(
gw
)
sigmoid_pred
=
fluid
.
layers
.
sigmoid
(
pred
)
val_program
=
train_program
.
clone
(
for_test
=
True
)
initializer
=
[]
with
fluid
.
program_guard
(
train_program
,
startup_program
):
train_node_index
,
init
=
paddle_helper
.
constant
(
"train_node_index"
,
dtype
=
"int64"
,
value
=
splitted_idx
[
"train"
])
initializer
.
append
(
init
)
train_node_label
,
init
=
paddle_helper
.
constant
(
"train_node_label"
,
dtype
=
"float32"
,
value
=
label
[
splitted_idx
[
"train"
]].
astype
(
"float32"
))
initializer
.
append
(
init
)
train_pred_t
=
fluid
.
layers
.
gather
(
pred
,
train_node_index
)
train_loss_t
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
x
=
train_pred_t
,
label
=
train_node_label
)
train_loss_t
=
fluid
.
layers
.
reduce_sum
(
train_loss_t
)
train_pred_t
=
fluid
.
layers
.
sigmoid
(
train_pred_t
)
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
1e-2
,
regularization
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0005
))
adam
.
minimize
(
train_loss_t
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
gw
.
initialize
(
place
)
for
init
in
initializer
:
init
(
place
)
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
loss
=
exe
.
run
(
train_program
,
feed
=
{},
fetch_list
=
[
train_loss_t
])
print
(
"Loss %s"
%
loss
[
0
])
print
(
"Evaluating..."
)
y_pred
=
exe
.
run
(
val_program
,
feed
=
{},
fetch_list
=
[
sigmoid_pred
])[
0
]
result
=
{}
input_dict
=
{
"y_true"
:
label
[
splitted_idx
[
"train"
]],
"y_pred"
:
y_pred
[
splitted_idx
[
"train"
]]
}
result
[
"train"
]
=
evaluator
.
eval
(
input_dict
)
input_dict
=
{
"y_true"
:
label
[
splitted_idx
[
"valid"
]],
"y_pred"
:
y_pred
[
splitted_idx
[
"valid"
]]
}
result
[
"valid"
]
=
evaluator
.
eval
(
input_dict
)
input_dict
=
{
"y_true"
:
label
[
splitted_idx
[
"test"
]],
"y_pred"
:
y_pred
[
splitted_idx
[
"test"
]]
}
result
[
"test"
]
=
evaluator
.
eval
(
input_dict
)
print
(
result
)
if
__name__
==
"__main__"
:
main
()
pgl/__init__.py
浏览文件 @
d347a2bb
...
...
@@ -18,4 +18,6 @@ from pgl import layers
from
pgl
import
graph_wrapper
from
pgl
import
graph
from
pgl
import
data_loader
from
pgl
import
heter_graph
from
pgl
import
heter_graph_wrapper
from
pgl
import
contrib
pgl/contrib/ogb/__init__.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pgl/contrib/ogb/graphproppred/__init__.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""__init__.py"""
pgl/contrib/ogb/graphproppred/dataset_pgl.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PglGraphPropPredDataset
"""
import
pandas
as
pd
import
shutil
,
os
import
os.path
as
osp
import
numpy
as
np
from
ogb.utils.url
import
decide_download
,
download_url
,
extract_zip
from
ogb.graphproppred
import
make_master_file
from
pgl.contrib.ogb.io.read_graph_pgl
import
read_csv_graph_pgl
def
to_bool
(
value
):
"""to_bool"""
return
np
.
array
([
value
],
dtype
=
"bool"
)[
0
]
class
PglGraphPropPredDataset
(
object
):
"""PglGraphPropPredDataset"""
def
__init__
(
self
,
name
,
root
=
"dataset"
):
self
.
name
=
name
## original name, e.g., ogbg-mol-tox21
self
.
dir_name
=
"_"
.
join
(
name
.
split
(
"-"
)
)
+
"_pgl"
## replace hyphen with underline, e.g., ogbg_mol_tox21_dgl
self
.
original_root
=
root
self
.
root
=
osp
.
join
(
root
,
self
.
dir_name
)
self
.
meta_info
=
make_master_file
.
df
#pd.read_csv(
#os.path.join(os.path.dirname(__file__), "master.csv"), index_col=0)
if
not
self
.
name
in
self
.
meta_info
:
print
(
self
.
name
)
error_mssg
=
"Invalid dataset name {}.
\n
"
.
format
(
self
.
name
)
error_mssg
+=
"Available datasets are as follows:
\n
"
error_mssg
+=
"
\n
"
.
join
(
self
.
meta_info
.
keys
())
raise
ValueError
(
error_mssg
)
self
.
download_name
=
self
.
meta_info
[
self
.
name
][
"download_name"
]
## name of downloaded file, e.g., tox21
self
.
num_tasks
=
int
(
self
.
meta_info
[
self
.
name
][
"num tasks"
])
self
.
task_type
=
self
.
meta_info
[
self
.
name
][
"task type"
]
super
(
PglGraphPropPredDataset
,
self
).
__init__
()
self
.
pre_process
()
def
pre_process
(
self
):
"""Pre-processing"""
processed_dir
=
osp
.
join
(
self
.
root
,
'processed'
)
raw_dir
=
osp
.
join
(
self
.
root
,
'raw'
)
pre_processed_file_path
=
osp
.
join
(
processed_dir
,
'pgl_data_processed'
)
if
os
.
path
.
exists
(
pre_processed_file_path
):
# TODO: Load Preprocessed
pass
else
:
### download
url
=
self
.
meta_info
[
self
.
name
][
"url"
]
if
decide_download
(
url
):
path
=
download_url
(
url
,
self
.
original_root
)
extract_zip
(
path
,
self
.
original_root
)
os
.
unlink
(
path
)
# delete folder if there exists
try
:
shutil
.
rmtree
(
self
.
root
)
except
:
pass
shutil
.
move
(
osp
.
join
(
self
.
original_root
,
self
.
download_name
),
self
.
root
)
else
:
print
(
"Stop download."
)
exit
(
-
1
)
### preprocess
add_inverse_edge
=
to_bool
(
self
.
meta_info
[
self
.
name
][
"add_inverse_edge"
])
self
.
graphs
=
read_csv_graph_pgl
(
raw_dir
,
add_inverse_edge
=
add_inverse_edge
)
self
.
graphs
=
np
.
array
(
self
.
graphs
)
self
.
labels
=
np
.
array
(
pd
.
read_csv
(
osp
.
join
(
raw_dir
,
"graph-label.csv.gz"
),
compression
=
"gzip"
,
header
=
None
).
values
)
# TODO: Load Graph
### load preprocessed files
def
get_idx_split
(
self
):
"""Train/Valid/Test split"""
split_type
=
self
.
meta_info
[
self
.
name
][
"split"
]
path
=
osp
.
join
(
self
.
root
,
"split"
,
split_type
)
train_idx
=
pd
.
read_csv
(
osp
.
join
(
path
,
"train.csv.gz"
),
compression
=
"gzip"
,
header
=
None
).
values
.
T
[
0
]
valid_idx
=
pd
.
read_csv
(
osp
.
join
(
path
,
"valid.csv.gz"
),
compression
=
"gzip"
,
header
=
None
).
values
.
T
[
0
]
test_idx
=
pd
.
read_csv
(
osp
.
join
(
path
,
"test.csv.gz"
),
compression
=
"gzip"
,
header
=
None
).
values
.
T
[
0
]
return
{
"train"
:
np
.
array
(
train_idx
,
dtype
=
"int64"
),
"valid"
:
np
.
array
(
valid_idx
,
dtype
=
"int64"
),
"test"
:
np
.
array
(
test_idx
,
dtype
=
"int64"
)
}
def
__getitem__
(
self
,
idx
):
"""Get datapoint with index"""
return
self
.
graphs
[
idx
],
self
.
labels
[
idx
]
def
__len__
(
self
):
"""Length of the dataset
Returns
-------
int
Length of Dataset
"""
return
len
(
self
.
graphs
)
def
__repr__
(
self
):
# pragma: no cover
return
'{}({})'
.
format
(
self
.
__class__
.
__name__
,
len
(
self
))
if
__name__
==
"__main__"
:
pgl_dataset
=
PglGraphPropPredDataset
(
name
=
"ogbg-mol-bace"
)
splitted_index
=
pgl_dataset
.
get_idx_split
()
print
(
pgl_dataset
)
print
(
pgl_dataset
[
3
:
20
])
#print(pgl_dataset[splitted_index["train"]])
#print(pgl_dataset[splitted_index["valid"]])
#print(pgl_dataset[splitted_index["test"]])
pgl/contrib/__init__.py
→
pgl/contrib/
ogb/io/
__init__.py
浏览文件 @
d347a2bb
# Copyright (c) 20
19 PaddlePaddle Authors. All Rights Reserved
# Copyright (c) 20
20 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,8 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate Contrib api
"""
__init__.py
"""
from
pgl.contrib
import
heter_graph
from
pgl.contrib
import
heter_graph_wrapper
pgl/contrib/ogb/io/read_graph_pgl.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""pgl read_csv_graph for ogb
"""
import
pandas
as
pd
import
os.path
as
osp
import
numpy
as
np
import
pgl
from
ogb.io.read_graph_raw
import
read_csv_graph_raw
def
read_csv_graph_pgl
(
raw_dir
,
add_inverse_edge
=
False
):
"""Read CSV data and build PGL Graph
"""
graph_list
=
read_csv_graph_raw
(
raw_dir
,
add_inverse_edge
)
pgl_graph_list
=
[]
for
graph
in
graph_list
:
edges
=
list
(
zip
(
graph
[
"edge_index"
][
0
],
graph
[
"edge_index"
][
1
]))
g
=
pgl
.
graph
.
Graph
(
num_nodes
=
graph
[
"num_nodes"
],
edges
=
edges
)
if
graph
[
"edge_feat"
]
is
not
None
:
g
.
edge_feat
[
"feat"
]
=
graph
[
"edge_feat"
]
if
graph
[
"node_feat"
]
is
not
None
:
g
.
node_feat
[
"feat"
]
=
graph
[
"node_feat"
]
pgl_graph_list
.
append
(
g
)
return
pgl_graph_list
if
__name__
==
"__main__"
:
# graph_list = read_csv_graph_dgl('dataset/proteinfunc_v2/raw', add_inverse_edge = True)
graph_list
=
read_csv_graph_pgl
(
'dataset/ogbn_proteins_pgl/raw'
,
add_inverse_edge
=
True
)
print
(
graph_list
)
pgl/contrib/ogb/linkproppred/__init__.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""__init__.py
"""
pgl/contrib/ogb/linkproppred/dataset_pgl.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LinkPropPredDataset for pgl
"""
import
pandas
as
pd
import
shutil
,
os
import
os.path
as
osp
import
numpy
as
np
from
ogb.utils.url
import
decide_download
,
download_url
,
extract_zip
from
ogb.linkproppred
import
make_master_file
from
pgl.contrib.ogb.io.read_graph_pgl
import
read_csv_graph_pgl
def
to_bool
(
value
):
"""to_bool"""
return
np
.
array
([
value
],
dtype
=
"bool"
)[
0
]
class
PglLinkPropPredDataset
(
object
):
"""PglLinkPropPredDataset
"""
def
__init__
(
self
,
name
,
root
=
"dataset"
):
self
.
name
=
name
## original name, e.g., ogbl-ppa
self
.
dir_name
=
"_"
.
join
(
name
.
split
(
"-"
))
+
"_pgl"
## replace hyphen with underline, e.g., ogbl_ppa_pgl
self
.
original_root
=
root
self
.
root
=
osp
.
join
(
root
,
self
.
dir_name
)
self
.
meta_info
=
make_master_file
.
df
#pd.read_csv(os.path.join(os.path.dirname(__file__), "master.csv"), index_col=0)
if
not
self
.
name
in
self
.
meta_info
:
print
(
self
.
name
)
error_mssg
=
"Invalid dataset name {}.
\n
"
.
format
(
self
.
name
)
error_mssg
+=
"Available datasets are as follows:
\n
"
error_mssg
+=
"
\n
"
.
join
(
self
.
meta_info
.
keys
())
raise
ValueError
(
error_mssg
)
self
.
download_name
=
self
.
meta_info
[
self
.
name
][
"download_name"
]
## name of downloaded file, e.g., ppassoc
self
.
task_type
=
self
.
meta_info
[
self
.
name
][
"task type"
]
super
(
PglLinkPropPredDataset
,
self
).
__init__
()
self
.
pre_process
()
def
pre_process
(
self
):
"""pre_process downlaoding data
"""
processed_dir
=
osp
.
join
(
self
.
root
,
'processed'
)
pre_processed_file_path
=
osp
.
join
(
processed_dir
,
'dgl_data_processed'
)
if
osp
.
exists
(
pre_processed_file_path
):
#TODO: Reload Preprocess files
pass
else
:
### check download
if
not
osp
.
exists
(
osp
.
join
(
self
.
root
,
"raw"
,
"edge.csv.gz"
)):
url
=
self
.
meta_info
[
self
.
name
][
"url"
]
if
decide_download
(
url
):
path
=
download_url
(
url
,
self
.
original_root
)
extract_zip
(
path
,
self
.
original_root
)
os
.
unlink
(
path
)
# delete folder if there exists
try
:
shutil
.
rmtree
(
self
.
root
)
except
:
pass
shutil
.
move
(
osp
.
join
(
self
.
original_root
,
self
.
download_name
),
self
.
root
)
else
:
print
(
"Stop download."
)
exit
(
-
1
)
raw_dir
=
osp
.
join
(
self
.
root
,
"raw"
)
### pre-process and save
add_inverse_edge
=
to_bool
(
self
.
meta_info
[
self
.
name
][
"add_inverse_edge"
])
self
.
graph
=
read_csv_graph_pgl
(
raw_dir
,
add_inverse_edge
=
add_inverse_edge
)
#TODO: SAVE preprocess graph
def
get_edge_split
(
self
):
"""Train/Validation/Test split
"""
split_type
=
self
.
meta_info
[
self
.
name
][
"split"
]
path
=
osp
.
join
(
self
.
root
,
"split"
,
split_type
)
train_idx
=
pd
.
read_csv
(
osp
.
join
(
path
,
"train.csv.gz"
),
compression
=
"gzip"
,
header
=
None
).
values
valid_idx
=
pd
.
read_csv
(
osp
.
join
(
path
,
"valid.csv.gz"
),
compression
=
"gzip"
,
header
=
None
).
values
test_idx
=
pd
.
read_csv
(
osp
.
join
(
path
,
"test.csv.gz"
),
compression
=
"gzip"
,
header
=
None
).
values
if
self
.
task_type
==
"link prediction"
:
target_type
=
np
.
int64
else
:
target_type
=
np
.
float32
return
{
"train_edge"
:
np
.
array
(
train_idx
[:,
:
2
],
dtype
=
"int64"
),
"train_edge_label"
:
np
.
array
(
train_idx
[:,
2
],
dtype
=
target_type
),
"valid_edge"
:
np
.
array
(
valid_idx
[:,
:
2
],
dtype
=
"int64"
),
"valid_edge_label"
:
np
.
array
(
valid_idx
[:,
2
],
dtype
=
target_type
),
"test_edge"
:
np
.
array
(
test_idx
[:,
:
2
],
dtype
=
"int64"
),
"test_edge_label"
:
np
.
array
(
test_idx
[:,
2
],
dtype
=
target_type
)
}
def
__getitem__
(
self
,
idx
):
assert
idx
==
0
,
"This dataset has only one graph"
return
self
.
graph
[
0
]
def
__len__
(
self
):
return
1
def
__repr__
(
self
):
# pragma: no cover
return
'{}({})'
.
format
(
self
.
__class__
.
__name__
,
len
(
self
))
if
__name__
==
"__main__"
:
pgl_dataset
=
PglLinkPropPredDataset
(
name
=
"ogbl-ppa"
)
splitted_edge
=
pgl_dataset
.
get_edge_split
()
print
(
pgl_dataset
[
0
])
print
(
splitted_edge
)
pgl/contrib/ogb/nodeproppred/__init__.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""__init__.py
"""
pgl/contrib/ogb/nodeproppred/dataset_pgl.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NodePropPredDataset for pgl
"""
import
pandas
as
pd
import
shutil
,
os
import
os.path
as
osp
import
numpy
as
np
from
ogb.utils.url
import
decide_download
,
download_url
,
extract_zip
from
ogb.nodeproppred
import
make_master_file
# create master.csv
from
pgl.contrib.ogb.io.read_graph_pgl
import
read_csv_graph_pgl
def
to_bool
(
value
):
"""to_bool"""
return
np
.
array
([
value
],
dtype
=
"bool"
)[
0
]
class
PglNodePropPredDataset
(
object
):
"""PglNodePropPredDataset
"""
def
__init__
(
self
,
name
,
root
=
"dataset"
):
self
.
name
=
name
## original name, e.g., ogbn-proteins
self
.
dir_name
=
"_"
.
join
(
name
.
split
(
"-"
)
)
+
"_pgl"
## replace hyphen with underline, e.g., ogbn_proteins_pgl
self
.
original_root
=
root
self
.
root
=
osp
.
join
(
root
,
self
.
dir_name
)
self
.
meta_info
=
make_master_file
.
df
#pd.read_csv(
#os.path.join(os.path.dirname(__file__), "master.csv"), index_col=0)
if
not
self
.
name
in
self
.
meta_info
:
error_mssg
=
"Invalid dataset name {}.
\n
"
.
format
(
self
.
name
)
error_mssg
+=
"Available datasets are as follows:
\n
"
error_mssg
+=
"
\n
"
.
join
(
self
.
meta_info
.
keys
())
raise
ValueError
(
error_mssg
)
self
.
download_name
=
self
.
meta_info
[
self
.
name
][
"download_name"
]
## name of downloaded file, e.g., tox21
self
.
num_tasks
=
int
(
self
.
meta_info
[
self
.
name
][
"num tasks"
])
self
.
task_type
=
self
.
meta_info
[
self
.
name
][
"task type"
]
super
(
PglNodePropPredDataset
,
self
).
__init__
()
self
.
pre_process
()
def
pre_process
(
self
):
"""pre_process downlaoding data
"""
processed_dir
=
osp
.
join
(
self
.
root
,
'processed'
)
pre_processed_file_path
=
osp
.
join
(
processed_dir
,
'pgl_data_processed'
)
if
osp
.
exists
(
pre_processed_file_path
):
# TODO: Reload Preprocess files
pass
else
:
### check download
if
not
osp
.
exists
(
osp
.
join
(
self
.
root
,
"raw"
,
"edge.csv.gz"
)):
url
=
self
.
meta_info
[
self
.
name
][
"url"
]
if
decide_download
(
url
):
path
=
download_url
(
url
,
self
.
original_root
)
extract_zip
(
path
,
self
.
original_root
)
os
.
unlink
(
path
)
# delete folder if there exists
try
:
shutil
.
rmtree
(
self
.
root
)
except
:
pass
shutil
.
move
(
osp
.
join
(
self
.
original_root
,
self
.
download_name
),
self
.
root
)
else
:
print
(
"Stop download."
)
exit
(
-
1
)
raw_dir
=
osp
.
join
(
self
.
root
,
"raw"
)
### pre-process and save
add_inverse_edge
=
to_bool
(
self
.
meta_info
[
self
.
name
][
"add_inverse_edge"
])
self
.
graph
=
read_csv_graph_pgl
(
raw_dir
,
add_inverse_edge
=
add_inverse_edge
)
### adding prediction target
node_label
=
pd
.
read_csv
(
osp
.
join
(
raw_dir
,
'node-label.csv.gz'
),
compression
=
"gzip"
,
header
=
None
).
values
if
"classification"
in
self
.
task_type
:
node_label
=
np
.
array
(
node_label
,
dtype
=
np
.
int64
)
else
:
node_label
=
np
.
array
(
node_label
,
dtype
=
np
.
float32
)
label_dict
=
{
"labels"
:
node_label
}
# TODO: SAVE preprocess graph
self
.
labels
=
label_dict
[
'labels'
]
def
get_idx_split
(
self
):
"""Train/Validation/Test split
"""
split_type
=
self
.
meta_info
[
self
.
name
][
"split"
]
path
=
osp
.
join
(
self
.
root
,
"split"
,
split_type
)
train_idx
=
pd
.
read_csv
(
osp
.
join
(
path
,
"train.csv.gz"
),
compression
=
"gzip"
,
header
=
None
).
values
.
T
[
0
]
valid_idx
=
pd
.
read_csv
(
osp
.
join
(
path
,
"valid.csv.gz"
),
compression
=
"gzip"
,
header
=
None
).
values
.
T
[
0
]
test_idx
=
pd
.
read_csv
(
osp
.
join
(
path
,
"test.csv.gz"
),
compression
=
"gzip"
,
header
=
None
).
values
.
T
[
0
]
return
{
"train"
:
np
.
array
(
train_idx
,
dtype
=
"int64"
),
"valid"
:
np
.
array
(
valid_idx
,
dtype
=
"int64"
),
"test"
:
np
.
array
(
test_idx
,
dtype
=
"int64"
)
}
def
__getitem__
(
self
,
idx
):
assert
idx
==
0
,
"This dataset has only one graph"
return
self
.
graph
[
idx
],
self
.
labels
def
__len__
(
self
):
return
1
def
__repr__
(
self
):
# pragma: no cover
return
'{}({})'
.
format
(
self
.
__class__
.
__name__
,
len
(
self
))
if
__name__
==
"__main__"
:
pgl_dataset
=
PglNodePropPredDataset
(
name
=
"ogbn-proteins"
)
splitted_index
=
pgl_dataset
.
get_idx_split
()
print
(
pgl_dataset
[
0
])
print
(
splitted_index
)
pgl/graph.py
浏览文件 @
d347a2bb
...
...
@@ -15,6 +15,7 @@
This package implement Graph structure for handling graph data.
"""
import
os
import
numpy
as
np
import
pickle
as
pkl
import
time
...
...
@@ -77,6 +78,15 @@ class EdgeIndex(object):
"""
return
self
.
_sorted_u
,
self
.
_sorted_v
,
self
.
_sorted_eid
def
dump
(
self
,
path
):
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
np
.
save
(
path
+
'/degree.npy'
,
self
.
_degree
)
np
.
save
(
path
+
'/sorted_u.npy'
,
self
.
_sorted_u
)
np
.
save
(
path
+
'/sorted_v.npy'
,
self
.
_sorted_v
)
np
.
save
(
path
+
'/sorted_eid.npy'
,
self
.
_sorted_eid
)
np
.
save
(
path
+
'/indptr.npy'
,
self
.
_indptr
)
class
Graph
(
object
):
"""Implementation of graph structure in pgl.
...
...
@@ -136,6 +146,18 @@ class Graph(object):
self
.
_adj_src_index
=
None
self
.
_adj_dst_index
=
None
def
dump
(
self
,
path
):
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
np
.
save
(
path
+
'/num_nodes.npy'
,
self
.
_num_nodes
)
np
.
save
(
path
+
'/edges.npy'
,
self
.
_edges
)
if
self
.
_adj_src_index
:
self
.
_adj_src_index
.
dump
(
path
+
'/adj_src'
)
if
self
.
_adj_dst_index
:
self
.
_adj_dst_index
.
dump
(
path
+
'/adj_dst'
)
@
property
def
adj_src_index
(
self
):
"""Return an EdgeIndex object for src.
...
...
@@ -506,7 +528,13 @@ class Graph(object):
(
key
,
_hide_num_nodes
(
value
.
shape
),
value
.
dtype
))
return
edge_feat_info
def
subgraph
(
self
,
nodes
,
eid
=
None
,
edges
=
None
):
def
subgraph
(
self
,
nodes
,
eid
=
None
,
edges
=
None
,
edge_feats
=
None
,
with_node_feat
=
True
,
with_edge_feat
=
True
):
"""Generate subgraph with nodes and edge ids.
This function will generate a :code:`pgl.graph.Subgraph` object and
...
...
@@ -522,6 +550,10 @@ class Graph(object):
edges (optional): Edge(src, dst) list which will be included in the subgraph.
with_node_feat: Whether to inherit node features from parent graph.
with_edge_feat: Whether to inherit edge features from parent graph.
Return:
A :code:`pgl.graph.Subgraph` object.
"""
...
...
@@ -543,12 +575,18 @@ class Graph(object):
len
(
edges
),
dtype
=
"int64"
),
edges
,
reindex
)
sub_edge_feat
=
{}
if
edges
is
None
:
if
with_edge_feat
:
for
key
,
value
in
self
.
_edge_feat
.
items
():
if
eid
is
None
:
raise
ValueError
(
"Eid can not be None with edge features."
)
raise
ValueError
(
"Eid can not be None with edge features."
)
sub_edge_feat
[
key
]
=
value
[
eid
]
else
:
sub_edge_feat
=
edge_feats
sub_node_feat
=
{}
if
with_node_feat
:
for
key
,
value
in
self
.
_node_feat
.
items
():
sub_node_feat
[
key
]
=
value
[
nodes
]
...
...
@@ -779,3 +817,27 @@ class SubGraph(Graph):
A list of node ids in parent graph.
"""
return
graph_kernel
.
map_nodes
(
nodes
,
self
.
_to_reindex
)
class
MemmapEdgeIndex
(
EdgeIndex
):
def
__init__
(
self
,
path
):
self
.
_degree
=
np
.
load
(
path
+
'/degree.npy'
,
mmap_mode
=
"r"
)
self
.
_sorted_u
=
np
.
load
(
path
+
'/sorted_u.npy'
,
mmap_mode
=
"r"
)
self
.
_sorted_v
=
np
.
load
(
path
+
'/sorted_v.npy'
,
mmap_mode
=
"r"
)
self
.
_sorted_eid
=
np
.
load
(
path
+
'/sorted_eid.npy'
,
mmap_mode
=
"r"
)
self
.
_indptr
=
np
.
load
(
path
+
'/indptr.npy'
,
mmap_mode
=
"r"
)
class
MemmapGraph
(
Graph
):
def
__init__
(
self
,
path
):
self
.
_num_nodes
=
np
.
load
(
path
+
'/num_nodes.npy'
)
self
.
_edges
=
np
.
load
(
path
+
'/edges.npy'
,
mmap_mode
=
"r"
)
if
os
.
path
.
exists
(
path
+
'/adj_src'
):
self
.
_adj_src_index
=
MemmapEdgeIndex
(
path
+
'/adj_src'
)
else
:
self
.
_adj_src_index
=
None
if
os
.
path
.
exists
(
path
+
'/adj_dst'
):
self
.
_adj_dst_index
=
MemmapEdgeIndex
(
path
+
'/adj_dst'
)
else
:
self
.
_adj_dst_index
=
None
pgl/graph_wrapper.py
浏览文件 @
d347a2bb
...
...
@@ -89,8 +89,8 @@ class BaseGraphWrapper(object):
"""
def
__init__
(
self
):
self
.
_
node_feat_tensor_dict
=
{}
self
.
_
edge_feat_tensor_dict
=
{}
self
.
node_feat_tensor_dict
=
{}
self
.
edge_feat_tensor_dict
=
{}
self
.
_edges_src
=
None
self
.
_edges_dst
=
None
self
.
_num_nodes
=
None
...
...
@@ -98,6 +98,10 @@ class BaseGraphWrapper(object):
self
.
_edge_uniq_dst
=
None
self
.
_edge_uniq_dst_count
=
None
self
.
_node_ids
=
None
self
.
_data_name_prefix
=
""
def
__repr__
(
self
):
return
self
.
_data_name_prefix
def
send
(
self
,
message_func
,
nfeat_list
=
None
,
efeat_list
=
None
):
"""Send message from all src nodes to dst nodes.
...
...
@@ -220,7 +224,7 @@ class BaseGraphWrapper(object):
A dictionary whose keys are the feature names and the values
are feature tensor.
"""
return
self
.
_
edge_feat_tensor_dict
return
self
.
edge_feat_tensor_dict
@
property
def
node_feat
(
self
):
...
...
@@ -230,7 +234,7 @@ class BaseGraphWrapper(object):
A dictionary whose keys are the feature names and the values
are feature tensor.
"""
return
self
.
_
node_feat_tensor_dict
return
self
.
node_feat_tensor_dict
def
indegree
(
self
):
"""Return the indegree tensor for all nodes.
...
...
@@ -298,8 +302,8 @@ class StaticGraphWrapper(BaseGraphWrapper):
def
__init__
(
self
,
name
,
graph
,
place
):
super
(
StaticGraphWrapper
,
self
).
__init__
()
self
.
_data_name_prefix
=
name
self
.
_initializers
=
[]
self
.
__data_name_prefix
=
name
self
.
__create_graph_attr
(
graph
)
def
__create_graph_attr
(
self
,
graph
):
...
...
@@ -326,43 +330,43 @@ class StaticGraphWrapper(BaseGraphWrapper):
self
.
_edges_src
,
init
=
paddle_helper
.
constant
(
dtype
=
"int64"
,
value
=
src
,
name
=
self
.
_
_
data_name_prefix
+
'/edges_src'
)
name
=
self
.
_data_name_prefix
+
'/edges_src'
)
self
.
_initializers
.
append
(
init
)
self
.
_edges_dst
,
init
=
paddle_helper
.
constant
(
dtype
=
"int64"
,
value
=
dst
,
name
=
self
.
_
_
data_name_prefix
+
'/edges_dst'
)
name
=
self
.
_data_name_prefix
+
'/edges_dst'
)
self
.
_initializers
.
append
(
init
)
self
.
_num_nodes
,
init
=
paddle_helper
.
constant
(
dtype
=
"int64"
,
hide_batch_size
=
False
,
value
=
np
.
array
([
graph
.
num_nodes
]),
name
=
self
.
_
_
data_name_prefix
+
'/num_nodes'
)
name
=
self
.
_data_name_prefix
+
'/num_nodes'
)
self
.
_initializers
.
append
(
init
)
self
.
_edge_uniq_dst
,
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
"/uniq_dst"
,
name
=
self
.
_data_name_prefix
+
"/uniq_dst"
,
dtype
=
"int64"
,
value
=
uniq_dst
)
self
.
_initializers
.
append
(
init
)
self
.
_edge_uniq_dst_count
,
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
"/uniq_dst_count"
,
name
=
self
.
_data_name_prefix
+
"/uniq_dst_count"
,
dtype
=
"int32"
,
value
=
uniq_dst_count
)
self
.
_initializers
.
append
(
init
)
node_ids_value
=
np
.
arange
(
0
,
graph
.
num_nodes
,
dtype
=
"int64"
)
self
.
_node_ids
,
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
"/node_ids"
,
name
=
self
.
_data_name_prefix
+
"/node_ids"
,
dtype
=
"int64"
,
value
=
node_ids_value
)
self
.
_initializers
.
append
(
init
)
self
.
_indegree
,
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
"/indegree"
,
name
=
self
.
_data_name_prefix
+
"/indegree"
,
dtype
=
"int64"
,
value
=
indegree
)
self
.
_initializers
.
append
(
init
)
...
...
@@ -373,9 +377,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
for
node_feat_name
,
node_feat_value
in
node_feat
.
items
():
node_feat_shape
=
node_feat_value
.
shape
node_feat_dtype
=
node_feat_value
.
dtype
self
.
_
node_feat_tensor_dict
[
self
.
node_feat_tensor_dict
[
node_feat_name
],
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
'/node_feat/'
+
name
=
self
.
_data_name_prefix
+
'/node_feat/'
+
node_feat_name
,
dtype
=
node_feat_dtype
,
value
=
node_feat_value
)
...
...
@@ -387,9 +391,9 @@ class StaticGraphWrapper(BaseGraphWrapper):
for
edge_feat_name
,
edge_feat_value
in
edge_feat
.
items
():
edge_feat_shape
=
edge_feat_value
.
shape
edge_feat_dtype
=
edge_feat_value
.
dtype
self
.
_
edge_feat_tensor_dict
[
self
.
edge_feat_tensor_dict
[
edge_feat_name
],
init
=
paddle_helper
.
constant
(
name
=
self
.
_
_
data_name_prefix
+
'/edge_feat/'
+
name
=
self
.
_data_name_prefix
+
'/edge_feat/'
+
edge_feat_name
,
dtype
=
edge_feat_dtype
,
value
=
edge_feat_value
)
...
...
@@ -477,8 +481,8 @@ class GraphWrapper(BaseGraphWrapper):
def
__init__
(
self
,
name
,
place
,
node_feat
=
[],
edge_feat
=
[]):
super
(
GraphWrapper
,
self
).
__init__
()
# collect holders for PyReader
self
.
_data_name_prefix
=
name
self
.
_holder_list
=
[]
self
.
__data_name_prefix
=
name
self
.
_place
=
place
self
.
__create_graph_attr_holders
()
for
node_feat_name
,
node_feat_shape
,
node_feat_dtype
in
node_feat
:
...
...
@@ -493,43 +497,43 @@ class GraphWrapper(BaseGraphWrapper):
"""Create data holders for graph attributes.
"""
self
.
_edges_src
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
'/edges_src'
,
self
.
_data_name_prefix
+
'/edges_src'
,
shape
=
[
None
],
append_batch_size
=
False
,
dtype
=
"int64"
,
stop_gradient
=
True
)
self
.
_edges_dst
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
'/edges_dst'
,
self
.
_data_name_prefix
+
'/edges_dst'
,
shape
=
[
None
],
append_batch_size
=
False
,
dtype
=
"int64"
,
stop_gradient
=
True
)
self
.
_num_nodes
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
'/num_nodes'
,
self
.
_data_name_prefix
+
'/num_nodes'
,
shape
=
[
1
],
append_batch_size
=
False
,
dtype
=
'int64'
,
stop_gradient
=
True
)
self
.
_edge_uniq_dst
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
"/uniq_dst"
,
self
.
_data_name_prefix
+
"/uniq_dst"
,
shape
=
[
None
],
append_batch_size
=
False
,
dtype
=
"int64"
,
stop_gradient
=
True
)
self
.
_edge_uniq_dst_count
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
"/uniq_dst_count"
,
self
.
_data_name_prefix
+
"/uniq_dst_count"
,
shape
=
[
None
],
append_batch_size
=
False
,
dtype
=
"int32"
,
stop_gradient
=
True
)
self
.
_node_ids
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
"/node_ids"
,
self
.
_data_name_prefix
+
"/node_ids"
,
shape
=
[
None
],
append_batch_size
=
False
,
dtype
=
"int64"
,
stop_gradient
=
True
)
self
.
_indegree
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
"/indegree"
,
self
.
_data_name_prefix
+
"/indegree"
,
shape
=
[
None
],
append_batch_size
=
False
,
dtype
=
"int64"
,
...
...
@@ -545,12 +549,12 @@ class GraphWrapper(BaseGraphWrapper):
"""Create data holders for node features.
"""
feat_holder
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
'/node_feat/'
+
node_feat_name
,
self
.
_data_name_prefix
+
'/node_feat/'
+
node_feat_name
,
shape
=
node_feat_shape
,
append_batch_size
=
False
,
dtype
=
node_feat_dtype
,
stop_gradient
=
True
)
self
.
_
node_feat_tensor_dict
[
node_feat_name
]
=
feat_holder
self
.
node_feat_tensor_dict
[
node_feat_name
]
=
feat_holder
self
.
_holder_list
.
append
(
feat_holder
)
def
__create_graph_edge_feat_holders
(
self
,
edge_feat_name
,
edge_feat_shape
,
...
...
@@ -558,12 +562,12 @@ class GraphWrapper(BaseGraphWrapper):
"""Create edge holders for edge features.
"""
feat_holder
=
fluid
.
layers
.
data
(
self
.
_
_
data_name_prefix
+
'/edge_feat/'
+
edge_feat_name
,
self
.
_data_name_prefix
+
'/edge_feat/'
+
edge_feat_name
,
shape
=
edge_feat_shape
,
append_batch_size
=
False
,
dtype
=
edge_feat_dtype
,
stop_gradient
=
True
)
self
.
_
edge_feat_tensor_dict
[
edge_feat_name
]
=
feat_holder
self
.
edge_feat_tensor_dict
[
edge_feat_name
]
=
feat_holder
self
.
_holder_list
.
append
(
feat_holder
)
def
to_feed
(
self
,
graph
):
...
...
@@ -594,20 +598,21 @@ class GraphWrapper(BaseGraphWrapper):
edge_feat
[
key
]
=
value
[
eid
]
node_feat
=
graph
.
node_feat
feed_dict
[
self
.
__data_name_prefix
+
'/edges_src'
]
=
src
feed_dict
[
self
.
__data_name_prefix
+
'/edges_dst'
]
=
dst
feed_dict
[
self
.
__data_name_prefix
+
'/num_nodes'
]
=
np
.
array
(
graph
.
num_nodes
)
feed_dict
[
self
.
__data_name_prefix
+
'/uniq_dst'
]
=
uniq_dst
feed_dict
[
self
.
__data_name_prefix
+
'/uniq_dst_count'
]
=
uniq_dst_count
feed_dict
[
self
.
__data_name_prefix
+
'/node_ids'
]
=
graph
.
nodes
feed_dict
[
self
.
__data_name_prefix
+
'/indegree'
]
=
indegree
for
key
in
self
.
_node_feat_tensor_dict
:
feed_dict
[
self
.
__data_name_prefix
+
'/node_feat/'
+
feed_dict
[
self
.
_data_name_prefix
+
'/edges_src'
]
=
src
feed_dict
[
self
.
_data_name_prefix
+
'/edges_dst'
]
=
dst
feed_dict
[
self
.
_data_name_prefix
+
'/num_nodes'
]
=
np
.
array
(
graph
.
num_nodes
)
feed_dict
[
self
.
_data_name_prefix
+
'/uniq_dst'
]
=
uniq_dst
feed_dict
[
self
.
_data_name_prefix
+
'/uniq_dst_count'
]
=
uniq_dst_count
feed_dict
[
self
.
_data_name_prefix
+
'/node_ids'
]
=
graph
.
nodes
feed_dict
[
self
.
_data_name_prefix
+
'/indegree'
]
=
indegree
for
key
in
self
.
node_feat_tensor_dict
:
feed_dict
[
self
.
_data_name_prefix
+
'/node_feat/'
+
key
]
=
node_feat
[
key
]
for
key
in
self
.
_
edge_feat_tensor_dict
:
feed_dict
[
self
.
_
_
data_name_prefix
+
'/edge_feat/'
+
for
key
in
self
.
edge_feat_tensor_dict
:
feed_dict
[
self
.
_data_name_prefix
+
'/edge_feat/'
+
key
]
=
edge_feat
[
key
]
return
feed_dict
...
...
pgl/
contrib/
heter_graph.py
→
pgl/heter_graph.py
浏览文件 @
d347a2bb
文件已移动
pgl/
contrib/
heter_graph_wrapper.py
→
pgl/heter_graph_wrapper.py
浏览文件 @
d347a2bb
...
...
@@ -64,8 +64,8 @@ class HeterGraphWrapper(object):
import paddle.fluid as fluid
import numpy as np
from pgl
.contrib
import heter_graph
from pgl
.contrib
import heter_graph_wrapper
from pgl import heter_graph
from pgl import heter_graph_wrapper
num_nodes = 4
node_types = [(0, 'user'), (1, 'item'), (2, 'item'), (3, 'user')]
edges = {
...
...
pgl/
contrib/
redis_hetergraph.py
→
pgl/redis_hetergraph.py
浏览文件 @
d347a2bb
...
...
@@ -28,7 +28,7 @@ import pgl.graph as pgraph
import
pickle
as
pkl
from
pgl.utils.logger
import
log
import
pgl.graph_kernel
as
graph_kernel
from
pgl
.contrib
import
heter_graph
from
pgl
import
heter_graph
import
pgl.redis_graph
as
rg
...
...
pgl/sample.py
浏览文件 @
d347a2bb
...
...
@@ -24,10 +24,29 @@ from pgl import graph_kernel
__all__
=
[
'graphsage_sample'
,
'node2vec_sample'
,
'deepwalk_sample'
,
'metapath_randomwalk'
'metapath_randomwalk'
,
'pinsage_sample'
]
def
traverse
(
item
):
"""traverse the list or numpy"""
if
isinstance
(
item
,
list
)
or
isinstance
(
item
,
np
.
ndarray
):
for
i
in
iter
(
item
):
for
j
in
traverse
(
i
):
yield
j
else
:
yield
item
def
flat_node_and_edge
(
nodes
,
eids
,
weights
=
None
):
"""flatten the sub-lists to one list"""
nodes
=
list
(
set
(
traverse
(
nodes
)))
eids
=
list
(
traverse
(
eids
))
if
weights
is
not
None
:
weights
=
list
(
traverse
(
weights
))
return
nodes
,
eids
,
weights
def
edge_hash
(
src
,
dst
):
"""edge_hash
"""
...
...
@@ -88,7 +107,6 @@ def graphsage_sample(graph, nodes, samples, ignore_edges=[]):
start_nodes
=
list
(
nodes_set
-
last_nodes_set
)
layer_nodes
=
[
nodes
]
+
layer_nodes
layer_eids
=
[
eids
]
+
layer_eids
log
.
debug
(
"flat time: %s"
%
(
time
.
time
()
-
start
))
start
=
time
.
time
()
# Find new nodes
...
...
@@ -317,3 +335,146 @@ def metapath_randomwalk(graph,
cur_nodes
=
np
.
array
(
nxt_cur_nodes
)
return
walk
def
random_walk_with_start_prob
(
graph
,
nodes
,
max_depth
,
proba
=
0.5
):
"""Implement of random walk with the probability of returning the origin node.
This function get random walks path for given nodes and depth.
Args:
nodes: Walk starting from nodes
max_depth: Max walking depth
proba: the proba to return the origin node
Return:
A list of walks.
"""
walk
=
[]
# init
for
node
in
nodes
:
walk
.
append
([
node
])
walk_ids
=
np
.
arange
(
0
,
len
(
nodes
))
cur_nodes
=
np
.
array
(
nodes
)
nodes
=
np
.
array
(
nodes
)
for
l
in
range
(
max_depth
):
# select the walks not end
if
l
>=
1
:
return_proba
=
np
.
random
.
rand
(
cur_nodes
.
shape
[
0
])
proba_mask
=
(
return_proba
<
proba
)
cur_nodes
[
proba_mask
]
=
nodes
[
proba_mask
]
outdegree
=
graph
.
outdegree
(
cur_nodes
)
mask
=
(
outdegree
!=
0
)
if
np
.
any
(
mask
):
cur_walk_ids
=
walk_ids
[
mask
]
outdegree
=
outdegree
[
mask
]
else
:
# stop when all nodes have no successor, wait start next loop to get precesssor
continue
succ
=
graph
.
successor
(
cur_nodes
[
mask
])
sample_index
=
np
.
floor
(
np
.
random
.
rand
(
outdegree
.
shape
[
0
])
*
outdegree
).
astype
(
"int64"
)
nxt_cur_nodes
=
cur_nodes
for
s
,
ind
,
walk_id
in
zip
(
succ
,
sample_index
,
cur_walk_ids
):
walk
[
walk_id
].
append
(
s
[
ind
])
nxt_cur_nodes
[
walk_id
]
=
s
[
ind
]
cur_nodes
=
np
.
array
(
nxt_cur_nodes
)
return
walk
def
pinsage_sample
(
graph
,
nodes
,
samples
,
top_k
=
10
,
proba
=
0.5
,
norm_bais
=
1.0
,
ignore_edges
=
set
()):
"""Implement of graphsage sample.
Reference paper: .
Args:
graph: A pgl graph instance
nodes: Sample starting from nodes
samples: A list, number of neighbors in each layer
top_k: select the top_k visit count nodes to construct the edges
proba: the probability to return the origin node
norm_bais: the normlization for the visit count
ignore_edges: list of edge(src, dst) will be ignored.
Return:
A list of subgraphs
"""
start
=
time
.
time
()
num_layers
=
len
(
samples
)
start_nodes
=
nodes
edges
,
weights
=
[],
[]
layer_nodes
,
layer_edges
,
layer_weights
=
[],
[],
[]
ignore_edge_set
=
set
([
edge_hash
(
src
,
dst
)
for
src
,
dst
in
ignore_edges
])
for
layer_idx
in
reversed
(
range
(
num_layers
)):
if
len
(
start_nodes
)
==
0
:
layer_nodes
=
[
nodes
]
+
layer_nodes
layer_edges
=
[
edges
]
+
layer_edges
layer_edges_weight
=
[
weights
]
+
layer_weights
continue
walks
=
random_walk_with_start_prob
(
graph
,
start_nodes
,
samples
[
layer_idx
],
proba
=
proba
)
walks
=
[
walk
[
1
:]
for
walk
in
walks
]
pred_edges
=
[]
pred_weights
=
[]
pred_nodes
=
[]
for
node
,
walk
in
zip
(
start_nodes
,
walks
):
walk_nodes
=
[]
walk_weights
=
[]
count_sum
=
0
for
random_walk_node
in
walk
:
if
len
(
ignore_edge_set
)
>
0
and
random_walk_node
!=
node
and
\
edge_hash
(
random_walk_node
,
node
)
in
ignore_edge_set
:
continue
walk_nodes
.
append
(
random_walk_node
)
unique
,
counts
=
np
.
unique
(
walk_nodes
,
return_counts
=
True
)
frequencies
=
np
.
asarray
((
unique
,
counts
)).
T
frequencies
=
frequencies
[
np
.
argsort
(
frequencies
[:,
1
])]
frequencies
=
frequencies
[
-
1
*
top_k
:,
:]
for
random_walk_node
,
random_count
in
zip
(
frequencies
[:,
0
].
tolist
(),
frequencies
[:,
1
].
tolist
()):
pred_nodes
.
append
(
random_walk_node
)
pred_edges
.
append
((
random_walk_node
,
node
))
walk_weights
.
append
(
random_count
)
count_sum
+=
random_count
count_sum
+=
len
(
walk_weights
)
*
norm_bais
walk_weights
=
(
np
.
array
(
walk_weights
)
+
norm_bais
)
/
(
count_sum
)
pred_weights
.
extend
(
walk_weights
.
tolist
())
last_node_set
=
set
(
nodes
)
nodes
,
edges
,
weights
=
flat_node_and_edge
([
nodes
,
pred_nodes
],
\
[
edges
,
pred_edges
],
[
weights
,
pred_weights
])
layer_edges
=
[
edges
]
+
layer_edges
layer_weights
=
[
weights
]
+
layer_weights
layer_nodes
=
[
nodes
]
+
layer_nodes
start_nodes
=
list
(
set
(
nodes
)
-
last_node_set
)
start
=
time
.
time
()
feed_dict
=
{}
subgraphs
=
[]
for
i
in
range
(
num_layers
):
edge_feat_dict
=
{
"weight"
:
np
.
array
(
layer_weights
[
i
],
dtype
=
'float32'
)
}
subgraphs
.
append
(
graph
.
subgraph
(
nodes
=
layer_nodes
[
0
],
edges
=
layer_edges
[
i
],
edge_feats
=
edge_feat_dict
))
subgraphs
[
i
].
node_feat
[
"index"
]
=
np
.
array
(
layer_nodes
[
0
],
dtype
=
"int64"
)
return
subgraphs
pgl/tests/test_hetergraph.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test_hetergraph"""
import
time
import
unittest
import
json
import
os
import
numpy
as
np
from
pgl.sample
import
metapath_randomwalk
from
pgl.graph
import
Graph
from
pgl
import
heter_graph
class
HeterGraphTest
(
unittest
.
TestCase
):
"""HeterGraph test
"""
@
classmethod
def
setUpClass
(
cls
):
np
.
random
.
seed
(
1
)
edges
=
{}
# for test no successor
edges
[
'c2p'
]
=
[(
1
,
4
),
(
0
,
5
),
(
1
,
9
),
(
1
,
8
),
(
2
,
8
),
(
2
,
5
),
(
3
,
6
),
(
3
,
7
),
(
3
,
4
),
(
3
,
8
)]
edges
[
'p2c'
]
=
[(
v
,
u
)
for
u
,
v
in
edges
[
'c2p'
]]
edges
[
'p2a'
]
=
[(
4
,
10
),
(
4
,
11
),
(
4
,
12
),
(
4
,
14
),
(
4
,
13
),
(
6
,
12
),
(
6
,
11
),
(
6
,
14
),
(
7
,
12
),
(
7
,
11
),
(
8
,
14
),
(
9
,
10
)]
edges
[
'a2p'
]
=
[(
v
,
u
)
for
u
,
v
in
edges
[
'p2a'
]]
# for test speed
# edges['c2p'] = [(0, 4), (0, 5), (1, 9), (1,8), (2,8), (2,5), (3,6), (3,7), (3,4), (3,8)]
# edges['p2c'] = [(v,u) for u, v in edges['c2p']]
# edges['p2a'] = [(4,10), (4,11), (4,12), (4,14), (5,13), (6,13), (6,11), (6,14), (7,12), (7,11), (8,14), (9,13)]
# edges['a2p'] = [(v,u) for u, v in edges['p2a']]
node_types
=
[
'c'
for
_
in
range
(
4
)]
+
[
'p'
for
_
in
range
(
6
)
]
+
[
'a'
for
_
in
range
(
5
)]
node_types
=
[(
i
,
t
)
for
i
,
t
in
enumerate
(
node_types
)]
cls
.
graph
=
heter_graph
.
HeterGraph
(
num_nodes
=
len
(
node_types
),
edges
=
edges
,
node_types
=
node_types
)
def
test_num_nodes_by_type
(
self
):
print
()
n_types
=
{
'c'
:
4
,
'p'
:
6
,
'a'
:
5
}
for
nt
in
n_types
:
num_nodes
=
self
.
graph
.
num_nodes_by_type
(
nt
)
self
.
assertEqual
(
num_nodes
,
n_types
[
nt
])
def
test_node_batch_iter
(
self
):
print
()
batch_size
=
2
ground
=
[[
4
,
5
],
[
6
,
7
],
[
8
,
9
]]
for
idx
,
nodes
in
enumerate
(
self
.
graph
.
node_batch_iter
(
batch_size
=
batch_size
,
shuffle
=
False
,
n_type
=
'p'
)):
self
.
assertEqual
(
len
(
nodes
),
batch_size
)
self
.
assertListEqual
(
list
(
nodes
),
ground
[
idx
])
def
test_sample_nodes
(
self
):
print
()
p_ground
=
[
4
,
5
,
6
,
7
,
8
,
9
]
sample_num
=
10
nodes
=
self
.
graph
.
sample_nodes
(
sample_num
=
sample_num
,
n_type
=
'p'
)
self
.
assertEqual
(
len
(
nodes
),
sample_num
)
for
n
in
nodes
:
self
.
assertIn
(
n
,
p_ground
)
# test n_type == None
ground
=
[
i
for
i
in
range
(
15
)]
nodes
=
self
.
graph
.
sample_nodes
(
sample_num
=
sample_num
,
n_type
=
None
)
self
.
assertEqual
(
len
(
nodes
),
sample_num
)
for
n
in
nodes
:
self
.
assertIn
(
n
,
ground
)
if
__name__
==
"__main__"
:
unittest
.
main
()
pgl/tests/test_metapath_randomwalk.py
0 → 100644
浏览文件 @
d347a2bb
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test_metapath_randomwalk"""
import
time
import
unittest
import
json
import
os
import
numpy
as
np
from
pgl.sample
import
metapath_randomwalk
from
pgl.graph
import
Graph
from
pgl
import
heter_graph
np
.
random
.
seed
(
1
)
class
MetapathRandomwalkTest
(
unittest
.
TestCase
):
"""metapath_randomwalk test
"""
def
setUp
(
self
):
edges
=
{}
# for test no successor
edges
[
'c2p'
]
=
[(
1
,
4
),
(
0
,
5
),
(
1
,
9
),
(
1
,
8
),
(
2
,
8
),
(
2
,
5
),
(
3
,
6
),
(
3
,
7
),
(
3
,
4
),
(
3
,
8
)]
edges
[
'p2c'
]
=
[(
v
,
u
)
for
u
,
v
in
edges
[
'c2p'
]]
edges
[
'p2a'
]
=
[(
4
,
10
),
(
4
,
11
),
(
4
,
12
),
(
4
,
14
),
(
4
,
13
),
(
6
,
12
),
(
6
,
11
),
(
6
,
14
),
(
7
,
12
),
(
7
,
11
),
(
8
,
14
),
(
9
,
10
)]
edges
[
'a2p'
]
=
[(
v
,
u
)
for
u
,
v
in
edges
[
'p2a'
]]
# for test speed
# edges['c2p'] = [(0, 4), (0, 5), (1, 9), (1,8), (2,8), (2,5), (3,6), (3,7), (3,4), (3,8)]
# edges['p2c'] = [(v,u) for u, v in edges['c2p']]
# edges['p2a'] = [(4,10), (4,11), (4,12), (4,14), (5,13), (6,13), (6,11), (6,14), (7,12), (7,11), (8,14), (9,13)]
# edges['a2p'] = [(v,u) for u, v in edges['p2a']]
self
.
node_types
=
[
'c'
for
_
in
range
(
4
)]
+
[
'p'
for
_
in
range
(
6
)
]
+
[
'a'
for
_
in
range
(
5
)]
node_types
=
[(
i
,
t
)
for
i
,
t
in
enumerate
(
self
.
node_types
)]
self
.
graph
=
heter_graph
.
HeterGraph
(
num_nodes
=
len
(
node_types
),
edges
=
edges
,
node_types
=
node_types
)
def
test_metapath_randomwalk
(
self
):
meta_path
=
'c2p-p2a-a2p-p2c'
path
=
[
'c'
,
'p'
,
'a'
,
'p'
,
'c'
]
start_nodes
=
[
0
,
1
,
2
,
3
]
walk_len
=
10
walks
=
metapath_randomwalk
(
graph
=
self
.
graph
,
start_nodes
=
start_nodes
,
metapath
=
meta_path
,
walk_length
=
walk_len
)
self
.
assertEqual
(
len
(
walks
),
4
)
for
walk
in
walks
:
for
i
in
range
(
len
(
walk
)):
idx
=
i
%
(
len
(
path
)
-
1
)
self
.
assertEqual
(
self
.
node_types
[
walk
[
i
]],
path
[
idx
])
if
__name__
==
"__main__"
:
unittest
.
main
()
pgl/utils/mp_reader.py
浏览文件 @
d347a2bb
...
...
@@ -25,6 +25,8 @@ except:
import
numpy
as
np
import
time
import
paddle.fluid
as
fluid
from
queue
import
Queue
import
threading
def
serialize_data
(
data
):
...
...
@@ -129,23 +131,40 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000, pipe_size=10):
p
.
start
()
reader_num
=
len
(
readers
)
finish_num
=
0
conn_to_remove
=
[]
finish_flag
=
np
.
zeros
(
len
(
conns
),
dtype
=
"int32"
)
while
finish_num
<
reader_num
:
for
conn_id
,
conn
in
enumerate
(
conns
):
if
finish_flag
[
conn_id
]
>
0
:
continue
if
conn
.
poll
(
0.01
):
buff
=
conn
.
recv
()
start
=
time
.
time
()
def
queue_worker
(
sub_conn
,
que
):
while
True
:
buff
=
sub_conn
.
recv
()
sample
=
deserialize_data
(
buff
)
if
sample
is
None
:
que
.
put
(
None
)
sub_conn
.
close
()
break
que
.
put
(
sample
)
thread_pool
=
[]
output_queue
=
Queue
(
maxsize
=
reader_num
)
for
i
in
range
(
reader_num
):
t
=
threading
.
Thread
(
target
=
queue_worker
,
args
=
(
conns
[
i
],
output_queue
))
t
.
daemon
=
True
t
.
start
()
thread_pool
.
append
(
t
)
finish_num
=
0
while
finish_num
<
reader_num
:
sample
=
output_queue
.
get
()
if
sample
is
None
:
finish_num
+=
1
conn
.
close
()
finish_flag
[
conn_id
]
=
1
else
:
yield
sample
for
thread
in
thread_pool
:
thread
.
join
()
if
use_pipe
:
return
pipe_reader
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录