提交 89338e44 编写于 作者: Z Zhong Hui

fix graph saint

上级 84b9d61c
......@@ -322,12 +322,12 @@ def alias_sample_build_table(np.ndarray[np.float64_t, ndim=1] probs):
smaller_num.push_back(l_i)
return alias, events
@cython.boundscheck(False)
@cython.wraparound(False)
def extract_edges_from_nodes(
np.ndarray[np.int64_t, ndim=1] adj_indptr,
np.ndarray[np.int64_t, ndim=1] sorted_v,
np.ndarray[np.int64_t, ndim=1] sorted_eid,
vector[long long] sampled_nodes,
):
"""
......@@ -357,7 +357,7 @@ def extract_edges_from_nodes(
j = start_neigh
while j < end_neigh:
if _arr_bit[sorted_v[j]] > -1:
ret_edge_index.push_back(j)
ret_edge_index.push_back(sorted_eid[j])
j = j + 1
i = i + 1
return ret_edge_index
......@@ -480,8 +480,8 @@ def pinsage_sample(graph,
def extract_edges_from_nodes(graph, sample_nodes):
eids = graph_kernel.extract_edges_from_nodes(
graph._adj_dst_index._indptr, graph._adj_dst_index._sorted_v,
sample_nodes)
graph.adj_src_index._indptr, graph.adj_src_index._sorted_v,
graph.adj_src_index._sorted_eid, sample_nodes)
return eids
......@@ -505,7 +505,7 @@ def graph_saint_random_walk_sample(graph,
Return:
a subgraph of sampled nodes.
"""
graph.indegree()
graph.outdegree()
walks = deepwalk_sample(graph, nodes, max_depth, alias_name, events_name)
sample_nodes = []
for walk in walks:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""graph saint sample test"""
import unittest
import pgl
import numpy as np
import paddle.fluid as fluid
from pgl.sample import graph_saint_random_walk_sample
class GraphSaintSampleTest(unittest.TestCase):
"""ScatterAddTest"""
def test_randomwalk_sampler(self):
"""test_scatter_add"""
g = pgl.graph.Graph(
num_nodes=8,
edges=[(1, 2), (2, 3), (0, 2), (0, 1), (6, 7), (4, 5), (6, 4),
(7, 4), (3, 4)])
subgraph = graph_saint_random_walk_sample(g, [6, 7], 2)
print('reinded', subgraph._from_reindex)
print('sub_edges', subgraph.edges)
assert len(subgraph.nodes) == 4
assert len(subgraph.edges) == 4
true_edges = np.array([[0, 1], [2, 3], [2, 0], [3, 0]])
assert "{}".format(subgraph.edges) == "{}".format(true_edges)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册