提交 cfd11f59 编写于 作者: L liweibin

increase the sampling speed of metapath random walk

上级 0fb10359
......@@ -40,8 +40,10 @@ class Dataset(object):
def __init__(self, config):
self.config = config
self.walk_files = config['input_path'] + config['walk_path']
self.word2id_file = config['input_path'] + config['word2id_file']
self.walk_files = os.path.join(config['input_path'],
config['walk_path'])
self.word2id_file = os.path.join(config['input_path'],
config['word2id_file'])
self.word2freq = {}
self.word2id = {}
......@@ -65,7 +67,7 @@ class Dataset(object):
for walk_file in glob.glob(self.walk_files):
with open(walk_file, 'r') as reader:
for walk in reader:
walk = walk.strip().split(' ')
walk = walk.strip().split()
if len(walk) > 1:
self.sentences_count += 1
for word in walk:
......@@ -123,7 +125,7 @@ class Dataset(object):
for filename in walkpath_files:
with open(filename) as reader:
for line in reader:
words = line.strip().split(' ')
words = line.strip().split()
if len(words) > 1:
word_ids = [
self.word2id[w] for w in words if w in self.word2id
......
......@@ -13,9 +13,12 @@ sampler:
new_author_label_file: author_label.txt
new_venue_label_file: venue_label.txt
walk_saved_path: walks/
walk_batch_size: 1000
num_walks: 1000
walk_length: 100
metapath: conf-paper-author-paper-conf
num_sample_workers: 16
first_node_type: conf
metapath: c2p-p2a-a2p-p2c #conf-paper-author-paper-conf
optimizer:
type: Adam
......
......@@ -101,7 +101,7 @@ class SkipgramModel(object):
pos_score = fl.squeeze(pos_logits, axes=[1])
pos_score = fl.clip(pos_score, min=-10, max=10)
pos_score = -1.0 * fl.logsigmoid(pos_score)
pos_score = -self.neg_num * fl.logsigmoid(pos_score)
neg_logits = fl.matmul(
embed_src, weight_negs,
......@@ -111,4 +111,4 @@ class SkipgramModel(object):
neg_score = -1.0 * fl.logsigmoid(-1.0 * neg_score)
neg_score = fl.reduce_sum(neg_score, dim=1, keep_dim=True)
self.loss = fl.reduce_mean(pos_score + neg_score)
self.loss = fl.reduce_mean(pos_score + neg_score) / self.neg_num / 2
......@@ -18,6 +18,7 @@ training metapath2vec model.
import multiprocessing
from multiprocessing import Pool
from multiprocessing import Process
import argparse
import sys
import os
......@@ -77,9 +78,14 @@ class Sampler(object):
self.config['data_path'] + 'paper_conf.txt', self.paper_id2index,
self.conf_id2index)
edges_by_types['edge'] = paper_author_edges + paper_conf_edges
logging.info('%d edges have been loaded.' %
(len(edges_by_types['edge'])))
# edges_by_types['edge'] = paper_author_edges + paper_conf_edges
edges_by_types['p2c'] = paper_conf_edges
edges_by_types['c2p'] = [(dst, src) for src, dst in paper_conf_edges]
edges_by_types['p2a'] = paper_author_edges
edges_by_types['a2p'] = [(dst, src) for src, dst in paper_author_edges]
# logging.info('%d edges have been loaded.' %
# (len(edges_by_types['edge'])))
node_features = {
'index': np.array([i for i in range(num_nodes)]).reshape(
......@@ -110,7 +116,7 @@ class Sampler(object):
return id2index, name2index, node_types
def load_edges(self, file_, src2index, dst2index, symmetry=True):
def load_edges(self, file_, src2index, dst2index, symmetry=False):
"""Load edges from file.
"""
edges = []
......@@ -143,41 +149,65 @@ class Sampler(object):
return index_label_list
def generate_walks(args):
"""Generate metapath random walk and save to file.
def walk_generator(graph, batch_size, metapath, n_type, walk_length):
"""Generate metapath random walk.
"""
g, meta_path, filename, walk_length = args
walks = []
node_types = g._node_types
first_type = meta_path.split('-')[0]
nodes = np.where(node_types == first_type)[0]
if len(nodes) > 4000:
nodes = np.random.choice(nodes, 4000, replace=False)
logging.info('%d number of start nodes' % (len(nodes)))
logging.info('save walks in file: %s' % (filename))
np.random.seed(os.getpid())
while True:
for start_nodes in graph.node_batch_iter(
batch_size=batch_size, n_type=n_type):
walks = metapath_randomwalk(
graph=graph,
start_nodes=start_nodes,
metapath=metapath,
walk_length=walk_length)
yield walks
def walk_to_files(g, batch_size, metapath, n_type, walk_length, max_num,
filename):
"""Generate metapath randomwalk and save in files"""
# g, batch_size, metapath, n_type, walk_length, max_num, filename = args
with open(filename, 'w') as writer:
for start_node in nodes:
walk = metapath_randomwalk(g, start_node, meta_path, walk_length)
walk = [str(walk[i]) for i in range(0, len(walk), 2)] # skip paper
writer.write(' '.join(walk) + '\n')
cc = 0
for walks in walk_generator(g, batch_size, metapath, n_type,
walk_length):
for walk in walks:
writer.write("%s\n" % "\t".join([str(i) for i in walk]))
cc += 1
if cc == max_num:
return
return
def multiprocess_generate_walks_to_files(graph, n_type, meta_path, num_walks,
walk_length, batch_size,
num_sample_workers, saved_path):
"""Use multiprocess to generate metapath random walk to files.
"""
num_nodes_by_type = graph.num_nodes_by_type(n_type)
logging.info("num_nodes_by_type: %s" % num_nodes_by_type)
max_num = (num_walks * num_nodes_by_type // num_sample_workers) + 1
logging.info("max sample number of every worker: %s" % max_num)
def multiprocess_generate_walks(sampler, edge_type, meta_path, num_walks,
walk_length, saved_path):
"""Use multiprocess to generate metapath random walk.
"""
args = []
for i in range(num_walks):
filename = saved_path + '%04d' % (i)
args.append(
(sampler.graph[edge_type], meta_path, filename, walk_length))
pool = Pool(16)
pool.map(generate_walks, args)
pool.close()
pool.join()
for i in range(num_sample_workers):
filename = os.path.join(saved_path, 'part-%05d' % (i))
args.append((graph, batch_size, meta_path, n_type, walk_length,
max_num, filename))
ps = []
for i in range(num_sample_workers):
p = Process(target=walk_to_files, args=args[i])
p.start()
ps.append(p)
for i in range(num_sample_workers):
ps[i].join()
# pool = Pool(num_sample_workers)
# pool.map(walk_to_files, args)
# pool.close()
# pool.join()
if __name__ == "__main__":
......@@ -220,13 +250,15 @@ if __name__ == "__main__":
begin = time.time()
logging.info('multi process sampling')
multiprocess_generate_walks(
sampler=sampler,
edge_type='edge',
multiprocess_generate_walks_to_files(
graph=sampler.graph,
n_type=config['first_node_type'],
meta_path=config['metapath'],
num_walks=config['num_walks'],
walk_length=config['walk_length'],
saved_path=config['walk_saved_path'])
batch_size=config['walk_batch_size'],
num_sample_workers=config['num_sample_workers'],
saved_path=config['walk_saved_path'], )
logging.info('total time: %.4f' % (time.time() - begin))
logging.info('generating multi class data')
......
......@@ -21,7 +21,7 @@ import time
import pgl.graph_kernel as graph_kernel
from pgl.graph import Graph
__all__ = ['HeterGraph']
__all__ = ['HeterGraph', 'SubHeterGraph']
def _hide_num_nodes(shape):
......@@ -32,31 +32,6 @@ def _hide_num_nodes(shape):
return shape
class NodeGraph(Graph):
"""Implementation of a graph that has multple node types.
Args:
num_nodes: number of nodes in the graph
edges: list of (u, v) tuples
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of numpy array as edge features
"""
def __init__(self,
num_nodes,
edges,
node_types=None,
node_feat=None,
edge_feat=None):
super(NodeGraph, self).__init__(num_nodes, edges, node_feat, edge_feat)
if isinstance(node_types, list):
self._node_types = np.array(node_types, dtype=object)[:, 1]
else:
self._node_types = node_types
class HeterGraph(object):
"""Implementation of heterogeneous graph structure in pgl
......@@ -102,6 +77,16 @@ class HeterGraph(object):
self._num_nodes = num_nodes
self._edges_dict = edges
if isinstance(node_types, list):
self._node_types = np.array(node_types, dtype=object)[:, 1]
else:
self._node_types = node_types
self._nodes_type_dict = {}
for n_type in np.unique(self._node_types):
self._nodes_type_dict[n_type] = np.where(
self._node_types == n_type)[0]
if node_feat is not None:
self._node_feat = node_feat
else:
......@@ -113,30 +98,262 @@ class HeterGraph(object):
self._edge_feat = {}
self._multi_graph = {}
for key, value in self._edges_dict.items():
if not self._edge_feat:
edge_feat = None
else:
edge_feat = self._edge_feat[key]
self._multi_graph[key] = NodeGraph(
self._multi_graph[key] = Graph(
num_nodes=self._num_nodes,
edges=value,
node_types=node_types,
node_feat=self._node_feat,
edge_feat=edge_feat)
self._edge_types = self.edge_types_info()
@property
def edge_types(self):
"""Return a list of edge types.
"""
return self._edge_types
@property
def num_nodes(self):
"""Return the number of nodes.
"""
return self._num_nodes
@property
def num_edges(self):
"""Return edges number of all edge types.
"""
n_edges = {}
for e_type in self._edge_types:
n_edges[e_type] = self._multi_graph[e_type].num_edges
return n_edges
@property
def node_types(self):
"""Return the node types.
"""
return self._node_types
@property
def edge_feat(self, edge_type=None):
"""Return edge features of all edge types.
"""
return self._edge_feat
@property
def node_feat(self):
"""Return a dictionary of node features.
"""
return self._node_feat
@property
def nodes(self):
"""Return all nodes id from 0 to :code:`num_nodes - 1`
"""
return np.arange(self._num_nodes, dtype='int64')
def __getitem__(self, edge_type):
"""__getitem__
"""
return self._multi_graph[edge_type]
def num_nodes_by_type(self, n_type=None):
"""Return the number of nodes with the specified node type.
"""
if n_type not in self._nodes_type_dict:
raise ("%s is not in valid node type" % n_type)
else:
return len(self._nodes_type_dict[n_type])
def indegree(self, nodes=None, edge_type=None):
"""Return the indegree of the given nodes with the specified edge_type.
Args:
nodes: Return the indegree of given nodes.
if nodes is None, return indegree for all nodes.
edge_types: Return the indegree with specified edge_type.
if edge_type is None, return the total indegree of the given nodes.
Return:
A numpy.ndarray as the given nodes' indegree.
"""
if edge_type is None:
indegrees = []
for e_type in self._edge_types:
indegrees.append(self._multi_graph[e_type].indegree(nodes))
indegrees = np.sum(np.vstack(indegrees), axis=0)
return indegrees
else:
return self._multi_graph[edge_type].indegree(nodes)
def outdegree(self, nodes=None, edge_type=None):
"""Return the outdegree of the given nodes with the specified edge_type.
Args:
nodes: Return the outdegree of given nodes,
if nodes is None, return outdegree for all nodes
edge_types: Return the outdegree with specified edge_type.
if edge_type is None, return the total outdegree of the given nodes.
Return:
A numpy.array as the given nodes' outdegree.
"""
if edge_type is None:
outdegrees = []
for e_type in self._edge_types:
outdegrees.append(self._multi_graph[e_type].outdegree(nodes))
outdegrees = np.sum(np.vstack(outdegrees), axis=0)
return outdegrees
else:
return self._multi_graph[edge_type].outdegree(nodes)
def successor(self, edge_type, nodes=None, return_eids=False):
"""Find successor of given nodes with the specified edge_type.
Args:
nodes: Return the successor of given nodes,
if nodes is None, return successor for all nodes
edge_types: Return the successor with specified edge_type.
if edge_type is None, return the total successor of the given nodes
and eids are invalid in this way.
return_eids: If True return nodes together with corresponding eid
"""
return self._multi_graph[edge_type].successor(nodes, return_eids)
def sample_successor(self,
edge_type,
nodes,
max_degree,
return_eids=False,
shuffle=False):
"""Sample successors of given nodes with the specified edge_type.
Args:
edge_type: The specified edge_type.
nodes: Given nodes whose successors will be sampled.
max_degree: The max sampled successors for each nodes.
return_eids: Whether to return the corresponding eids.
Return:
Return a list of numpy.ndarray and each numpy.ndarray represent a list
of sampled successor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
connected nodes to their successors.
"""
return self._multi_graph[edge_type].sample_successor(
nodes=nodes,
max_degree=max_degree,
return_eids=return_eids,
shuffle=shuffle)
def predecessor(self, edge_type, nodes=None, return_eids=False):
"""Find predecessor of given nodes with the specified edge_type.
Args:
nodes: Return the predecessor of given nodes,
if nodes is None, return predecessor for all nodes
edge_types: Return the predecessor with specified edge_type.
return_eids: If True return nodes together with corresponding eid
"""
return self._multi_graph[edge_type].predecessor(nodes, return_eids)
def sample_predecessor(self,
edge_type,
nodes,
max_degree,
return_eids=False,
shuffle=False):
"""Sample predecessors of given nodes with the specified edge_type.
Args:
edge_type: The specified edge_type.
nodes: Given nodes whose predecessors will be sampled.
max_degree: The max sampled predecessors for each nodes.
return_eids: Whether to return the corresponding eids.
Return:
Return a list of numpy.ndarray and each numpy.ndarray represent a list
of sampled predecessor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
connected nodes to their predecessors.
"""
return self._multi_graph[edge_type].sample_predecessor(
nodes=nodes,
max_degree=max_degree,
return_eids=return_eids,
shuffle=shuffle)
def node_batch_iter(self, batch_size, shuffle=True, n_type=None):
"""Node batch iterator
Iterate all nodes by batch with the specified node type.
Args:
batch_size: The batch size of each batch of nodes.
shuffle: Whether shuffle the nodes.
n_type: Iterate the nodes with the specified node type. If n_type is None,
iterate all nodes by batch.
Return:
Batch iterator
"""
if n_type is None:
nodes = np.arange(self._num_nodes, dtype="int64")
else:
nodes = self._nodes_type_dict[n_type]
if shuffle:
np.random.shuffle(nodes)
start = 0
while start < len(nodes):
yield nodes[start:start + batch_size]
start += batch_size
def sample_nodes(self, sample_num, n_type=None):
"""Sample nodes with the specified n_type from the graph
This function helps to sample nodes with the specified n_type from the graph.
If n_type is None, this function will sample nodes from all nodes.
Nodes might be duplicated.
Args:
sample_num: The number of samples
n_type: The nodes of type to be sampled
Return:
A list of nodes
"""
if n_type is not None:
return np.random.choice(
self._nodes_type_dict[n_type], size=sample_num)
else:
return np.random.randint(
low=0, high=self._num_nodes, size=sample_num)
def node_feat_info(self):
"""Return the information of node feature for HeterGraphWrapper.
......@@ -186,3 +403,60 @@ class HeterGraph(object):
edge_types_info.append(key)
return edge_types_info
class SubHeterGraph(HeterGraph):
"""Implementation of SubHeterGraph in pgl.
SubHeterGraph is inherit from :code:`HeterGraph`.
Args:
num_nodes: number of nodes in a heterogeneous graph
edges: dict, every element in dict is a list of (u, v) tuples.
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of dict as edge features for every edge type
reindex: A dictionary that maps parent hetergraph node id to subhetergraph node id.
"""
def __init__(self,
num_nodes,
edges,
node_types=None,
node_feat=None,
edge_feat=None,
reindex=None):
super(SubHeterGraph, self).__init__(
num_nodes=num_nodes,
edges=edges,
node_types=node_types,
node_feat=node_feat,
edge_feat=edge_feat)
if reindex is None:
reindex = {}
self._from_reindex = reindex
self._to_reindex = {u: v for v, u in reindex.items()}
def reindex_from_parrent_nodes(self, nodes):
"""Map the given parent graph node id to subgraph id.
Args:
nodes: A list of nodes from parent graph.
Return:
A list of subgraph ids.
"""
return graph_kernel.map_nodes(nodes, self._from_reindex)
def reindex_to_parrent_nodes(self, nodes):
"""Map the given subgraph node id to parent graph id.
Args:
nodes: A list of nodes in this subgraph.
Return:
A list of node ids in parent graph.
"""
return graph_kernel.map_nodes(nodes, self._to_reindex)
......@@ -256,43 +256,64 @@ def node2vec_sample(graph, nodes, max_depth, p=1.0, q=1.0):
return walk
def metapath_randomwalk(graph, start_node, metapath, walk_length):
def metapath_randomwalk(graph,
start_nodes,
metapath,
walk_length,
alias_name=None,
events_name=None):
"""Implementation of metapath random walk in heterogeneous graph.
Args:
graph: instance of pgl heterogeneous graph
start_node: start node to generate walk
start_nodes: start nodes to generate walk
metapath: meta path for sample nodes.
e.g: "user-item-user"
e.g: "c2p-p2a-a2p-p2c"
walk_length: the walk length
Return:
a list of metapath walk, each element is a node id.
a list of metapath walks.
"""
np.random.seed()
edge_types = metapath.split('-')
walk = []
metapath = metapath.split('-')
assert metapath[0] == metapath[
-1], "The last meta path item should be the same as the first one"
mp_len = len(metapath) - 1
walk.append(start_node)
for i in range(1, walk_length):
cur_node = walk[-1]
succs = graph.successor(cur_node)
if succs.size > 0:
succs_node_types = graph._node_types[succs]
for node in start_nodes:
walk.append([node])
cur_walk_ids = np.arange(0, len(start_nodes))
cur_nodes = np.array(start_nodes)
mp_len = len(edge_types)
for i in range(0, walk_length - 1):
g = graph[edge_types[i % mp_len]]
cur_succs = g.successor(cur_nodes)
mask = [len(succ) > 0 for succ in cur_succs]
if np.any(mask):
cur_walk_ids = cur_walk_ids[mask]
cur_nodes = cur_nodes[mask]
cur_succs = cur_succs[mask]
else:
# no successor of current node
# stop when all nodes have no successor
break
succs_nodes = succs[np.where(succs_node_types == metapath[i % mp_len])[
0]]
if succs_nodes.size > 0:
walk.append(np.random.choice(succs_nodes))
if alias_name is not None and events_name is not None:
sample_index = [
alias_sample([1], g.node_feat[alias_name][node],
g.node_feat[events_name][node])[0]
for node in cur_nodes
]
else:
# no successor of such node type
break
outdegree = [len(cur_succ) for cur_succ in cur_succs]
sample_index = np.floor(
np.random.rand(cur_succs.shape[0]) * outdegree).astype("int64")
nxt_cur_nodes = []
for s, ind, walk_id in zip(cur_succs, sample_index, cur_walk_ids):
walk[walk_id].append(s[ind])
nxt_cur_nodes.append(s[ind])
cur_nodes = np.array(nxt_cur_nodes)
return walk
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册