Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
89338e44
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
89338e44
编写于
7月 02, 2020
作者:
Z
Zhong Hui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix graph saint
上级
84b9d61c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
48 addition
and
5 deletion
+48
-5
pgl/graph_kernel.pyx
pgl/graph_kernel.pyx
+2
-2
pgl/sample.py
pgl/sample.py
+3
-3
pgl/tests/test_graph_saint_sample.py
pgl/tests/test_graph_saint_sample.py
+43
-0
未找到文件。
pgl/graph_kernel.pyx
浏览文件 @
89338e44
...
...
@@ -322,12 +322,12 @@ def alias_sample_build_table(np.ndarray[np.float64_t, ndim=1] probs):
smaller_num
.
push_back
(
l_i
)
return
alias
,
events
@
cython
.
boundscheck
(
False
)
@
cython
.
wraparound
(
False
)
def
extract_edges_from_nodes
(
np
.
ndarray
[
np
.
int64_t
,
ndim
=
1
]
adj_indptr
,
np
.
ndarray
[
np
.
int64_t
,
ndim
=
1
]
sorted_v
,
np
.
ndarray
[
np
.
int64_t
,
ndim
=
1
]
sorted_eid
,
vector
[
long
long
]
sampled_nodes
,
):
"""
...
...
@@ -357,7 +357,7 @@ def extract_edges_from_nodes(
j
=
start_neigh
while
j
<
end_neigh
:
if
_arr_bit
[
sorted_v
[
j
]]
>
-
1
:
ret_edge_index
.
push_back
(
j
)
ret_edge_index
.
push_back
(
sorted_eid
[
j
]
)
j
=
j
+
1
i
=
i
+
1
return
ret_edge_index
pgl/sample.py
浏览文件 @
89338e44
...
...
@@ -480,8 +480,8 @@ def pinsage_sample(graph,
def
extract_edges_from_nodes
(
graph
,
sample_nodes
):
eids
=
graph_kernel
.
extract_edges_from_nodes
(
graph
.
_adj_dst_index
.
_indptr
,
graph
.
_adj_dst
_index
.
_sorted_v
,
sample_nodes
)
graph
.
adj_src_index
.
_indptr
,
graph
.
adj_src
_index
.
_sorted_v
,
graph
.
adj_src_index
.
_sorted_eid
,
sample_nodes
)
return
eids
...
...
@@ -505,7 +505,7 @@ def graph_saint_random_walk_sample(graph,
Return:
a subgraph of sampled nodes.
"""
graph
.
in
degree
()
graph
.
out
degree
()
walks
=
deepwalk_sample
(
graph
,
nodes
,
max_depth
,
alias_name
,
events_name
)
sample_nodes
=
[]
for
walk
in
walks
:
...
...
pgl/tests/test_graph_saint_sample.py
0 → 100644
浏览文件 @
89338e44
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""graph saint sample test"""
import
unittest
import
pgl
import
numpy
as
np
import
paddle.fluid
as
fluid
from
pgl.sample
import
graph_saint_random_walk_sample
class
GraphSaintSampleTest
(
unittest
.
TestCase
):
"""ScatterAddTest"""
def
test_randomwalk_sampler
(
self
):
"""test_scatter_add"""
g
=
pgl
.
graph
.
Graph
(
num_nodes
=
8
,
edges
=
[(
1
,
2
),
(
2
,
3
),
(
0
,
2
),
(
0
,
1
),
(
6
,
7
),
(
4
,
5
),
(
6
,
4
),
(
7
,
4
),
(
3
,
4
)])
subgraph
=
graph_saint_random_walk_sample
(
g
,
[
6
,
7
],
2
)
print
(
'reinded'
,
subgraph
.
_from_reindex
)
print
(
'sub_edges'
,
subgraph
.
edges
)
assert
len
(
subgraph
.
nodes
)
==
4
assert
len
(
subgraph
.
edges
)
==
4
true_edges
=
np
.
array
([[
0
,
1
],
[
2
,
3
],
[
2
,
0
],
[
3
,
0
]])
assert
"{}"
.
format
(
subgraph
.
edges
)
==
"{}"
.
format
(
true_edges
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录