提交 361da2cc 编写于 作者: Y Yelrose

add partial node feature

上级 fb01c859
...@@ -28,11 +28,21 @@ from pgl.utils.logger import log ...@@ -28,11 +28,21 @@ from pgl.utils.logger import log
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"] __all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"]
def send(src, dst, nfeat, efeat, message_func): def send(src, dst, nfeat, efeat, message_func, nfeat_src, nfeat_dst):
"""Send message from src to dst. """Send message from src to dst.
""" """
src_feat = op.read_rows(nfeat, src) for key in nfeat_src.keys():
dst_feat = op.read_rows(nfeat, dst) if key in nfeat:
log.info("Node-Feature %s both in nfeat_src_list and nfeat_list" % key)
for key in nfeat_dst.keys():
if key in nfeat:
log.info("Node-Feature %s both in nfeat_dst_list and nfeat_list" % key)
nfeat_src.update(nfeat)
nfeat_dst.update(nfeat)
src_feat = op.read_rows(nfeat_src, src)
dst_feat = op.read_rows(nfeat_dst, dst)
msg = message_func(src_feat, dst_feat, efeat) msg = message_func(src_feat, dst_feat, efeat)
return msg return msg
...@@ -111,7 +121,7 @@ class BaseGraphWrapper(object): ...@@ -111,7 +121,7 @@ class BaseGraphWrapper(object):
def __repr__(self): def __repr__(self):
return self._data_name_prefix return self._data_name_prefix
def send(self, message_func, nfeat_list=None, efeat_list=None): def send(self, message_func, nfeat_list=None, efeat_list=None, nfeat_list_src=None, nfeat_list_dst=None):
"""Send message from all src nodes to dst nodes. """Send message from all src nodes to dst nodes.
The UDF message function should has the following format. The UDF message function should has the following format.
...@@ -136,6 +146,8 @@ class BaseGraphWrapper(object): ...@@ -136,6 +146,8 @@ class BaseGraphWrapper(object):
message_func: UDF function. message_func: UDF function.
nfeat_list: a list of names or tuple (name, tensor) nfeat_list: a list of names or tuple (name, tensor)
efeat_list: a list of names or tuple (name, tensor) efeat_list: a list of names or tuple (name, tensor)
nfeat_list_src: a list of names or tuple (name, tensor). The node feature only for src
efeat_list_dst: a list of names or tuple (name, tensor). The node feature only for dst
Return: Return:
A dictionary of tensor representing the message. Each of the values A dictionary of tensor representing the message. Each of the values
...@@ -144,11 +156,19 @@ class BaseGraphWrapper(object): ...@@ -144,11 +156,19 @@ class BaseGraphWrapper(object):
""" """
if efeat_list is None: if efeat_list is None:
efeat_list = {} efeat_list = {}
if nfeat_list is None: if nfeat_list is None:
nfeat_list = {} nfeat_list = {}
if nfeat_list_src is None:
nfeat_list_src = {}
if nfeat_list_dst is None:
nfeat_list_dst = {}
src, dst = self.edges src, dst = self.edges
nfeat = {} nfeat = {}
for feat in nfeat_list: for feat in nfeat_list:
if isinstance(feat, str): if isinstance(feat, str):
nfeat[feat] = self.node_feat[feat] nfeat[feat] = self.node_feat[feat]
...@@ -156,6 +176,24 @@ class BaseGraphWrapper(object): ...@@ -156,6 +176,24 @@ class BaseGraphWrapper(object):
name, tensor = feat name, tensor = feat
nfeat[name] = tensor nfeat[name] = tensor
nfeat_src = {}
for feat in nfeat_list_src:
if isinstance(feat, str):
nfeat_src[feat] = self.node_feat[feat]
else:
name, tensor = feat
nfeat_src[name] = tensor
nfeat_dst = {}
for feat in nfeat_list_dst:
if isinstance(feat, str):
nfeat_dst[feat] = self.node_feat[feat]
else:
name, tensor = feat
nfeat_dst[name] = tensor
efeat = {} efeat = {}
for feat in efeat_list: for feat in efeat_list:
if isinstance(feat, str): if isinstance(feat, str):
...@@ -164,7 +202,8 @@ class BaseGraphWrapper(object): ...@@ -164,7 +202,8 @@ class BaseGraphWrapper(object):
name, tensor = feat name, tensor = feat
efeat[name] = tensor efeat[name] = tensor
msg = send(src, dst, nfeat, efeat, message_func) msg = send(src, dst, nfeat, efeat, message_func,
nfeat_src, nfeat_dst)
return msg return msg
def recv(self, msg, reduce_function): def recv(self, msg, reduce_function):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册