diff --git a/pgl/graph_kernel.pyx b/pgl/graph_kernel.pyx index 854082efa9694f4c0fe161e2fa330026e27dcdb5..22aee9589cc49c2dd98c3f3912c5ff7c09298f15 100644 --- a/pgl/graph_kernel.pyx +++ b/pgl/graph_kernel.pyx @@ -322,12 +322,12 @@ def alias_sample_build_table(np.ndarray[np.float64_t, ndim=1] probs): smaller_num.push_back(l_i) return alias, events - @cython.boundscheck(False) @cython.wraparound(False) def extract_edges_from_nodes( np.ndarray[np.int64_t, ndim=1] adj_indptr, np.ndarray[np.int64_t, ndim=1] sorted_v, + np.ndarray[np.int64_t, ndim=1] sorted_eid, vector[long long] sampled_nodes, ): """ @@ -357,7 +357,7 @@ def extract_edges_from_nodes( j = start_neigh while j < end_neigh: if _arr_bit[sorted_v[j]] > -1: - ret_edge_index.push_back(j) + ret_edge_index.push_back(sorted_eid[j]) j = j + 1 i = i + 1 return ret_edge_index diff --git a/pgl/sample.py b/pgl/sample.py index 741a576661b02302d5813057c0da683e10d2e9e0..81241d5dc6f8224283abebeaa35da69644e9d1a1 100644 --- a/pgl/sample.py +++ b/pgl/sample.py @@ -480,8 +480,8 @@ def pinsage_sample(graph, def extract_edges_from_nodes(graph, sample_nodes): eids = graph_kernel.extract_edges_from_nodes( - graph._adj_dst_index._indptr, graph._adj_dst_index._sorted_v, - sample_nodes) + graph.adj_src_index._indptr, graph.adj_src_index._sorted_v, + graph.adj_src_index._sorted_eid, sample_nodes) return eids @@ -505,7 +505,7 @@ def graph_saint_random_walk_sample(graph, Return: a subgraph of sampled nodes. """ - graph.indegree() + graph.outdegree() walks = deepwalk_sample(graph, nodes, max_depth, alias_name, events_name) sample_nodes = [] for walk in walks: diff --git a/pgl/tests/test_graph_saint_sample.py b/pgl/tests/test_graph_saint_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..cf188a6d94224e2416a48736f959d3de0982ac1b --- /dev/null +++ b/pgl/tests/test_graph_saint_sample.py @@ -0,0 +1,43 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""graph saint sample test""" + +import unittest + +import pgl +import numpy as np +import paddle.fluid as fluid +from pgl.sample import graph_saint_random_walk_sample + + +class GraphSaintSampleTest(unittest.TestCase): + """ScatterAddTest""" + + def test_randomwalk_sampler(self): + """test_scatter_add""" + g = pgl.graph.Graph( + num_nodes=8, + edges=[(1, 2), (2, 3), (0, 2), (0, 1), (6, 7), (4, 5), (6, 4), + (7, 4), (3, 4)]) + subgraph = graph_saint_random_walk_sample(g, [6, 7], 2) + print('reinded', subgraph._from_reindex) + print('sub_edges', subgraph.edges) + assert len(subgraph.nodes) == 4 + assert len(subgraph.edges) == 4 + true_edges = np.array([[0, 1], [2, 3], [2, 0], [3, 0]]) + assert "{}".format(subgraph.edges) == "{}".format(true_edges) + + +if __name__ == '__main__': + unittest.main()