From a43b5a2e5355ee4fcd0cb09e3a14d36f03a2304a Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Wed, 1 Jul 2020 18:10:07 +0800 Subject: [PATCH] add graphsaint --- pgl/graph_kernel.pyx | 40 ++++++++++++++++++++++++++++++++ pgl/sample.py | 55 ++++++++++++++++++++++++++++++++++++-------- 2 files changed, 85 insertions(+), 10 deletions(-) diff --git a/pgl/graph_kernel.pyx b/pgl/graph_kernel.pyx index 5e5f289..52ad265 100644 --- a/pgl/graph_kernel.pyx +++ b/pgl/graph_kernel.pyx @@ -321,3 +321,43 @@ def alias_sample_build_table(np.ndarray[np.float64_t, ndim=1] probs): if alias[l_i] < 1: smaller_num.push_back(l_i) return alias, events + + +@cython.boundscheck(False) +@cython.wraparound(False) +def adj_extract( + np.ndarray[np.int64_t, ndim=1] adj_indptr, + np.ndarray[np.int64_t, ndim=1] sorted_v, + vector[long long] sampled_nodes, +): + """ + Extract all eids of given sampled_nodes for the origin graph. + ret_edge_index: edge ids between sampled_nodes. + + Refers: https://github.com/GraphSAINT/GraphSAINT + """ + cdef long long i, v, j + cdef long long num_v_orig, num_v_sub + cdef long long start_neigh, end_neigh + cdef vector[int] _arr_bit + cdef vector[long long] ret_edge_index + num_v_orig = adj_indptr.size-1 + _arr_bit = vector[int](num_v_orig,-1) + num_v_sub = sampled_nodes.size() + i = 0 + with nogil: + while i < num_v_sub: + _arr_bit[sampled_nodes[i]] = i + i = i + 1 + i = 0 + while i < num_v_sub: + v = sampled_nodes[i] + start_neigh = adj_indptr[v] + end_neigh = adj_indptr[v+1] + j = start_neigh + while j < end_neigh: + if _arr_bit[sorted_v[j]] > -1: + ret_edge_index.push_back(j) + j = j + 1 + i = i + 1 + return ret_edge_index diff --git a/pgl/sample.py b/pgl/sample.py index 89c1d1b..c7b2a31 100644 --- a/pgl/sample.py +++ b/pgl/sample.py @@ -55,7 +55,7 @@ def edge_hash(src, dst): def graphsage_sample(graph, nodes, samples, ignore_edges=[]): """Implement of graphsage sample. - + Reference paper: https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf. Args: @@ -63,7 +63,7 @@ def graphsage_sample(graph, nodes, samples, ignore_edges=[]): nodes: Sample starting from nodes samples: A list, number of neighbors in each layer ignore_edges: list of edge(src, dst) will be ignored. - + Return: A list of subgraphs """ @@ -129,7 +129,7 @@ def alias_sample(size, alias, events): size: Output shape. alias: The alias table build by `alias_sample_build_table`. events: The events table build by `alias_sample_build_table`. - + Return: samples: The generated random samples. """ @@ -283,13 +283,13 @@ def metapath_randomwalk(graph, Args: graph: instance of pgl heterogeneous graph start_nodes: start nodes to generate walk - metapath: meta path for sample nodes. + metapath: meta path for sample nodes. e.g: "c2p-p2a-a2p-p2c" walk_length: the walk length Return: - a list of metapath walks. - + a list of metapath walks. + """ edge_types = metapath.split('-') @@ -390,18 +390,18 @@ def pinsage_sample(graph, norm_bais=1.0, ignore_edges=set()): """Implement of graphsage sample. - + Reference paper: . Args: graph: A pgl graph instance nodes: Sample starting from nodes samples: A list, number of neighbors in each layer - top_k: select the top_k visit count nodes to construct the edges - proba: the probability to return the origin node + top_k: select the top_k visit count nodes to construct the edges + proba: the probability to return the origin node norm_bais: the normlization for the visit count ignore_edges: list of edge(src, dst) will be ignored. - + Return: A list of subgraphs """ @@ -476,3 +476,38 @@ def pinsage_sample(graph, layer_nodes[0], dtype="int64") return subgraphs + + +def graph_saint_random_walk_sample(graph, + nodes, + max_depth, + alias_name=None, + events_name=None): + """Implement of graph saint random walk sample. + + First, this function will get random walks path for given nodes and depth. + Then, it will create subgraph from all sampled nodes. + + Reference Paper: https://arxiv.org/abs/1907.04931 + + Args: + graph: A pgl graph instance + nodes: Walk starting from nodes + max_depth: Max walking depth + + Return: + a subgraph of sampled nodes. + """ + graph.indegree() + walks = deepwalk_sample(graph, nodes, max_depth, alias_name, events_name) + sample_nodes = [] + for walk in walks: + sample_nodes.extend(walk) + sample_nodes = np.unique(sample_nodes) + eids = graph_kernel.adj_extract(graph._adj_dst_index._indptr, + graph._adj_dst_index._sorted_v, + sample_nodes) + subgraph = graph.subgraph( + nodes=sample_nodes, eid=eids, with_node_feat=True, with_edge_feat=True) + subgraph.node_feat["index"] = np.array(sample_nodes, dtype="int64") + return subgraph -- GitLab