提交 84b9d61c 编写于 作者: Z Zhong Hui

refine graph saint

上级 a43b5a2e
......@@ -24,7 +24,7 @@ import ssl
ssl._create_default_https_context = ssl._create_unverified_context
from pgl.contrib.ogb.nodeproppred.dataset_pgl import PglNodePropPredDataset
#from pgl.sample import graph_saint_random_walk_sample
from pgl.sample import graph_saint_random_walk_sample
from ogb.nodeproppred import Evaluator
import tqdm
from collections import namedtuple
......@@ -78,10 +78,10 @@ def k_hop_sampler(graph, samples, batch_nodes):
return subgraph, sub_node_index
#def graph_saint_randomwalk_sampler(graph, batch_nodes, max_depth=3):
# subgraph = graph_saint_random_walk_sample(graph, batch_nodes, max_depth)
# sub_node_index = subgraph.reindex_from_parrent_nodes(batch_nodes)
# return subgraph, sub_node_index
def graph_saint_randomwalk_sampler(graph, batch_nodes, max_depth=3):
subgraph = graph_saint_random_walk_sample(graph, batch_nodes, max_depth)
sub_node_index = subgraph.reindex_from_parrent_nodes(batch_nodes)
return subgraph, sub_node_index
class ArxivDataGenerator(BaseDataGenerator):
......
......@@ -325,7 +325,7 @@ def alias_sample_build_table(np.ndarray[np.float64_t, ndim=1] probs):
@cython.boundscheck(False)
@cython.wraparound(False)
def adj_extract(
def extract_edges_from_nodes(
np.ndarray[np.int64_t, ndim=1] adj_indptr,
np.ndarray[np.int64_t, ndim=1] sorted_v,
vector[long long] sampled_nodes,
......
......@@ -24,7 +24,7 @@ from pgl import graph_kernel
__all__ = [
'graphsage_sample', 'node2vec_sample', 'deepwalk_sample',
'metapath_randomwalk', 'pinsage_sample'
'metapath_randomwalk', 'pinsage_sample', 'graph_saint_random_walk_sample'
]
......@@ -478,6 +478,13 @@ def pinsage_sample(graph,
return subgraphs
def extract_edges_from_nodes(graph, sample_nodes):
eids = graph_kernel.extract_edges_from_nodes(
graph._adj_dst_index._indptr, graph._adj_dst_index._sorted_v,
sample_nodes)
return eids
def graph_saint_random_walk_sample(graph,
nodes,
max_depth,
......@@ -504,9 +511,7 @@ def graph_saint_random_walk_sample(graph,
for walk in walks:
sample_nodes.extend(walk)
sample_nodes = np.unique(sample_nodes)
eids = graph_kernel.adj_extract(graph._adj_dst_index._indptr,
graph._adj_dst_index._sorted_v,
sample_nodes)
eids = extract_edges_from_nodes(graph, sample_nodes)
subgraph = graph.subgraph(
nodes=sample_nodes, eid=eids, with_node_feat=True, with_edge_feat=True)
subgraph.node_feat["index"] = np.array(sample_nodes, dtype="int64")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册