提交 361da2cc 编写于 作者: Y Yelrose

add partial node feature

上级 fb01c859
......@@ -28,11 +28,21 @@ from pgl.utils.logger import log
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"]
def send(src, dst, nfeat, efeat, message_func):
def send(src, dst, nfeat, efeat, message_func, nfeat_src, nfeat_dst):
"""Send message from src to dst.
"""
src_feat = op.read_rows(nfeat, src)
dst_feat = op.read_rows(nfeat, dst)
for key in nfeat_src.keys():
if key in nfeat:
log.info("Node-Feature %s both in nfeat_src_list and nfeat_list" % key)
for key in nfeat_dst.keys():
if key in nfeat:
log.info("Node-Feature %s both in nfeat_dst_list and nfeat_list" % key)
nfeat_src.update(nfeat)
nfeat_dst.update(nfeat)
src_feat = op.read_rows(nfeat_src, src)
dst_feat = op.read_rows(nfeat_dst, dst)
msg = message_func(src_feat, dst_feat, efeat)
return msg
......@@ -111,7 +121,7 @@ class BaseGraphWrapper(object):
def __repr__(self):
return self._data_name_prefix
def send(self, message_func, nfeat_list=None, efeat_list=None):
def send(self, message_func, nfeat_list=None, efeat_list=None, nfeat_list_src=None, nfeat_list_dst=None):
"""Send message from all src nodes to dst nodes.
The UDF message function should has the following format.
......@@ -136,6 +146,8 @@ class BaseGraphWrapper(object):
message_func: UDF function.
nfeat_list: a list of names or tuple (name, tensor)
efeat_list: a list of names or tuple (name, tensor)
nfeat_list_src: a list of names or tuple (name, tensor). The node feature only for src
efeat_list_dst: a list of names or tuple (name, tensor). The node feature only for dst
Return:
A dictionary of tensor representing the message. Each of the values
......@@ -144,11 +156,19 @@ class BaseGraphWrapper(object):
"""
if efeat_list is None:
efeat_list = {}
if nfeat_list is None:
nfeat_list = {}
if nfeat_list_src is None:
nfeat_list_src = {}
if nfeat_list_dst is None:
nfeat_list_dst = {}
src, dst = self.edges
nfeat = {}
for feat in nfeat_list:
if isinstance(feat, str):
nfeat[feat] = self.node_feat[feat]
......@@ -156,6 +176,24 @@ class BaseGraphWrapper(object):
name, tensor = feat
nfeat[name] = tensor
nfeat_src = {}
for feat in nfeat_list_src:
if isinstance(feat, str):
nfeat_src[feat] = self.node_feat[feat]
else:
name, tensor = feat
nfeat_src[name] = tensor
nfeat_dst = {}
for feat in nfeat_list_dst:
if isinstance(feat, str):
nfeat_dst[feat] = self.node_feat[feat]
else:
name, tensor = feat
nfeat_dst[name] = tensor
efeat = {}
for feat in efeat_list:
if isinstance(feat, str):
......@@ -164,7 +202,8 @@ class BaseGraphWrapper(object):
name, tensor = feat
efeat[name] = tensor
msg = send(src, dst, nfeat, efeat, message_func)
msg = send(src, dst, nfeat, efeat, message_func,
nfeat_src, nfeat_dst)
return msg
def recv(self, msg, reduce_function):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册