提交 cfd11f59 编写于 作者: L liweibin

increase the sampling speed of metapath random walk

上级 0fb10359
...@@ -40,8 +40,10 @@ class Dataset(object): ...@@ -40,8 +40,10 @@ class Dataset(object):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.walk_files = config['input_path'] + config['walk_path'] self.walk_files = os.path.join(config['input_path'],
self.word2id_file = config['input_path'] + config['word2id_file'] config['walk_path'])
self.word2id_file = os.path.join(config['input_path'],
config['word2id_file'])
self.word2freq = {} self.word2freq = {}
self.word2id = {} self.word2id = {}
...@@ -65,7 +67,7 @@ class Dataset(object): ...@@ -65,7 +67,7 @@ class Dataset(object):
for walk_file in glob.glob(self.walk_files): for walk_file in glob.glob(self.walk_files):
with open(walk_file, 'r') as reader: with open(walk_file, 'r') as reader:
for walk in reader: for walk in reader:
walk = walk.strip().split(' ') walk = walk.strip().split()
if len(walk) > 1: if len(walk) > 1:
self.sentences_count += 1 self.sentences_count += 1
for word in walk: for word in walk:
...@@ -123,7 +125,7 @@ class Dataset(object): ...@@ -123,7 +125,7 @@ class Dataset(object):
for filename in walkpath_files: for filename in walkpath_files:
with open(filename) as reader: with open(filename) as reader:
for line in reader: for line in reader:
words = line.strip().split(' ') words = line.strip().split()
if len(words) > 1: if len(words) > 1:
word_ids = [ word_ids = [
self.word2id[w] for w in words if w in self.word2id self.word2id[w] for w in words if w in self.word2id
......
...@@ -13,9 +13,12 @@ sampler: ...@@ -13,9 +13,12 @@ sampler:
new_author_label_file: author_label.txt new_author_label_file: author_label.txt
new_venue_label_file: venue_label.txt new_venue_label_file: venue_label.txt
walk_saved_path: walks/ walk_saved_path: walks/
walk_batch_size: 1000
num_walks: 1000 num_walks: 1000
walk_length: 100 walk_length: 100
metapath: conf-paper-author-paper-conf num_sample_workers: 16
first_node_type: conf
metapath: c2p-p2a-a2p-p2c #conf-paper-author-paper-conf
optimizer: optimizer:
type: Adam type: Adam
......
...@@ -101,7 +101,7 @@ class SkipgramModel(object): ...@@ -101,7 +101,7 @@ class SkipgramModel(object):
pos_score = fl.squeeze(pos_logits, axes=[1]) pos_score = fl.squeeze(pos_logits, axes=[1])
pos_score = fl.clip(pos_score, min=-10, max=10) pos_score = fl.clip(pos_score, min=-10, max=10)
pos_score = -1.0 * fl.logsigmoid(pos_score) pos_score = -self.neg_num * fl.logsigmoid(pos_score)
neg_logits = fl.matmul( neg_logits = fl.matmul(
embed_src, weight_negs, embed_src, weight_negs,
...@@ -111,4 +111,4 @@ class SkipgramModel(object): ...@@ -111,4 +111,4 @@ class SkipgramModel(object):
neg_score = -1.0 * fl.logsigmoid(-1.0 * neg_score) neg_score = -1.0 * fl.logsigmoid(-1.0 * neg_score)
neg_score = fl.reduce_sum(neg_score, dim=1, keep_dim=True) neg_score = fl.reduce_sum(neg_score, dim=1, keep_dim=True)
self.loss = fl.reduce_mean(pos_score + neg_score) self.loss = fl.reduce_mean(pos_score + neg_score) / self.neg_num / 2
...@@ -18,6 +18,7 @@ training metapath2vec model. ...@@ -18,6 +18,7 @@ training metapath2vec model.
import multiprocessing import multiprocessing
from multiprocessing import Pool from multiprocessing import Pool
from multiprocessing import Process
import argparse import argparse
import sys import sys
import os import os
...@@ -77,9 +78,14 @@ class Sampler(object): ...@@ -77,9 +78,14 @@ class Sampler(object):
self.config['data_path'] + 'paper_conf.txt', self.paper_id2index, self.config['data_path'] + 'paper_conf.txt', self.paper_id2index,
self.conf_id2index) self.conf_id2index)
edges_by_types['edge'] = paper_author_edges + paper_conf_edges # edges_by_types['edge'] = paper_author_edges + paper_conf_edges
logging.info('%d edges have been loaded.' % edges_by_types['p2c'] = paper_conf_edges
(len(edges_by_types['edge']))) edges_by_types['c2p'] = [(dst, src) for src, dst in paper_conf_edges]
edges_by_types['p2a'] = paper_author_edges
edges_by_types['a2p'] = [(dst, src) for src, dst in paper_author_edges]
# logging.info('%d edges have been loaded.' %
# (len(edges_by_types['edge'])))
node_features = { node_features = {
'index': np.array([i for i in range(num_nodes)]).reshape( 'index': np.array([i for i in range(num_nodes)]).reshape(
...@@ -110,7 +116,7 @@ class Sampler(object): ...@@ -110,7 +116,7 @@ class Sampler(object):
return id2index, name2index, node_types return id2index, name2index, node_types
def load_edges(self, file_, src2index, dst2index, symmetry=True): def load_edges(self, file_, src2index, dst2index, symmetry=False):
"""Load edges from file. """Load edges from file.
""" """
edges = [] edges = []
...@@ -143,41 +149,65 @@ class Sampler(object): ...@@ -143,41 +149,65 @@ class Sampler(object):
return index_label_list return index_label_list
def generate_walks(args): def walk_generator(graph, batch_size, metapath, n_type, walk_length):
"""Generate metapath random walk and save to file. """Generate metapath random walk.
""" """
g, meta_path, filename, walk_length = args np.random.seed(os.getpid())
walks = [] while True:
node_types = g._node_types for start_nodes in graph.node_batch_iter(
first_type = meta_path.split('-')[0] batch_size=batch_size, n_type=n_type):
nodes = np.where(node_types == first_type)[0] walks = metapath_randomwalk(
if len(nodes) > 4000: graph=graph,
nodes = np.random.choice(nodes, 4000, replace=False) start_nodes=start_nodes,
metapath=metapath,
logging.info('%d number of start nodes' % (len(nodes))) walk_length=walk_length)
logging.info('save walks in file: %s' % (filename)) yield walks
def walk_to_files(g, batch_size, metapath, n_type, walk_length, max_num,
filename):
"""Generate metapath randomwalk and save in files"""
# g, batch_size, metapath, n_type, walk_length, max_num, filename = args
with open(filename, 'w') as writer: with open(filename, 'w') as writer:
for start_node in nodes: cc = 0
walk = metapath_randomwalk(g, start_node, meta_path, walk_length) for walks in walk_generator(g, batch_size, metapath, n_type,
walk = [str(walk[i]) for i in range(0, len(walk), 2)] # skip paper walk_length):
writer.write(' '.join(walk) + '\n') for walk in walks:
writer.write("%s\n" % "\t".join([str(i) for i in walk]))
cc += 1
if cc == max_num:
return
return
def multiprocess_generate_walks_to_files(graph, n_type, meta_path, num_walks,
walk_length, batch_size,
num_sample_workers, saved_path):
"""Use multiprocess to generate metapath random walk to files.
"""
num_nodes_by_type = graph.num_nodes_by_type(n_type)
logging.info("num_nodes_by_type: %s" % num_nodes_by_type)
max_num = (num_walks * num_nodes_by_type // num_sample_workers) + 1
logging.info("max sample number of every worker: %s" % max_num)
def multiprocess_generate_walks(sampler, edge_type, meta_path, num_walks,
walk_length, saved_path):
"""Use multiprocess to generate metapath random walk.
"""
args = [] args = []
for i in range(num_walks): for i in range(num_sample_workers):
filename = saved_path + '%04d' % (i) filename = os.path.join(saved_path, 'part-%05d' % (i))
args.append( args.append((graph, batch_size, meta_path, n_type, walk_length,
(sampler.graph[edge_type], meta_path, filename, walk_length)) max_num, filename))
pool = Pool(16) ps = []
pool.map(generate_walks, args) for i in range(num_sample_workers):
pool.close() p = Process(target=walk_to_files, args=args[i])
pool.join() p.start()
ps.append(p)
for i in range(num_sample_workers):
ps[i].join()
# pool = Pool(num_sample_workers)
# pool.map(walk_to_files, args)
# pool.close()
# pool.join()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -220,13 +250,15 @@ if __name__ == "__main__": ...@@ -220,13 +250,15 @@ if __name__ == "__main__":
begin = time.time() begin = time.time()
logging.info('multi process sampling') logging.info('multi process sampling')
multiprocess_generate_walks( multiprocess_generate_walks_to_files(
sampler=sampler, graph=sampler.graph,
edge_type='edge', n_type=config['first_node_type'],
meta_path=config['metapath'], meta_path=config['metapath'],
num_walks=config['num_walks'], num_walks=config['num_walks'],
walk_length=config['walk_length'], walk_length=config['walk_length'],
saved_path=config['walk_saved_path']) batch_size=config['walk_batch_size'],
num_sample_workers=config['num_sample_workers'],
saved_path=config['walk_saved_path'], )
logging.info('total time: %.4f' % (time.time() - begin)) logging.info('total time: %.4f' % (time.time() - begin))
logging.info('generating multi class data') logging.info('generating multi class data')
......
...@@ -21,7 +21,7 @@ import time ...@@ -21,7 +21,7 @@ import time
import pgl.graph_kernel as graph_kernel import pgl.graph_kernel as graph_kernel
from pgl.graph import Graph from pgl.graph import Graph
__all__ = ['HeterGraph'] __all__ = ['HeterGraph', 'SubHeterGraph']
def _hide_num_nodes(shape): def _hide_num_nodes(shape):
...@@ -32,31 +32,6 @@ def _hide_num_nodes(shape): ...@@ -32,31 +32,6 @@ def _hide_num_nodes(shape):
return shape return shape
class NodeGraph(Graph):
"""Implementation of a graph that has multple node types.
Args:
num_nodes: number of nodes in the graph
edges: list of (u, v) tuples
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of numpy array as edge features
"""
def __init__(self,
num_nodes,
edges,
node_types=None,
node_feat=None,
edge_feat=None):
super(NodeGraph, self).__init__(num_nodes, edges, node_feat, edge_feat)
if isinstance(node_types, list):
self._node_types = np.array(node_types, dtype=object)[:, 1]
else:
self._node_types = node_types
class HeterGraph(object): class HeterGraph(object):
"""Implementation of heterogeneous graph structure in pgl """Implementation of heterogeneous graph structure in pgl
...@@ -102,6 +77,16 @@ class HeterGraph(object): ...@@ -102,6 +77,16 @@ class HeterGraph(object):
self._num_nodes = num_nodes self._num_nodes = num_nodes
self._edges_dict = edges self._edges_dict = edges
if isinstance(node_types, list):
self._node_types = np.array(node_types, dtype=object)[:, 1]
else:
self._node_types = node_types
self._nodes_type_dict = {}
for n_type in np.unique(self._node_types):
self._nodes_type_dict[n_type] = np.where(
self._node_types == n_type)[0]
if node_feat is not None: if node_feat is not None:
self._node_feat = node_feat self._node_feat = node_feat
else: else:
...@@ -113,30 +98,262 @@ class HeterGraph(object): ...@@ -113,30 +98,262 @@ class HeterGraph(object):
self._edge_feat = {} self._edge_feat = {}
self._multi_graph = {} self._multi_graph = {}
for key, value in self._edges_dict.items(): for key, value in self._edges_dict.items():
if not self._edge_feat: if not self._edge_feat:
edge_feat = None edge_feat = None
else: else:
edge_feat = self._edge_feat[key] edge_feat = self._edge_feat[key]
self._multi_graph[key] = NodeGraph( self._multi_graph[key] = Graph(
num_nodes=self._num_nodes, num_nodes=self._num_nodes,
edges=value, edges=value,
node_types=node_types,
node_feat=self._node_feat, node_feat=self._node_feat,
edge_feat=edge_feat) edge_feat=edge_feat)
self._edge_types = self.edge_types_info()
@property
def edge_types(self):
"""Return a list of edge types.
"""
return self._edge_types
@property @property
def num_nodes(self): def num_nodes(self):
"""Return the number of nodes. """Return the number of nodes.
""" """
return self._num_nodes return self._num_nodes
@property
def num_edges(self):
"""Return edges number of all edge types.
"""
n_edges = {}
for e_type in self._edge_types:
n_edges[e_type] = self._multi_graph[e_type].num_edges
return n_edges
@property
def node_types(self):
"""Return the node types.
"""
return self._node_types
@property
def edge_feat(self, edge_type=None):
"""Return edge features of all edge types.
"""
return self._edge_feat
@property
def node_feat(self):
"""Return a dictionary of node features.
"""
return self._node_feat
@property
def nodes(self):
"""Return all nodes id from 0 to :code:`num_nodes - 1`
"""
return np.arange(self._num_nodes, dtype='int64')
def __getitem__(self, edge_type): def __getitem__(self, edge_type):
"""__getitem__ """__getitem__
""" """
return self._multi_graph[edge_type] return self._multi_graph[edge_type]
def num_nodes_by_type(self, n_type=None):
"""Return the number of nodes with the specified node type.
"""
if n_type not in self._nodes_type_dict:
raise ("%s is not in valid node type" % n_type)
else:
return len(self._nodes_type_dict[n_type])
def indegree(self, nodes=None, edge_type=None):
"""Return the indegree of the given nodes with the specified edge_type.
Args:
nodes: Return the indegree of given nodes.
if nodes is None, return indegree for all nodes.
edge_types: Return the indegree with specified edge_type.
if edge_type is None, return the total indegree of the given nodes.
Return:
A numpy.ndarray as the given nodes' indegree.
"""
if edge_type is None:
indegrees = []
for e_type in self._edge_types:
indegrees.append(self._multi_graph[e_type].indegree(nodes))
indegrees = np.sum(np.vstack(indegrees), axis=0)
return indegrees
else:
return self._multi_graph[edge_type].indegree(nodes)
def outdegree(self, nodes=None, edge_type=None):
"""Return the outdegree of the given nodes with the specified edge_type.
Args:
nodes: Return the outdegree of given nodes,
if nodes is None, return outdegree for all nodes
edge_types: Return the outdegree with specified edge_type.
if edge_type is None, return the total outdegree of the given nodes.
Return:
A numpy.array as the given nodes' outdegree.
"""
if edge_type is None:
outdegrees = []
for e_type in self._edge_types:
outdegrees.append(self._multi_graph[e_type].outdegree(nodes))
outdegrees = np.sum(np.vstack(outdegrees), axis=0)
return outdegrees
else:
return self._multi_graph[edge_type].outdegree(nodes)
def successor(self, edge_type, nodes=None, return_eids=False):
"""Find successor of given nodes with the specified edge_type.
Args:
nodes: Return the successor of given nodes,
if nodes is None, return successor for all nodes
edge_types: Return the successor with specified edge_type.
if edge_type is None, return the total successor of the given nodes
and eids are invalid in this way.
return_eids: If True return nodes together with corresponding eid
"""
return self._multi_graph[edge_type].successor(nodes, return_eids)
def sample_successor(self,
edge_type,
nodes,
max_degree,
return_eids=False,
shuffle=False):
"""Sample successors of given nodes with the specified edge_type.
Args:
edge_type: The specified edge_type.
nodes: Given nodes whose successors will be sampled.
max_degree: The max sampled successors for each nodes.
return_eids: Whether to return the corresponding eids.
Return:
Return a list of numpy.ndarray and each numpy.ndarray represent a list
of sampled successor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
connected nodes to their successors.
"""
return self._multi_graph[edge_type].sample_successor(
nodes=nodes,
max_degree=max_degree,
return_eids=return_eids,
shuffle=shuffle)
def predecessor(self, edge_type, nodes=None, return_eids=False):
"""Find predecessor of given nodes with the specified edge_type.
Args:
nodes: Return the predecessor of given nodes,
if nodes is None, return predecessor for all nodes
edge_types: Return the predecessor with specified edge_type.
return_eids: If True return nodes together with corresponding eid
"""
return self._multi_graph[edge_type].predecessor(nodes, return_eids)
def sample_predecessor(self,
edge_type,
nodes,
max_degree,
return_eids=False,
shuffle=False):
"""Sample predecessors of given nodes with the specified edge_type.
Args:
edge_type: The specified edge_type.
nodes: Given nodes whose predecessors will be sampled.
max_degree: The max sampled predecessors for each nodes.
return_eids: Whether to return the corresponding eids.
Return:
Return a list of numpy.ndarray and each numpy.ndarray represent a list
of sampled predecessor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
connected nodes to their predecessors.
"""
return self._multi_graph[edge_type].sample_predecessor(
nodes=nodes,
max_degree=max_degree,
return_eids=return_eids,
shuffle=shuffle)
def node_batch_iter(self, batch_size, shuffle=True, n_type=None):
"""Node batch iterator
Iterate all nodes by batch with the specified node type.
Args:
batch_size: The batch size of each batch of nodes.
shuffle: Whether shuffle the nodes.
n_type: Iterate the nodes with the specified node type. If n_type is None,
iterate all nodes by batch.
Return:
Batch iterator
"""
if n_type is None:
nodes = np.arange(self._num_nodes, dtype="int64")
else:
nodes = self._nodes_type_dict[n_type]
if shuffle:
np.random.shuffle(nodes)
start = 0
while start < len(nodes):
yield nodes[start:start + batch_size]
start += batch_size
def sample_nodes(self, sample_num, n_type=None):
"""Sample nodes with the specified n_type from the graph
This function helps to sample nodes with the specified n_type from the graph.
If n_type is None, this function will sample nodes from all nodes.
Nodes might be duplicated.
Args:
sample_num: The number of samples
n_type: The nodes of type to be sampled
Return:
A list of nodes
"""
if n_type is not None:
return np.random.choice(
self._nodes_type_dict[n_type], size=sample_num)
else:
return np.random.randint(
low=0, high=self._num_nodes, size=sample_num)
def node_feat_info(self): def node_feat_info(self):
"""Return the information of node feature for HeterGraphWrapper. """Return the information of node feature for HeterGraphWrapper.
...@@ -186,3 +403,60 @@ class HeterGraph(object): ...@@ -186,3 +403,60 @@ class HeterGraph(object):
edge_types_info.append(key) edge_types_info.append(key)
return edge_types_info return edge_types_info
class SubHeterGraph(HeterGraph):
"""Implementation of SubHeterGraph in pgl.
SubHeterGraph is inherit from :code:`HeterGraph`.
Args:
num_nodes: number of nodes in a heterogeneous graph
edges: dict, every element in dict is a list of (u, v) tuples.
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of dict as edge features for every edge type
reindex: A dictionary that maps parent hetergraph node id to subhetergraph node id.
"""
def __init__(self,
num_nodes,
edges,
node_types=None,
node_feat=None,
edge_feat=None,
reindex=None):
super(SubHeterGraph, self).__init__(
num_nodes=num_nodes,
edges=edges,
node_types=node_types,
node_feat=node_feat,
edge_feat=edge_feat)
if reindex is None:
reindex = {}
self._from_reindex = reindex
self._to_reindex = {u: v for v, u in reindex.items()}
def reindex_from_parrent_nodes(self, nodes):
"""Map the given parent graph node id to subgraph id.
Args:
nodes: A list of nodes from parent graph.
Return:
A list of subgraph ids.
"""
return graph_kernel.map_nodes(nodes, self._from_reindex)
def reindex_to_parrent_nodes(self, nodes):
"""Map the given subgraph node id to parent graph id.
Args:
nodes: A list of nodes in this subgraph.
Return:
A list of node ids in parent graph.
"""
return graph_kernel.map_nodes(nodes, self._to_reindex)
...@@ -256,43 +256,64 @@ def node2vec_sample(graph, nodes, max_depth, p=1.0, q=1.0): ...@@ -256,43 +256,64 @@ def node2vec_sample(graph, nodes, max_depth, p=1.0, q=1.0):
return walk return walk
def metapath_randomwalk(graph, start_node, metapath, walk_length): def metapath_randomwalk(graph,
start_nodes,
metapath,
walk_length,
alias_name=None,
events_name=None):
"""Implementation of metapath random walk in heterogeneous graph. """Implementation of metapath random walk in heterogeneous graph.
Args: Args:
graph: instance of pgl heterogeneous graph graph: instance of pgl heterogeneous graph
start_node: start node to generate walk start_nodes: start nodes to generate walk
metapath: meta path for sample nodes. metapath: meta path for sample nodes.
e.g: "user-item-user" e.g: "c2p-p2a-a2p-p2c"
walk_length: the walk length walk_length: the walk length
Return: Return:
a list of metapath walk, each element is a node id. a list of metapath walks.
""" """
np.random.seed()
edge_types = metapath.split('-')
walk = [] walk = []
metapath = metapath.split('-') for node in start_nodes:
assert metapath[0] == metapath[ walk.append([node])
-1], "The last meta path item should be the same as the first one"
mp_len = len(metapath) - 1 cur_walk_ids = np.arange(0, len(start_nodes))
cur_nodes = np.array(start_nodes)
walk.append(start_node) mp_len = len(edge_types)
for i in range(1, walk_length): for i in range(0, walk_length - 1):
cur_node = walk[-1] g = graph[edge_types[i % mp_len]]
succs = graph.successor(cur_node)
if succs.size > 0: cur_succs = g.successor(cur_nodes)
succs_node_types = graph._node_types[succs] mask = [len(succ) > 0 for succ in cur_succs]
if np.any(mask):
cur_walk_ids = cur_walk_ids[mask]
cur_nodes = cur_nodes[mask]
cur_succs = cur_succs[mask]
else: else:
# no successor of current node # stop when all nodes have no successor
break break
succs_nodes = succs[np.where(succs_node_types == metapath[i % mp_len])[ if alias_name is not None and events_name is not None:
0]] sample_index = [
if succs_nodes.size > 0: alias_sample([1], g.node_feat[alias_name][node],
walk.append(np.random.choice(succs_nodes)) g.node_feat[events_name][node])[0]
for node in cur_nodes
]
else: else:
# no successor of such node type outdegree = [len(cur_succ) for cur_succ in cur_succs]
break sample_index = np.floor(
np.random.rand(cur_succs.shape[0]) * outdegree).astype("int64")
nxt_cur_nodes = []
for s, ind, walk_id in zip(cur_succs, sample_index, cur_walk_ids):
walk[walk_id].append(s[ind])
nxt_cur_nodes.append(s[ind])
cur_nodes = np.array(nxt_cur_nodes)
return walk return walk
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册