Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
969c1880
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
969c1880
编写于
9月 09, 2020
作者:
W
Webbley
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mmap mode for heter graph
上级
bccff22a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
218 addition
and
2 deletion
+218
-2
pgl/heter_graph.py
pgl/heter_graph.py
+45
-2
pgl/tests/test_MmapHeterGraph.py
pgl/tests/test_MmapHeterGraph.py
+173
-0
未找到文件。
pgl/heter_graph.py
浏览文件 @
969c1880
...
...
@@ -14,12 +14,13 @@
"""
This package implement Heterogeneous Graph structure for handling Heterogeneous graph data.
"""
import
os
import
time
import
numpy
as
np
import
pickle
as
pkl
import
time
import
pgl.graph_kernel
as
graph_kernel
from
pgl.graph
import
Graph
from
pgl.graph
import
Graph
,
MemmapGraph
__all__
=
[
'HeterGraph'
,
'SubHeterGraph'
]
...
...
@@ -113,6 +114,30 @@ class HeterGraph(object):
self
.
_edge_types
=
self
.
edge_types_info
()
def
dump
(
self
,
path
,
indegree
=
False
,
outdegree
=
False
):
if
indegree
:
for
e_type
,
g
in
self
.
_multi_graph
.
items
():
g
.
indegree
()
if
outdegree
:
for
e_type
,
g
in
self
.
_multi_graph
.
items
():
g
.
outdegree
()
if
not
os
.
path
.
exists
(
path
):
os
.
makedirs
(
path
)
np
.
save
(
os
.
path
.
join
(
path
,
"num_nodes.npy"
),
self
.
_num_nodes
)
np
.
save
(
os
.
path
.
join
(
path
,
"node_types.npy"
),
self
.
_node_types
)
with
open
(
os
.
path
.
join
(
path
,
"edge_types.pkl"
),
'wb'
)
as
f
:
pkl
.
dump
(
self
.
_edge_types
,
f
)
with
open
(
os
.
path
.
join
(
path
,
"nodes_type_dict.pkl"
),
'wb'
)
as
f
:
pkl
.
dump
(
self
.
_nodes_type_dict
,
f
)
for
e_type
,
g
in
self
.
_multi_graph
.
items
():
sub_path
=
os
.
path
.
join
(
path
,
e_type
)
g
.
dump
(
sub_path
)
@
property
def
edge_types
(
self
):
"""Return a list of edge types.
...
...
@@ -399,7 +424,7 @@ class HeterGraph(object):
"""
edge_types_info
=
[]
for
key
,
_
in
self
.
_
edges_dict
.
items
():
for
key
,
_
in
self
.
_
multi_graph
.
items
():
edge_types_info
.
append
(
key
)
return
edge_types_info
...
...
@@ -460,3 +485,21 @@ class SubHeterGraph(HeterGraph):
A list of node ids in parent graph.
"""
return
graph_kernel
.
map_nodes
(
nodes
,
self
.
_to_reindex
)
class
MemmapHeterGraph
(
HeterGraph
):
def
__init__
(
self
,
path
):
self
.
_num_nodes
=
np
.
load
(
os
.
path
.
join
(
path
,
'num_nodes.npy'
))
self
.
_node_types
=
np
.
load
(
os
.
path
.
join
(
path
,
'node_types.npy'
),
allow_pickle
=
True
)
with
open
(
os
.
path
.
join
(
path
,
'edge_types.pkl'
),
'rb'
)
as
f
:
self
.
_edge_types
=
pkl
.
load
(
f
)
with
open
(
os
.
path
.
join
(
path
,
"nodes_type_dict.pkl"
),
'rb'
)
as
f
:
self
.
_nodes_type_dict
=
pkl
.
load
(
f
)
self
.
_multi_graph
=
{}
for
e_type
in
self
.
_edge_types
:
sub_path
=
os
.
path
.
join
(
path
,
e_type
)
self
.
_multi_graph
[
e_type
]
=
MemmapGraph
(
sub_path
)
pgl/tests/test_MmapHeterGraph.py
0 → 100644
浏览文件 @
969c1880
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test_hetergraph"""
import
time
import
unittest
import
json
import
os
import
numpy
as
np
from
pgl.sample
import
metapath_randomwalk
from
pgl.graph
import
Graph
from
pgl
import
heter_graph
from
pgl.heter_graph
import
MemmapHeterGraph
def
test_dump
():
np
.
random
.
seed
(
1
)
edges
=
{}
# for test no successor
edges
[
'c2p'
]
=
[(
1
,
4
),
(
0
,
5
),
(
1
,
9
),
(
1
,
8
),
(
2
,
8
),
(
2
,
5
),
(
3
,
6
),
(
3
,
7
),
(
3
,
4
),
(
3
,
8
)]
edges
[
'p2c'
]
=
[(
v
,
u
)
for
u
,
v
in
edges
[
'c2p'
]]
edges
[
'p2a'
]
=
[(
4
,
10
),
(
4
,
11
),
(
4
,
12
),
(
4
,
14
),
(
4
,
13
),
(
6
,
12
),
(
6
,
11
),
(
6
,
14
),
(
7
,
12
),
(
7
,
11
),
(
8
,
14
),
(
9
,
10
)]
edges
[
'a2p'
]
=
[(
v
,
u
)
for
u
,
v
in
edges
[
'p2a'
]]
node_types
=
[
'c'
for
_
in
range
(
4
)]
+
[
'p'
for
_
in
range
(
6
)
]
+
[
'a'
for
_
in
range
(
5
)]
node_types
=
[(
i
,
t
)
for
i
,
t
in
enumerate
(
node_types
)]
graph
=
heter_graph
.
HeterGraph
(
num_nodes
=
len
(
node_types
),
edges
=
edges
,
node_types
=
node_types
)
graph
.
dump
(
"./hetergraph_mmap"
,
outdegree
=
True
)
def
test_load
():
graph
=
MemmapHeterGraph
(
"./hetergraph_mmap"
)
class
MmapHeterGraphTest
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
graph
=
MemmapHeterGraph
(
"./hetergraph_mmap"
)
def
test_num_nodes_by_type
(
self
):
print
()
n_types
=
{
'c'
:
4
,
'p'
:
6
,
'a'
:
5
}
for
nt
in
n_types
:
num_nodes
=
self
.
graph
.
num_nodes_by_type
(
nt
)
self
.
assertEqual
(
num_nodes
,
n_types
[
nt
])
def
test_node_batch_iter
(
self
):
print
()
batch_size
=
2
ground
=
[[
4
,
5
],
[
6
,
7
],
[
8
,
9
]]
for
idx
,
nodes
in
enumerate
(
self
.
graph
.
node_batch_iter
(
batch_size
=
batch_size
,
shuffle
=
False
,
n_type
=
'p'
)):
self
.
assertEqual
(
len
(
nodes
),
batch_size
)
self
.
assertListEqual
(
list
(
nodes
),
ground
[
idx
])
def
test_sample_successor
(
self
):
print
()
nodes
=
[
4
,
5
,
8
]
md
=
2
succes
=
self
.
graph
.
sample_successor
(
edge_type
=
'p2a'
,
nodes
=
nodes
,
max_degree
=
md
,
return_eids
=
False
)
self
.
assertIsInstance
(
succes
,
list
)
ground
=
[[
10
,
11
,
12
,
14
,
13
],
[],
[
14
]]
for
succ
,
g
in
zip
(
succes
,
ground
):
self
.
assertIsInstance
(
succ
,
np
.
ndarray
)
for
i
in
succ
:
self
.
assertIn
(
i
,
g
)
nodes
=
[
4
]
succes
=
self
.
graph
.
sample_successor
(
edge_type
=
'p2a'
,
nodes
=
nodes
,
max_degree
=
md
,
return_eids
=
False
)
self
.
assertIsInstance
(
succes
,
list
)
ground
=
[[
10
,
11
,
12
,
14
,
13
]]
for
succ
,
g
in
zip
(
succes
,
ground
):
self
.
assertIsInstance
(
succ
,
np
.
ndarray
)
for
i
in
succ
:
self
.
assertIn
(
i
,
g
)
def
test_successor
(
self
):
print
()
nodes
=
[
4
,
5
,
8
]
e_type
=
'p2a'
succes
=
self
.
graph
.
successor
(
edge_type
=
e_type
,
nodes
=
nodes
,
)
self
.
assertIsInstance
(
succes
,
np
.
ndarray
)
ground
=
[[
10
,
11
,
12
,
14
,
13
],
[],
[
14
]]
for
succ
,
g
in
zip
(
succes
,
ground
):
self
.
assertIsInstance
(
succ
,
np
.
ndarray
)
self
.
assertCountEqual
(
succ
,
g
)
nodes
=
[
4
]
e_type
=
'p2a'
succes
=
self
.
graph
.
successor
(
edge_type
=
e_type
,
nodes
=
nodes
,
)
self
.
assertIsInstance
(
succes
,
np
.
ndarray
)
ground
=
[[
10
,
11
,
12
,
14
,
13
]]
for
succ
,
g
in
zip
(
succes
,
ground
):
self
.
assertIsInstance
(
succ
,
np
.
ndarray
)
self
.
assertCountEqual
(
succ
,
g
)
def
test_predecessor
(
self
):
print
()
nodes
=
[
11
,
12
,
13
]
e_type
=
'p2a'
pre
=
self
.
graph
.
predecessor
(
edge_type
=
e_type
,
nodes
=
nodes
,
)
self
.
assertIsInstance
(
pre
,
np
.
ndarray
)
print
(
pre
)
ground
=
[[
4
,
6
,
7
],
[
4
,
6
,
7
],
[
4
]]
for
succ
,
g
in
zip
(
pre
,
ground
):
self
.
assertIsInstance
(
succ
,
np
.
ndarray
)
self
.
assertCountEqual
(
succ
,
g
)
nodes
=
[
11
]
e_type
=
'p2a'
pre
=
self
.
graph
.
predecessor
(
edge_type
=
e_type
,
nodes
=
nodes
,
)
print
(
pre
)
self
.
assertIsInstance
(
pre
,
np
.
ndarray
)
ground
=
[[
4
,
6
,
7
]]
for
p
,
g
in
zip
(
pre
,
ground
):
self
.
assertIsInstance
(
p
,
np
.
ndarray
)
self
.
assertCountEqual
(
p
,
g
)
def
test_sample_nodes
(
self
):
print
()
p_ground
=
[
4
,
5
,
6
,
7
,
8
,
9
]
sample_num
=
10
nodes
=
self
.
graph
.
sample_nodes
(
sample_num
=
sample_num
,
n_type
=
'p'
)
self
.
assertEqual
(
len
(
nodes
),
sample_num
)
for
n
in
nodes
:
self
.
assertIn
(
n
,
p_ground
)
# test n_type == None
ground
=
[
i
for
i
in
range
(
15
)]
nodes
=
self
.
graph
.
sample_nodes
(
sample_num
=
sample_num
,
n_type
=
None
)
self
.
assertEqual
(
len
(
nodes
),
sample_num
)
for
n
in
nodes
:
self
.
assertIn
(
n
,
ground
)
if
__name__
==
"__main__"
:
unittest
.
main
()
# test_dump()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录