提交 16c41716 编写于 作者: L liweibin

update pgl

上级 0bdc0da9
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Generate pgl apis
"""
__version__ = "1.0.0"
__version__ = "1.0.1"
from pgl import layers
from pgl import graph_wrapper
from pgl import graph
......
......@@ -14,11 +14,12 @@
"""
This package implement Heterogeneous Graph structure for handling Heterogeneous graph data.
"""
import time
import numpy as np
import pickle as pkl
import time
import pgl.graph_kernel as graph_kernel
from pgl import graph
from pgl.graph import Graph
__all__ = ['HeterGraph']
......@@ -31,123 +32,111 @@ def _hide_num_nodes(shape):
return shape
class HeterGraph(object):
"""Implementation of graph structure in pgl
This is a simple implementation of heterogeneous graph structure in pgl
class NodeGraph(Graph):
"""Implementation of a graph that has multple node types.
Args:
num_nodes_every_type: dict, number of nodes for every node type
num_nodes: number of nodes in the graph
edges: list of (u, v) tuples
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of numpy array as edge features
"""
edges_every_type: dict, every element is a list of (u, v) tuples.
def __init__(self,
num_nodes,
edges,
node_types=None,
node_feat=None,
edge_feat=None):
super(NodeGraph, self).__init__(num_nodes, edges, node_feat, edge_feat)
if isinstance(node_types, list):
self._node_types = np.array(node_types, dtype=object)[:, 1]
else:
self._node_types = node_types
node_feat_every_type: features for every node type.
class HeterGraph(object):
"""Implementation of heterogeneous graph structure in pgl
This is a simple implementation of heterogeneous graph structure in pgl.
Args:
num_nodes: number of nodes in a heterogeneous graph
edges: dict, every element in dict is a list of (u, v) tuples.
node_types (optional): list of (u, node_type) tuples to specify the node type of every node
node_feat (optional): a dict of numpy array as node features
edge_feat (optional): a dict of dict as edge features for every edge type
Examples:
.. code-block:: python
import numpy as np
num_nodes_every_type = {'type1':3,'type2':4, 'type3':2}
edges_every_type = {
('type1','type2', 'edges_type1'): [(0,1), (1,2)],
('type1', 'type3', 'edges_type2'): [(1,2), (3,1)],
}
node_feat_every_type = {
'type1': {'features1': np.random.randn(3, 4),
'features2': np.random.randn(3, 4)},
'type2': {'features3': np.random.randn(4, 4)},
'type3': {'features1': np.random.randn(2, 4),
'features2': np.random.randn(2, 4)}
num_nodes = 4
node_types = [(0, 'user'), (1, 'item'), (2, 'item'), (3, 'user')]
edges = {
'edges_type1': [(0,1), (3,2)],
'edges_type2': [(1,2), (3,1)],
}
edges_feat_every_type = {
('type1','type2','edges_type1'): {'h': np.random.randn(2, 4)},
('type1', 'type3', 'edges_type2'): {'h':np.random.randn(2, 4)},
node_feat = {'feature': np.random.randn(4, 16)}
edges_feat = {
'edges_type1': {'h': np.random.randn(2, 16)},
'edges_type2': {'h': np.random.randn(2, 16)},
}
g = heter_graph.HeterGraph(
num_nodes_every_type=num_nodes_every_type,
edges_every_type=edges_every_type,
node_feat_every_type=node_feat_every_type,
edge_feat_every_type=edges_feat_every_type)
num_nodes=num_nodes,
edges=edges,
node_types=node_types,
node_feat=node_feat,
edge_feat=edges_feat)
"""
def __init__(self,
num_nodes_every_type,
edges_every_type,
node_feat_every_type=None,
edge_feat_every_type=None):
self._num_nodes_dict = num_nodes_every_type
self._edges_dict = edges_every_type
if node_feat_every_type is not None:
self._node_feat = node_feat_every_type
num_nodes,
edges,
node_types=None,
node_feat=None,
edge_feat=None):
self._num_nodes = num_nodes
self._edges_dict = edges
if node_feat is not None:
self._node_feat = node_feat
else:
self._node_feat = {}
if edge_feat_every_type is not None:
self._edge_feat = edge_feat_every_type
if edge_feat is not None:
self._edge_feat = edge_feat
else:
self._edge_feat = {}
self._multi_graph = {}
for key, value in self._edges_dict.items():
if not self._node_feat:
node_feat = None
else:
node_feat = self._node_feat[key[0]]
if not self._edge_feat:
edge_feat = None
else:
edge_feat = self._edge_feat[key]
self._multi_graph[key] = graph.Graph(
num_nodes=self._num_nodes_dict[key[1]],
self._multi_graph[key] = NodeGraph(
num_nodes=self._num_nodes,
edges=value,
node_feat=node_feat,
node_types=node_types,
node_feat=self._node_feat,
edge_feat=edge_feat)
@property
def num_nodes(self):
"""Return the number of nodes.
"""
return self._num_nodes
def __getitem__(self, edge_type):
"""__getitem__
"""
return self._multi_graph[edge_type]
def meta_path_random_walk(self, start_node, edge_types, meta_path,
max_depth):
"""Meta path random walk sampling.
Args:
start_nodes: int, node to begin random walk.
edge_types: list, the edge types to be sampled.
meta_path: 'user-item-user'
max_depth: the max length of every walk.
"""
edges_type_list = []
node_type_list = meta_path.split('-')
for i in range(1, len(node_type_list)):
edges_type_list.append(
(node_type_list[i - 1], node_type_list[i], edge_types[i - 1]))
no_neighbors_flag = False
walk = [start_node]
for i in range(max_depth):
for e_type in edges_type_list:
cur_node = [walk[-1]]
nxt_node = self._multi_graph[e_type].sample_successor(
cur_node, max_degree=1) # list of np.array
nxt_node = nxt_node[0]
if len(nxt_node) == 0:
no_neighbors_flag = True
break
else:
walk.append(nxt_node.tolist()[0])
if no_neighbors_flag:
break
return walk
def node_feat_info(self):
"""Return the information of node feature for HeterGraphWrapper.
......@@ -155,17 +144,13 @@ class HeterGraph(object):
function is used to help constructing HeterGraphWrapper
Return:
A dict of list of tuple (name, shape, dtype) for all given node feature.
A list of tuple (name, shape, dtype) for all given node feature.
"""
node_feat_info = {}
for node_type_name, feat_dict in self._node_feat.items():
tmp_node_feat_info = []
for feat_name, feat in feat_dict.items():
full_name = feat_name
tmp_node_feat_info.append(
(full_name, _hide_num_nodes(feat.shape), feat.dtype))
node_feat_info[node_type_name] = tmp_node_feat_info
node_feat_info = []
for feat_name, feat in self._node_feat.items():
node_feat_info.append(
(feat_name, _hide_num_nodes(feat.shape), feat.dtype))
return node_feat_info
......@@ -193,7 +178,7 @@ class HeterGraph(object):
"""Return the information of all edge types.
Return:
A list of tuple ('srctype','dsttype', 'edges_type') for all edge types.
A list of all edge types.
"""
edge_types_info = []
......
......@@ -26,6 +26,7 @@ from pgl.utils.logger import log
from pgl.graph_wrapper import GraphWrapper
ALL = "__ALL__"
__all__ = ["HeterGraphWrapper"]
def is_all(arg):
......@@ -34,89 +35,6 @@ def is_all(arg):
return isinstance(arg, str) and arg == ALL
class BipartiteGraphWrapper(GraphWrapper):
"""Implement a bipartite graph wrapper that creates a graph data holders.
"""
def __init__(self, name, place, node_feat=[], edge_feat=[]):
super(BipartiteGraphWrapper, self).__init__(name, place, node_feat,
edge_feat)
def send(self,
message_func,
src_nfeat_list=None,
dst_nfeat_list=None,
efeat_list=None):
"""Send message from all src nodes to dst nodes.
The UDF message function should has the following format.
.. code-block:: python
def message_func(src_feat, dst_feat, edge_feat):
'''
Args:
src_feat: the node feat dict attached to the src nodes.
dst_feat: the node feat dict attached to the dst nodes.
edge_feat: the edge feat dict attached to the
corresponding (src, dst) edges.
Return:
It should return a tensor or a dictionary of tensor. And each tensor
should have a shape of (num_edges, dims).
'''
pass
Args:
message_func: UDF function.
src_nfeat_list: a list of tuple (name, tensor) for src nodes
dst_nfeat_list: a list of tuple (name, tensor) for dst nodes
efeat_list: a list of names or tuple (name, tensor)
Return:
A dictionary of tensor representing the message. Each of the values
in the dictionary has a shape (num_edges, dim) which should be collected
by :code:`recv` function.
"""
if efeat_list is None:
efeat_list = {}
if src_nfeat_list is None:
src_nfeat_list = {}
if dst_nfeat_list is None:
dst_nfeat_list = {}
src, dst = self.edges
src_feat = {}
for feat in src_nfeat_list:
if isinstance(feat, str):
src_feat[feat] = self.node_feat[feat]
else:
name, tensor = feat
src_feat[name] = tensor
dst_feat = {}
for feat in dst_nfeat_list:
if isinstance(feat, str):
dst_feat[feat] = self.node_feat[feat]
else:
name, tensor = feat
dst_feat[name] = tensor
efeat = {}
for feat in efeat_list:
if isinstance(feat, str):
efeat[feat] = self.edge_feat[feat]
else:
name, tensor = feat
efeat[name] = tensor
src_feat = op.read_rows(src_feat, src)
dst_feat = op.read_rows(dst_feat, dst)
msg = message_func(src_feat, dst_feat, efeat)
return msg
class HeterGraphWrapper(object):
"""Implement a heterogeneous graph wrapper that creates a graph data holders
that attributes and features in the heterogeneous graph.
......@@ -146,33 +64,30 @@ class HeterGraphWrapper(object):
import paddle.fluid as fluid
import numpy as np
num_nodes_every_type = {'type1':3,'type2':4, 'type3':2}
edges_every_type = {
('type1','type2', 'edges_type1'): [(0,1), (1,2)],
('type1', 'type3', 'edges_type2'): [(1,2), (3,1)],
}
node_feat_every_type = {
'type1': {'features1': np.random.randn(3, 4),
'features2': np.random.randn(3, 4)},
'type2': {'features3': np.random.randn(4, 4)},
'type3': {'features1': np.random.randn(2, 4),
'features2': np.random.randn(2, 4)}
from pgl.contrib import heter_graph
from pgl.contrib import heter_graph_wrapper
num_nodes = 4
node_types = [(0, 'user'), (1, 'item'), (2, 'item'), (3, 'user')]
edges = {
'edges_type1': [(0,1), (3,2)],
'edges_type2': [(1,2), (3,1)],
}
edges_feat_every_type = {
('type1','type2','edges_type1'): {'h': np.random.randn(2, 4)},
('type1', 'type3', 'edges_type2'): {'h':np.random.randn(2, 4)},
node_feat = {'feature': np.random.randn(4, 16)}
edges_feat = {
'edges_type1': {'h': np.random.randn(2, 16)},
'edges_type2': {'h': np.random.randn(2, 16)},
}
g = heter_graph.HeterGraph(
num_nodes_every_type=num_nodes_every_type,
edges_every_type=edges_every_type,
node_feat_every_type=node_feat_every_type,
edge_feat_every_type=edges_feat_every_type)
num_nodes=num_nodes,
edges=edges,
node_types=node_types,
node_feat=node_feat,
edge_feat=edges_feat)
place = fluid.CPUPlace()
gw = pgl.heter_graph_wrapper.HeterGraphWrapper(
gw = heter_graph_wrapper.HeterGraphWrapper(
name='heter_graph',
place = place,
edge_types = g.edge_types_info(),
......@@ -186,10 +101,9 @@ class HeterGraphWrapper(object):
self._edge_types = edge_types
self._multi_gw = {}
for edge_type in self._edge_types:
type_name = self.__data_name_prefix + '/' + edge_type[
0] + '_' + edge_type[1]
type_name = self.__data_name_prefix + '/' + edge_type
if node_feat:
n_feat = node_feat[edge_type[0]]
n_feat = node_feat
else:
n_feat = {}
......@@ -198,7 +112,7 @@ class HeterGraphWrapper(object):
else:
e_feat = {}
self._multi_gw[edge_type] = BipartiteGraphWrapper(
self._multi_gw[edge_type] = GraphWrapper(
name=type_name,
place=self._place,
node_feat=n_feat,
......
......@@ -596,8 +596,7 @@ class GraphWrapper(BaseGraphWrapper):
feed_dict[self.__data_name_prefix + '/edges_src'] = src
feed_dict[self.__data_name_prefix + '/edges_dst'] = dst
feed_dict[self.__data_name_prefix + '/num_nodes'] = np.array(
graph.num_nodes)
feed_dict[self.__data_name_prefix + '/num_nodes'] = np.array(graph.num_nodes)
feed_dict[self.__data_name_prefix + '/uniq_dst'] = uniq_dst
feed_dict[self.__data_name_prefix + '/uniq_dst_count'] = uniq_dst_count
feed_dict[self.__data_name_prefix + '/node_ids'] = graph.nodes
......
......@@ -16,6 +16,8 @@
from pgl.layers import conv
from pgl.layers.conv import *
from pgl.layers.set2set import *
__all__ = []
__all__ += conv.__all__
__all__ += set2set.__all__
......@@ -22,7 +22,10 @@ import pgl
from pgl.utils.logger import log
from pgl import graph_kernel
__all__ = ['graphsage_sample', 'node2vec_sample', 'deepwalk_sample']
__all__ = [
'graphsage_sample', 'node2vec_sample', 'deepwalk_sample',
'metapath_randomwalk'
]
def edge_hash(src, dst):
......@@ -251,3 +254,45 @@ def node2vec_sample(graph, nodes, max_depth, p=1.0, q=1.0):
prev_nodes, prev_succs = cur_nodes, cur_succs
cur_nodes = nxt_nodes
return walk
def metapath_randomwalk(graph, start_node, metapath, walk_length):
"""Implementation of metapath random walk in heterogeneous graph.
Args:
graph: instance of pgl heterogeneous graph
start_node: start node to generate walk
metapath: meta path for sample nodes.
e.g: "user-item-user"
walk_length: the walk length
Return:
a list of metapath walk, each element is a node id.
"""
np.random.seed()
walk = []
metapath = metapath.split('-')
assert metapath[0] == metapath[
-1], "The last meta path item should be the same as the first one"
mp_len = len(metapath) - 1
walk.append(start_node)
for i in range(1, walk_length):
cur_node = walk[-1]
succs = graph.successor(cur_node)
if succs.size > 0:
succs_node_types = graph._node_types[succs]
else:
# no successor of current node
break
succs_nodes = succs[np.where(succs_node_types == metapath[i % mp_len])[
0]]
if succs_nodes.size > 0:
walk.append(np.random.choice(succs_nodes))
else:
# no successor of such node type
break
return walk
......@@ -226,7 +226,6 @@ def scatter_add(input, index, updates):
output = fluid.layers.scatter(input, index, updates, mode='add')
return output
def scatter_max(input, index, updates):
"""Scatter max updates to input by given index.
......@@ -245,3 +244,4 @@ def scatter_max(input, index, updates):
output = fluid.layers.scatter(input, index, updates, mode='max')
return output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册