未验证 提交 c87716c4 编写于 作者: H Huang Zhengjie 提交者: GitHub

Merge pull request #22 from Liwb5/develop

Develop
...@@ -195,7 +195,7 @@ def run_epoch(batch_iter, ...@@ -195,7 +195,7 @@ def run_epoch(batch_iter,
if num_trainer > 1: if num_trainer > 1:
num_samples = sum( num_samples = sum(
[len(batch["node_index"]) for batch in batch_feed_dict]) [len(_batch["node_index"]) for _batch in batch_feed_dict])
else: else:
num_samples = len(batch_feed_dict["node_index"]) num_samples = len(batch_feed_dict["node_index"])
total_loss += batch_loss * num_samples total_loss += batch_loss * num_samples
......
...@@ -40,8 +40,10 @@ class Dataset(object): ...@@ -40,8 +40,10 @@ class Dataset(object):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.walk_files = config['input_path'] + config['walk_path'] self.walk_files = os.path.join(config['input_path'],
self.word2id_file = config['input_path'] + config['word2id_file'] config['walk_path'])
self.word2id_file = os.path.join(config['input_path'],
config['word2id_file'])
self.word2freq = {} self.word2freq = {}
self.word2id = {} self.word2id = {}
...@@ -65,7 +67,7 @@ class Dataset(object): ...@@ -65,7 +67,7 @@ class Dataset(object):
for walk_file in glob.glob(self.walk_files): for walk_file in glob.glob(self.walk_files):
with open(walk_file, 'r') as reader: with open(walk_file, 'r') as reader:
for walk in reader: for walk in reader:
walk = walk.strip().split(' ') walk = walk.strip().split()
if len(walk) > 1: if len(walk) > 1:
self.sentences_count += 1 self.sentences_count += 1
for word in walk: for word in walk:
...@@ -123,7 +125,7 @@ class Dataset(object): ...@@ -123,7 +125,7 @@ class Dataset(object):
for filename in walkpath_files: for filename in walkpath_files:
with open(filename) as reader: with open(filename) as reader:
for line in reader: for line in reader:
words = line.strip().split(' ') words = line.strip().split()
if len(words) > 1: if len(words) > 1:
word_ids = [ word_ids = [
self.word2id[w] for w in words if w in self.word2id self.word2id[w] for w in words if w in self.word2id
......
...@@ -13,9 +13,12 @@ sampler: ...@@ -13,9 +13,12 @@ sampler:
new_author_label_file: author_label.txt new_author_label_file: author_label.txt
new_venue_label_file: venue_label.txt new_venue_label_file: venue_label.txt
walk_saved_path: walks/ walk_saved_path: walks/
walk_batch_size: 1000
num_walks: 1000 num_walks: 1000
walk_length: 100 walk_length: 100
metapath: conf-paper-author-paper-conf num_sample_workers: 16
first_node_type: conf
metapath: c2p-p2a-a2p-p2c #conf-paper-author-paper-conf
optimizer: optimizer:
type: Adam type: Adam
......
...@@ -101,7 +101,7 @@ class SkipgramModel(object): ...@@ -101,7 +101,7 @@ class SkipgramModel(object):
pos_score = fl.squeeze(pos_logits, axes=[1]) pos_score = fl.squeeze(pos_logits, axes=[1])
pos_score = fl.clip(pos_score, min=-10, max=10) pos_score = fl.clip(pos_score, min=-10, max=10)
pos_score = -1.0 * fl.logsigmoid(pos_score) pos_score = -self.neg_num * fl.logsigmoid(pos_score)
neg_logits = fl.matmul( neg_logits = fl.matmul(
embed_src, weight_negs, embed_src, weight_negs,
...@@ -111,4 +111,4 @@ class SkipgramModel(object): ...@@ -111,4 +111,4 @@ class SkipgramModel(object):
neg_score = -1.0 * fl.logsigmoid(-1.0 * neg_score) neg_score = -1.0 * fl.logsigmoid(-1.0 * neg_score)
neg_score = fl.reduce_sum(neg_score, dim=1, keep_dim=True) neg_score = fl.reduce_sum(neg_score, dim=1, keep_dim=True)
self.loss = fl.reduce_mean(pos_score + neg_score) self.loss = fl.reduce_mean(pos_score + neg_score) / self.neg_num / 2
...@@ -18,6 +18,7 @@ training metapath2vec model. ...@@ -18,6 +18,7 @@ training metapath2vec model.
import multiprocessing import multiprocessing
from multiprocessing import Pool from multiprocessing import Pool
from multiprocessing import Process
import argparse import argparse
import sys import sys
import os import os
...@@ -77,9 +78,14 @@ class Sampler(object): ...@@ -77,9 +78,14 @@ class Sampler(object):
self.config['data_path'] + 'paper_conf.txt', self.paper_id2index, self.config['data_path'] + 'paper_conf.txt', self.paper_id2index,
self.conf_id2index) self.conf_id2index)
edges_by_types['edge'] = paper_author_edges + paper_conf_edges # edges_by_types['edge'] = paper_author_edges + paper_conf_edges
logging.info('%d edges have been loaded.' % edges_by_types['p2c'] = paper_conf_edges
(len(edges_by_types['edge']))) edges_by_types['c2p'] = [(dst, src) for src, dst in paper_conf_edges]
edges_by_types['p2a'] = paper_author_edges
edges_by_types['a2p'] = [(dst, src) for src, dst in paper_author_edges]
# logging.info('%d edges have been loaded.' %
# (len(edges_by_types['edge'])))
node_features = { node_features = {
'index': np.array([i for i in range(num_nodes)]).reshape( 'index': np.array([i for i in range(num_nodes)]).reshape(
...@@ -110,7 +116,7 @@ class Sampler(object): ...@@ -110,7 +116,7 @@ class Sampler(object):
return id2index, name2index, node_types return id2index, name2index, node_types
def load_edges(self, file_, src2index, dst2index, symmetry=True): def load_edges(self, file_, src2index, dst2index, symmetry=False):
"""Load edges from file. """Load edges from file.
""" """
edges = [] edges = []
...@@ -143,41 +149,65 @@ class Sampler(object): ...@@ -143,41 +149,65 @@ class Sampler(object):
return index_label_list return index_label_list
def generate_walks(args): def walk_generator(graph, batch_size, metapath, n_type, walk_length):
"""Generate metapath random walk and save to file. """Generate metapath random walk.
""" """
g, meta_path, filename, walk_length = args np.random.seed(os.getpid())
walks = [] while True:
node_types = g._node_types for start_nodes in graph.node_batch_iter(
first_type = meta_path.split('-')[0] batch_size=batch_size, n_type=n_type):
nodes = np.where(node_types == first_type)[0] walks = metapath_randomwalk(
if len(nodes) > 4000: graph=graph,
nodes = np.random.choice(nodes, 4000, replace=False) start_nodes=start_nodes,
metapath=metapath,
logging.info('%d number of start nodes' % (len(nodes))) walk_length=walk_length)
logging.info('save walks in file: %s' % (filename)) yield walks
def walk_to_files(g, batch_size, metapath, n_type, walk_length, max_num,
filename):
"""Generate metapath randomwalk and save in files"""
# g, batch_size, metapath, n_type, walk_length, max_num, filename = args
with open(filename, 'w') as writer: with open(filename, 'w') as writer:
for start_node in nodes: cc = 0
walk = metapath_randomwalk(g, start_node, meta_path, walk_length) for walks in walk_generator(g, batch_size, metapath, n_type,
walk = [str(walk[i]) for i in range(0, len(walk), 2)] # skip paper walk_length):
writer.write(' '.join(walk) + '\n') for walk in walks:
writer.write("%s\n" % "\t".join([str(i) for i in walk]))
cc += 1
if cc == max_num:
return
return
def multiprocess_generate_walks_to_files(graph, n_type, meta_path, num_walks,
walk_length, batch_size,
num_sample_workers, saved_path):
"""Use multiprocess to generate metapath random walk to files.
"""
num_nodes_by_type = graph.num_nodes_by_type(n_type)
logging.info("num_nodes_by_type: %s" % num_nodes_by_type)
max_num = (num_walks * num_nodes_by_type // num_sample_workers) + 1
logging.info("max sample number of every worker: %s" % max_num)
def multiprocess_generate_walks(sampler, edge_type, meta_path, num_walks,
walk_length, saved_path):
"""Use multiprocess to generate metapath random walk.
"""
args = [] args = []
for i in range(num_walks): for i in range(num_sample_workers):
filename = saved_path + '%04d' % (i) filename = os.path.join(saved_path, 'part-%05d' % (i))
args.append( args.append((graph, batch_size, meta_path, n_type, walk_length,
(sampler.graph[edge_type], meta_path, filename, walk_length)) max_num, filename))
pool = Pool(16) ps = []
pool.map(generate_walks, args) for i in range(num_sample_workers):
pool.close() p = Process(target=walk_to_files, args=args[i])
pool.join() p.start()
ps.append(p)
for i in range(num_sample_workers):
ps[i].join()
# pool = Pool(num_sample_workers)
# pool.map(walk_to_files, args)
# pool.close()
# pool.join()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -220,13 +250,15 @@ if __name__ == "__main__": ...@@ -220,13 +250,15 @@ if __name__ == "__main__":
begin = time.time() begin = time.time()
logging.info('multi process sampling') logging.info('multi process sampling')
multiprocess_generate_walks( multiprocess_generate_walks_to_files(
sampler=sampler, graph=sampler.graph,
edge_type='edge', n_type=config['first_node_type'],
meta_path=config['metapath'], meta_path=config['metapath'],
num_walks=config['num_walks'], num_walks=config['num_walks'],
walk_length=config['walk_length'], walk_length=config['walk_length'],
saved_path=config['walk_saved_path']) batch_size=config['walk_batch_size'],
num_sample_workers=config['num_sample_workers'],
saved_path=config['walk_saved_path'], )
logging.info('total time: %.4f' % (time.time() - begin)) logging.info('total time: %.4f' % (time.time() - begin))
logging.info('generating multi class data') logging.info('generating multi class data')
......
...@@ -21,7 +21,7 @@ import time ...@@ -21,7 +21,7 @@ import time
import pgl.graph_kernel as graph_kernel import pgl.graph_kernel as graph_kernel
from pgl.graph import Graph from pgl.graph import Graph
__all__ = ['HeterGraph'] __all__ = ['HeterGraph', 'SubHeterGraph']
def _hide_num_nodes(shape): def _hide_num_nodes(shape):
...@@ -32,31 +32,6 @@ def _hide_num_nodes(shape): ...@@ -32,31 +32,6 @@ def _hide_num_nodes(shape):
return shape return shape
class NodeGraph(Graph):
"""Implementation of a graph that has multple node types.
Args:
num_nodes: number of nodes in the graph
edges: list of (u, v) tuples
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of numpy array as edge features
"""
def __init__(self,
num_nodes,
edges,
node_types=None,
node_feat=None,
edge_feat=None):
super(NodeGraph, self).__init__(num_nodes, edges, node_feat, edge_feat)
if isinstance(node_types, list):
self._node_types = np.array(node_types, dtype=object)[:, 1]
else:
self._node_types = node_types
class HeterGraph(object): class HeterGraph(object):
"""Implementation of heterogeneous graph structure in pgl """Implementation of heterogeneous graph structure in pgl
...@@ -102,6 +77,16 @@ class HeterGraph(object): ...@@ -102,6 +77,16 @@ class HeterGraph(object):
self._num_nodes = num_nodes self._num_nodes = num_nodes
self._edges_dict = edges self._edges_dict = edges
if isinstance(node_types, list):
self._node_types = np.array(node_types, dtype=object)[:, 1]
else:
self._node_types = node_types
self._nodes_type_dict = {}
for n_type in np.unique(self._node_types):
self._nodes_type_dict[n_type] = np.where(
self._node_types == n_type)[0]
if node_feat is not None: if node_feat is not None:
self._node_feat = node_feat self._node_feat = node_feat
else: else:
...@@ -113,30 +98,262 @@ class HeterGraph(object): ...@@ -113,30 +98,262 @@ class HeterGraph(object):
self._edge_feat = {} self._edge_feat = {}
self._multi_graph = {} self._multi_graph = {}
for key, value in self._edges_dict.items(): for key, value in self._edges_dict.items():
if not self._edge_feat: if not self._edge_feat:
edge_feat = None edge_feat = None
else: else:
edge_feat = self._edge_feat[key] edge_feat = self._edge_feat[key]
self._multi_graph[key] = NodeGraph( self._multi_graph[key] = Graph(
num_nodes=self._num_nodes, num_nodes=self._num_nodes,
edges=value, edges=value,
node_types=node_types,
node_feat=self._node_feat, node_feat=self._node_feat,
edge_feat=edge_feat) edge_feat=edge_feat)
self._edge_types = self.edge_types_info()
@property
def edge_types(self):
"""Return a list of edge types.
"""
return self._edge_types
@property @property
def num_nodes(self): def num_nodes(self):
"""Return the number of nodes. """Return the number of nodes.
""" """
return self._num_nodes return self._num_nodes
@property
def num_edges(self):
"""Return edges number of all edge types.
"""
n_edges = {}
for e_type in self._edge_types:
n_edges[e_type] = self._multi_graph[e_type].num_edges
return n_edges
@property
def node_types(self):
"""Return the node types.
"""
return self._node_types
@property
def edge_feat(self, edge_type=None):
"""Return edge features of all edge types.
"""
return self._edge_feat
@property
def node_feat(self):
"""Return a dictionary of node features.
"""
return self._node_feat
@property
def nodes(self):
"""Return all nodes id from 0 to :code:`num_nodes - 1`
"""
return np.arange(self._num_nodes, dtype='int64')
def __getitem__(self, edge_type): def __getitem__(self, edge_type):
"""__getitem__ """__getitem__
""" """
return self._multi_graph[edge_type] return self._multi_graph[edge_type]
def num_nodes_by_type(self, n_type=None):
"""Return the number of nodes with the specified node type.
"""
if n_type not in self._nodes_type_dict:
raise ("%s is not in valid node type" % n_type)
else:
return len(self._nodes_type_dict[n_type])
def indegree(self, nodes=None, edge_type=None):
"""Return the indegree of the given nodes with the specified edge_type.
Args:
nodes: Return the indegree of given nodes.
if nodes is None, return indegree for all nodes.
edge_types: Return the indegree with specified edge_type.
if edge_type is None, return the total indegree of the given nodes.
Return:
A numpy.ndarray as the given nodes' indegree.
"""
if edge_type is None:
indegrees = []
for e_type in self._edge_types:
indegrees.append(self._multi_graph[e_type].indegree(nodes))
indegrees = np.sum(np.vstack(indegrees), axis=0)
return indegrees
else:
return self._multi_graph[edge_type].indegree(nodes)
def outdegree(self, nodes=None, edge_type=None):
"""Return the outdegree of the given nodes with the specified edge_type.
Args:
nodes: Return the outdegree of given nodes,
if nodes is None, return outdegree for all nodes
edge_types: Return the outdegree with specified edge_type.
if edge_type is None, return the total outdegree of the given nodes.
Return:
A numpy.array as the given nodes' outdegree.
"""
if edge_type is None:
outdegrees = []
for e_type in self._edge_types:
outdegrees.append(self._multi_graph[e_type].outdegree(nodes))
outdegrees = np.sum(np.vstack(outdegrees), axis=0)
return outdegrees
else:
return self._multi_graph[edge_type].outdegree(nodes)
def successor(self, edge_type, nodes=None, return_eids=False):
"""Find successor of given nodes with the specified edge_type.
Args:
nodes: Return the successor of given nodes,
if nodes is None, return successor for all nodes
edge_types: Return the successor with specified edge_type.
if edge_type is None, return the total successor of the given nodes
and eids are invalid in this way.
return_eids: If True return nodes together with corresponding eid
"""
return self._multi_graph[edge_type].successor(nodes, return_eids)
def sample_successor(self,
edge_type,
nodes,
max_degree,
return_eids=False,
shuffle=False):
"""Sample successors of given nodes with the specified edge_type.
Args:
edge_type: The specified edge_type.
nodes: Given nodes whose successors will be sampled.
max_degree: The max sampled successors for each nodes.
return_eids: Whether to return the corresponding eids.
Return:
Return a list of numpy.ndarray and each numpy.ndarray represent a list
of sampled successor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
connected nodes to their successors.
"""
return self._multi_graph[edge_type].sample_successor(
nodes=nodes,
max_degree=max_degree,
return_eids=return_eids,
shuffle=shuffle)
def predecessor(self, edge_type, nodes=None, return_eids=False):
"""Find predecessor of given nodes with the specified edge_type.
Args:
nodes: Return the predecessor of given nodes,
if nodes is None, return predecessor for all nodes
edge_types: Return the predecessor with specified edge_type.
return_eids: If True return nodes together with corresponding eid
"""
return self._multi_graph[edge_type].predecessor(nodes, return_eids)
def sample_predecessor(self,
edge_type,
nodes,
max_degree,
return_eids=False,
shuffle=False):
"""Sample predecessors of given nodes with the specified edge_type.
Args:
edge_type: The specified edge_type.
nodes: Given nodes whose predecessors will be sampled.
max_degree: The max sampled predecessors for each nodes.
return_eids: Whether to return the corresponding eids.
Return:
Return a list of numpy.ndarray and each numpy.ndarray represent a list
of sampled predecessor ids for given nodes with specified edge type.
If :code:`return_eids=True`, there will be an additional list of
numpy.ndarray and each numpy.ndarray represent a list of eids that
connected nodes to their predecessors.
"""
return self._multi_graph[edge_type].sample_predecessor(
nodes=nodes,
max_degree=max_degree,
return_eids=return_eids,
shuffle=shuffle)
def node_batch_iter(self, batch_size, shuffle=True, n_type=None):
"""Node batch iterator
Iterate all nodes by batch with the specified node type.
Args:
batch_size: The batch size of each batch of nodes.
shuffle: Whether shuffle the nodes.
n_type: Iterate the nodes with the specified node type. If n_type is None,
iterate all nodes by batch.
Return:
Batch iterator
"""
if n_type is None:
nodes = np.arange(self._num_nodes, dtype="int64")
else:
nodes = self._nodes_type_dict[n_type]
if shuffle:
np.random.shuffle(nodes)
start = 0
while start < len(nodes):
yield nodes[start:start + batch_size]
start += batch_size
def sample_nodes(self, sample_num, n_type=None):
"""Sample nodes with the specified n_type from the graph
This function helps to sample nodes with the specified n_type from the graph.
If n_type is None, this function will sample nodes from all nodes.
Nodes might be duplicated.
Args:
sample_num: The number of samples
n_type: The nodes of type to be sampled
Return:
A list of nodes
"""
if n_type is not None:
return np.random.choice(
self._nodes_type_dict[n_type], size=sample_num)
else:
return np.random.randint(
low=0, high=self._num_nodes, size=sample_num)
def node_feat_info(self): def node_feat_info(self):
"""Return the information of node feature for HeterGraphWrapper. """Return the information of node feature for HeterGraphWrapper.
...@@ -186,3 +403,60 @@ class HeterGraph(object): ...@@ -186,3 +403,60 @@ class HeterGraph(object):
edge_types_info.append(key) edge_types_info.append(key)
return edge_types_info return edge_types_info
class SubHeterGraph(HeterGraph):
"""Implementation of SubHeterGraph in pgl.
SubHeterGraph is inherit from :code:`HeterGraph`.
Args:
num_nodes: number of nodes in a heterogeneous graph
edges: dict, every element in dict is a list of (u, v) tuples.
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of dict as edge features for every edge type
reindex: A dictionary that maps parent hetergraph node id to subhetergraph node id.
"""
def __init__(self,
num_nodes,
edges,
node_types=None,
node_feat=None,
edge_feat=None,
reindex=None):
super(SubHeterGraph, self).__init__(
num_nodes=num_nodes,
edges=edges,
node_types=node_types,
node_feat=node_feat,
edge_feat=edge_feat)
if reindex is None:
reindex = {}
self._from_reindex = reindex
self._to_reindex = {u: v for v, u in reindex.items()}
def reindex_from_parrent_nodes(self, nodes):
"""Map the given parent graph node id to subgraph id.
Args:
nodes: A list of nodes from parent graph.
Return:
A list of subgraph ids.
"""
return graph_kernel.map_nodes(nodes, self._from_reindex)
def reindex_to_parrent_nodes(self, nodes):
"""Map the given subgraph node id to parent graph id.
Args:
nodes: A list of nodes in this subgraph.
Return:
A list of node ids in parent graph.
"""
return graph_kernel.map_nodes(nodes, self._to_reindex)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""redis_hetergraph"""
import pgl
import redis
from redis import BlockingConnectionPool, StrictRedis
from redis._compat import b, unicode, bytes, long, basestring
from rediscluster.nodemanager import NodeManager
from rediscluster.crc import crc16
from collections import OrderedDict
import threading
import numpy as np
import time
import json
import pgl.graph as pgraph
import pickle as pkl
from pgl.utils.logger import log
import pgl.graph_kernel as graph_kernel
from pgl.contrib import heter_graph
import pgl.redis_graph as rg
class RedisHeterGraph(rg.RedisGraph):
"""Redis Heterogeneous Graph"""
def __init__(self, name, edge_types, redis_config, num_parts):
super(RedisHeterGraph, self).__init__(name, redis_config, num_parts)
self._num_edges = {}
self.edge_types = edge_types
self.e_type = None
self._edge_feat_info = {}
self._edge_feat_dtype = {}
self._edge_feat_shape = {}
def num_edges_by_type(self, e_type):
"""get edge number by specified edge type"""
if e_type not in self._num_edges:
self._num_edges[e_type] = int(
self._rs.get("%s:num_edges" % e_type))
return self._num_edges[e_type]
def num_edges(self):
"""num_edges"""
num_edges = {}
for e_type in self.edge_types:
num_edges[e_type] = self.num_edges_by_type(e_type)
return num_edges
def edge_feat_info_by_type(self, e_type):
"""get edge features information by specified edge type"""
if e_type not in self._edge_feat_info:
buff = self._rs.get("%s:ef:infos" % e_type)
if buff is not None:
self._edge_feat_info[e_type] = json.loads(buff.decode())
else:
self._edge_feat_info[e_type] = []
return self._edge_feat_info[e_type]
def edge_feat_info(self):
"""edge_feat_info"""
edge_feat_info = {}
for e_type in self.edge_types:
edge_feat_info[e_type] = self.edge_feat_info_by_type(e_type)
return edge_feat_info
def edge_feat_shape(self, e_type, key):
"""edge_feat_shape"""
if e_type not in self._edge_feat_shape:
e_feat_shape = {}
for k, shape, _ in self.edge_feat_info()[e_type]:
e_feat_shape[k] = shape
self._edge_feat_shape[e_type] = e_feat_shape
return self._edge_feat_shape[e_type][key]
def edge_feat_dtype(self, e_type, key):
"""edge_feat_dtype"""
if e_type not in self._edge_feat_dtype:
e_feat_dtype = {}
for k, _, dtype in self.edge_feat_info()[e_type]:
e_feat_dtype[k] = dtype
self._edge_feat_dtype[e_type] = e_feat_dtype
return self._edge_feat_dtype[e_type][key]
def sample_predecessor(self, e_type, nodes, max_degree, return_eids=False):
"""sample predecessor with the specified edge type"""
query = ["%s:d:%s" % (e_type, n) for n in nodes]
rets = rg.hmget_sample_helper(self._rs, query, self.num_parts,
max_degree)
v = []
eid = []
for buff in rets:
if buff is None:
v.append(np.array([], dtype="int64"))
eid.append(np.array([], dtype="int64"))
else:
npret = np.frombuffer(
buff, dtype="int64").reshape([-1, 2]).astype("int64")
v.append(npret[:, 0])
eid.append(npret[:, 1])
if return_eids:
return np.array(v), np.array(eid)
else:
return np.array(v)
def sample_successor(self, e_type, nodes, max_degree, return_eids=False):
"""sample successor with the specified edge type"""
query = ["%s:s:%s" % (e_type, n) for n in nodes]
rets = rg.hmget_sample_helper(self._rs, query, self.num_parts,
max_degree)
v = []
eid = []
for buff in rets:
if buff is None:
v.append(np.array([], dtype="int64"))
eid.append(np.array([], dtype="int64"))
else:
npret = np.frombuffer(
buff, dtype="int64").reshape([-1, 2]).astype("int64")
v.append(npret[:, 0])
eid.append(npret[:, 1])
if return_eids:
return np.array(v), np.array(eid)
else:
return np.array(v)
def predecessor(self, e_type, nodes, return_eids=False):
"""predecessor with the specified edge type"""
query = ["%s:d:%s" % (e_type, n) for n in nodes]
ret = rg.hmget_helper(self._rs, query, self.num_parts)
v = []
eid = []
for buff in ret:
if buff is not None:
npret = np.frombuffer(
buff, dtype="int64").reshape([-1, 2]).astype("int64")
v.append(npret[:, 0])
eid.append(npret[:, 1])
else:
v.append(np.array([], dtype="int64"))
eid.append(np.array([], dtype="int64"))
if return_eids:
return np.array(v), np.array(eid)
else:
return np.array(v)
def successor(self, e_type, nodes, return_eids=False):
"""successor with the specified edge type"""
query = ["%s:s:%s" % (e_type, n) for n in nodes]
ret = rg.hmget_helper(self._rs, query, self.num_parts)
v = []
eid = []
for buff in ret:
if buff is not None:
npret = np.frombuffer(
buff, dtype="int64").reshape([-1, 2]).astype("int64")
v.append(npret[:, 0])
eid.append(npret[:, 1])
else:
v.append(np.array([], dtype="int64"))
eid.append(np.array([], dtype="int64"))
if return_eids:
return np.array(v), np.array(eid)
else:
return np.array(v)
def get_edges_by_id(self, e_type, eids):
"""get_edges_by_id"""
queries = ["%s:e:%s" % (e_type, e) for e in eids]
ret = rg.hmget_helper(self._rs, queries, self.num_parts)
o = np.asarray(ret, dtype="int64")
dst = o % self.num_nodes
src = o // self.num_nodes
data = np.hstack(
[src.reshape([-1, 1]), dst.reshape([-1, 1])]).astype("int64")
return data
def get_edge_feat_by_id(self, e_type, key, eids):
"""get_edge_feat_by_id"""
queries = ["%s:ef:%s:%i" % (e_type, key, e) for e in eids]
ret = rg.hmget_helper(self._rs, queries, self.num_parts)
if ret is None:
return None
else:
ret = b"".join(ret)
data = np.frombuffer(ret, dtype=self.edge_feat_dtype(e_type, key))
data = data.reshape(self.edge_feat_shape(e_type, key))
return data
def get_node_types(self, nodes):
"""get_node_types """
queries = ["nt:%i" % n for n in nodes]
ret = rg.hmget_helper(self._rs, queries, self.num_parts)
node_types = []
for buff in ret:
if buff:
node_types.append(buff.decode())
else:
node_types = None
return node_types
def subgraph(self, nodes, eid, edges=None):
"""Generate heterogeneous subgraph with nodes and edge ids.
WARNING: ALL NODES IN EID MUST BE INCLUDED BY NODES
Args:
nodes: Node ids which will be included in the subgraph.
eid: Edge ids which will be included in the subgraph.
Return:
A :code:`pgl.heter_graph.Subgraph` object.
"""
reindex = {}
for ind, node in enumerate(nodes):
reindex[node] = ind
_node_types = self.get_node_types(nodes)
if _node_types is None:
node_types = None
else:
node_types = []
for idx, t in zip(nodes, _node_types):
node_types.append([reindex[idx], t])
if edges is None:
edges = {}
for e_type, eid_list in eid.items():
edges[e_type] = self.get_edges_by_id(e_type, eid_list)
sub_edges = {}
for e_type, edges_list in edges.items():
sub_edges[e_type] = graph_kernel.map_edges(
np.arange(
len(edges_list), dtype="int64"), edges_list, reindex)
sub_edge_feat = {}
for e_type, edge_feat_info in self.edge_feat_info().items():
type_edge_feat = {}
for key, _, _ in edge_feat_info:
type_edge_feat[key] = self.get_edge_feat_by_id(e_type, key,
eid)
sub_edge_feat[e_type] = type_edge_feat
sub_node_feat = {}
for key, _, _ in self.node_feat_info():
sub_node_feat[key] = self.get_node_feat_by_id(key, nodes)
subgraph = heter_graph.SubHeterGraph(
num_nodes=len(nodes),
edges=sub_edges,
node_types=node_types,
node_feat=sub_node_feat,
edge_feat=sub_edge_feat,
reindex=reindex)
return subgraph
...@@ -43,8 +43,8 @@ class EdgeIndex(object): ...@@ -43,8 +43,8 @@ class EdgeIndex(object):
""" """
def __init__(self, u, v, num_nodes): def __init__(self, u, v, num_nodes):
self._v, self._eid, self._degree, self._sorted_u,\ self._degree, self._sorted_v, self._sorted_u, \
self._sorted_v, self._sorted_eid = graph_kernel.build_index(u, v, num_nodes) self._sorted_eid, self._indptr = graph_kernel.build_index(u, v, num_nodes)
@property @property
def degree(self): def degree(self):
...@@ -52,17 +52,25 @@ class EdgeIndex(object): ...@@ -52,17 +52,25 @@ class EdgeIndex(object):
""" """
return self._degree return self._degree
@property def view_v(self, u=None):
def v(self): """Return the compressed v for given u.
"""Return the compressed v.
""" """
return self._v if u is None:
return np.split(self._sorted_v, self._indptr[1:])
else:
u = np.array(u, dtype="int64")
return graph_kernel.slice_by_index(
self._sorted_v, self._indptr, index=u)
@property def view_eid(self, u=None):
def eid(self): """Return the compressed edge id for given u.
"""Return the edge id.
""" """
return self._eid if u is None:
return np.split(self._sorted_eid, self._indptr[1:])
else:
u = np.array(u, dtype="int64")
return graph_kernel.slice_by_index(
self._sorted_eid, self._indptr, index=u)
def triples(self): def triples(self):
"""Return the sorted (u, v, eid) tuples. """Return the sorted (u, v, eid) tuples.
...@@ -287,17 +295,11 @@ class Graph(object): ...@@ -287,17 +295,11 @@ class Graph(object):
[]] []]
""" """
if nodes is None: if return_eids:
if return_eids: return self.adj_src_index.view_v(
return self.adj_src_index.v, self.adj_src_index.eid nodes), self.adj_src_index.view_eid(nodes)
else:
return self.adj_src_index.v
else: else:
if return_eids: return self.adj_src_index.view_v(nodes)
return self.adj_src_index.v[nodes], self.adj_src_index.eid[
nodes]
else:
return self.adj_src_index.v[nodes]
def sample_successor(self, def sample_successor(self,
nodes, nodes,
...@@ -385,17 +387,11 @@ class Graph(object): ...@@ -385,17 +387,11 @@ class Graph(object):
[2]] [2]]
""" """
if nodes is None: if return_eids:
if return_eids: return self.adj_dst_index.view_v(
return self.adj_dst_index.v, self.adj_dst_index.eid nodes), self.adj_dst_index.view_eid(nodes)
else:
return self.adj_dst_index.v
else: else:
if return_eids: return self.adj_dst_index.view_v(nodes)
return self.adj_dst_index.v[nodes], self.adj_dst_index.eid[
nodes]
else:
return self.adj_dst_index.v[nodes]
def sample_predecessor(self, def sample_predecessor(self,
nodes, nodes,
......
...@@ -53,14 +53,21 @@ def build_index(np.ndarray[np.int64_t, ndim=1] u, ...@@ -53,14 +53,21 @@ def build_index(np.ndarray[np.int64_t, ndim=1] u,
_tmp_eid[indptr[u[i]] + count[u[i]]] = i _tmp_eid[indptr[u[i]] + count[u[i]]] = i
_tmp_u[indptr[u[i]] + count[u[i]]] = u[i] _tmp_u[indptr[u[i]] + count[u[i]]] = u[i]
count[u[i]] += 1 count[u[i]] += 1
return degree, _tmp_v, _tmp_u, _tmp_eid, indptr
cdef list output_eid = [] @cython.boundscheck(False)
cdef list output_v = [] @cython.wraparound(False)
for i in xrange(n_size): def slice_by_index(np.ndarray[np.int64_t, ndim=1] u,
output_eid.append(_tmp_eid[indptr[i]:indptr[i+1]]) np.ndarray[np.int64_t, ndim=1] indptr,
output_v.append(_tmp_v[indptr[i]:indptr[i+1]]) np.ndarray[np.int64_t, ndim=1] index):
return np.array(output_v), np.array(output_eid), degree, _tmp_u, _tmp_v, _tmp_eid cdef list output = []
cdef long long i
cdef long long h = len(index)
cdef long long j
for i in xrange(h):
j = index[i]
output.append(u[indptr[j]:indptr[j+1]])
return np.array(output)
@cython.boundscheck(False) @cython.boundscheck(False)
@cython.wraparound(False) @cython.wraparound(False)
...@@ -253,22 +260,10 @@ def sample_subset_with_eid(list nids, list eids, long long maxdegree, shuffle=Fa ...@@ -253,22 +260,10 @@ def sample_subset_with_eid(list nids, list eids, long long maxdegree, shuffle=Fa
@cython.boundscheck(False) @cython.boundscheck(False)
@cython.wraparound(False) @cython.wraparound(False)
def skip_gram_gen_pair(vector[long long] walk_path, long win_size=5): def skip_gram_gen_pair(vector[long long] walk, long win_size=5):
"""Return node paris generated by skip-gram algorithm.
This function will auto remove the pair which src node is the same
as dst node.
Args:
walk_path: List of nodes as a walk path.
win_size: the windows size used in skip-gram.
Return:
A tuple of (src node list, dst node list).
"""
cdef vector[long long] src cdef vector[long long] src
cdef vector[long long] dst cdef vector[long long] dst
cdef long long l = len(walk_path) cdef long long l = len(walk)
cdef long long real_win_size, left, right, i cdef long long real_win_size, left, right, i
cdef np.ndarray[np.int64_t, ndim=1] rnd = np.random.randint(1, win_size+1, cdef np.ndarray[np.int64_t, ndim=1] rnd = np.random.randint(1, win_size+1,
dtype=np.int64, size=l) dtype=np.int64, size=l)
...@@ -282,23 +277,15 @@ def skip_gram_gen_pair(vector[long long] walk_path, long win_size=5): ...@@ -282,23 +277,15 @@ def skip_gram_gen_pair(vector[long long] walk_path, long win_size=5):
if right >= l: if right >= l:
right = l - 1 right = l - 1
for j in xrange(left, right+1): for j in xrange(left, right+1):
if walk_path[i] == walk_path[j]: if walk[i] == walk[j]:
continue continue
src.push_back(walk_path[i]) src.push_back(walk[i])
dst.push_back(walk_path[j]) dst.push_back(walk[j])
return src, dst return src, dst
@cython.boundscheck(False) @cython.boundscheck(False)
@cython.wraparound(False) @cython.wraparound(False)
def alias_sample_build_table(np.ndarray[np.float64_t, ndim=1] probs): def alias_sample_build_table(np.ndarray[np.float64_t, ndim=1] probs):
"""Return the alias table and event table for alias sampling.
Args:
porobs: A list of float numbers as the probability.
Return:
A tuple of (alias table, event table).
"""
cdef long long l = len(probs) cdef long long l = len(probs)
cdef np.ndarray[np.float64_t, ndim=1] alias = probs * l cdef np.ndarray[np.float64_t, ndim=1] alias = probs * l
cdef np.ndarray[np.int64_t, ndim=1] events = np.zeros(l, dtype=np.int64) cdef np.ndarray[np.int64_t, ndim=1] events = np.zeros(l, dtype=np.int64)
......
...@@ -256,43 +256,64 @@ def node2vec_sample(graph, nodes, max_depth, p=1.0, q=1.0): ...@@ -256,43 +256,64 @@ def node2vec_sample(graph, nodes, max_depth, p=1.0, q=1.0):
return walk return walk
def metapath_randomwalk(graph, start_node, metapath, walk_length): def metapath_randomwalk(graph,
start_nodes,
metapath,
walk_length,
alias_name=None,
events_name=None):
"""Implementation of metapath random walk in heterogeneous graph. """Implementation of metapath random walk in heterogeneous graph.
Args: Args:
graph: instance of pgl heterogeneous graph graph: instance of pgl heterogeneous graph
start_node: start node to generate walk start_nodes: start nodes to generate walk
metapath: meta path for sample nodes. metapath: meta path for sample nodes.
e.g: "user-item-user" e.g: "c2p-p2a-a2p-p2c"
walk_length: the walk length walk_length: the walk length
Return: Return:
a list of metapath walk, each element is a node id. a list of metapath walks.
""" """
np.random.seed()
edge_types = metapath.split('-')
walk = [] walk = []
metapath = metapath.split('-') for node in start_nodes:
assert metapath[0] == metapath[ walk.append([node])
-1], "The last meta path item should be the same as the first one"
mp_len = len(metapath) - 1 cur_walk_ids = np.arange(0, len(start_nodes))
cur_nodes = np.array(start_nodes)
walk.append(start_node) mp_len = len(edge_types)
for i in range(1, walk_length): for i in range(0, walk_length - 1):
cur_node = walk[-1] g = graph[edge_types[i % mp_len]]
succs = graph.successor(cur_node)
if succs.size > 0: cur_succs = g.successor(cur_nodes)
succs_node_types = graph._node_types[succs] mask = [len(succ) > 0 for succ in cur_succs]
if np.any(mask):
cur_walk_ids = cur_walk_ids[mask]
cur_nodes = cur_nodes[mask]
cur_succs = cur_succs[mask]
else: else:
# no successor of current node # stop when all nodes have no successor
break break
succs_nodes = succs[np.where(succs_node_types == metapath[i % mp_len])[ if alias_name is not None and events_name is not None:
0]] sample_index = [
if succs_nodes.size > 0: alias_sample([1], g.node_feat[alias_name][node],
walk.append(np.random.choice(succs_nodes)) g.node_feat[events_name][node])[0]
for node in cur_nodes
]
else: else:
# no successor of such node type outdegree = [len(cur_succ) for cur_succ in cur_succs]
break sample_index = np.floor(
np.random.rand(cur_succs.shape[0]) * outdegree).astype("int64")
nxt_cur_nodes = []
for s, ind, walk_id in zip(cur_succs, sample_index, cur_walk_ids):
walk[walk_id].append(s[ind])
nxt_cur_nodes.append(s[ind])
cur_nodes = np.array(nxt_cur_nodes)
return walk return walk
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册