diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index 0d67fa6bfec2d3ed8f3a208b28ebdfb4d82c1a6f..58b6ff00f2d2f41fd0a5450edd4a479ed0d8f70d 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -28,11 +28,21 @@ from pgl.utils.logger import log __all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"] -def send(src, dst, nfeat, efeat, message_func): +def send(src, dst, nfeat, efeat, message_func, nfeat_src, nfeat_dst): """Send message from src to dst. """ - src_feat = op.read_rows(nfeat, src) - dst_feat = op.read_rows(nfeat, dst) + for key in nfeat_src.keys(): + if key in nfeat: + log.info("Node-Feature %s both in nfeat_src_list and nfeat_list" % key) + + for key in nfeat_dst.keys(): + if key in nfeat: + log.info("Node-Feature %s both in nfeat_dst_list and nfeat_list" % key) + + nfeat_src.update(nfeat) + nfeat_dst.update(nfeat) + src_feat = op.read_rows(nfeat_src, src) + dst_feat = op.read_rows(nfeat_dst, dst) msg = message_func(src_feat, dst_feat, efeat) return msg @@ -111,7 +121,7 @@ class BaseGraphWrapper(object): def __repr__(self): return self._data_name_prefix - def send(self, message_func, nfeat_list=None, efeat_list=None): + def send(self, message_func, nfeat_list=None, efeat_list=None, nfeat_list_src=None, nfeat_list_dst=None): """Send message from all src nodes to dst nodes. The UDF message function should has the following format. @@ -136,6 +146,8 @@ class BaseGraphWrapper(object): message_func: UDF function. nfeat_list: a list of names or tuple (name, tensor) efeat_list: a list of names or tuple (name, tensor) + nfeat_list_src: a list of names or tuple (name, tensor). The node feature only for src + efeat_list_dst: a list of names or tuple (name, tensor). The node feature only for dst Return: A dictionary of tensor representing the message. Each of the values @@ -144,11 +156,19 @@ class BaseGraphWrapper(object): """ if efeat_list is None: efeat_list = {} + if nfeat_list is None: nfeat_list = {} + if nfeat_list_src is None: + nfeat_list_src = {} + + if nfeat_list_dst is None: + nfeat_list_dst = {} + src, dst = self.edges nfeat = {} + for feat in nfeat_list: if isinstance(feat, str): nfeat[feat] = self.node_feat[feat] @@ -156,6 +176,24 @@ class BaseGraphWrapper(object): name, tensor = feat nfeat[name] = tensor + nfeat_src = {} + + for feat in nfeat_list_src: + if isinstance(feat, str): + nfeat_src[feat] = self.node_feat[feat] + else: + name, tensor = feat + nfeat_src[name] = tensor + + nfeat_dst = {} + + for feat in nfeat_list_dst: + if isinstance(feat, str): + nfeat_dst[feat] = self.node_feat[feat] + else: + name, tensor = feat + nfeat_dst[name] = tensor + efeat = {} for feat in efeat_list: if isinstance(feat, str): @@ -164,7 +202,8 @@ class BaseGraphWrapper(object): name, tensor = feat efeat[name] = tensor - msg = send(src, dst, nfeat, efeat, message_func) + msg = send(src, dst, nfeat, efeat, message_func, + nfeat_src, nfeat_dst) return msg def recv(self, msg, reduce_function):