From 88b4ada05f9a429ac90e7a68dcb0e415bcd56626 Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Thu, 6 Aug 2020 12:08:04 +0800 Subject: [PATCH] add partial node feature --- pgl/graph_wrapper.py | 61 +++++++++++++------------------------------- 1 file changed, 18 insertions(+), 43 deletions(-) diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index 58b6ff0..2df923c 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -27,22 +27,24 @@ from pgl.utils.logger import log __all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"] - -def send(src, dst, nfeat, efeat, message_func, nfeat_src, nfeat_dst): +class ReadRows(object): + """Memory Efficient ReadRows + """ + def __init__(self, nfeat, index): + self.nfeat = nfeat + self.loaded_nfeat = {} + self.index = index + + def __getitem__(self, key): + if key not in self.loaded_nfeat: + self.loaded_nfeat[key] = op.read_rows(self.nfeat[key], self.index) + return self.loaded_nfeat[key] + +def send(src, dst, nfeat, efeat, message_func): """Send message from src to dst. """ - for key in nfeat_src.keys(): - if key in nfeat: - log.info("Node-Feature %s both in nfeat_src_list and nfeat_list" % key) - - for key in nfeat_dst.keys(): - if key in nfeat: - log.info("Node-Feature %s both in nfeat_dst_list and nfeat_list" % key) - - nfeat_src.update(nfeat) - nfeat_dst.update(nfeat) - src_feat = op.read_rows(nfeat_src, src) - dst_feat = op.read_rows(nfeat_dst, dst) + src_feat = ReadRows(nfeat, src) + dst_feat = ReadRows(nfeat, dst) msg = message_func(src_feat, dst_feat, efeat) return msg @@ -121,7 +123,7 @@ class BaseGraphWrapper(object): def __repr__(self): return self._data_name_prefix - def send(self, message_func, nfeat_list=None, efeat_list=None, nfeat_list_src=None, nfeat_list_dst=None): + def send(self, message_func, nfeat_list=None, efeat_list=None): """Send message from all src nodes to dst nodes. The UDF message function should has the following format. @@ -146,8 +148,6 @@ class BaseGraphWrapper(object): message_func: UDF function. nfeat_list: a list of names or tuple (name, tensor) efeat_list: a list of names or tuple (name, tensor) - nfeat_list_src: a list of names or tuple (name, tensor). The node feature only for src - efeat_list_dst: a list of names or tuple (name, tensor). The node feature only for dst Return: A dictionary of tensor representing the message. Each of the values @@ -160,12 +160,6 @@ class BaseGraphWrapper(object): if nfeat_list is None: nfeat_list = {} - if nfeat_list_src is None: - nfeat_list_src = {} - - if nfeat_list_dst is None: - nfeat_list_dst = {} - src, dst = self.edges nfeat = {} @@ -176,24 +170,6 @@ class BaseGraphWrapper(object): name, tensor = feat nfeat[name] = tensor - nfeat_src = {} - - for feat in nfeat_list_src: - if isinstance(feat, str): - nfeat_src[feat] = self.node_feat[feat] - else: - name, tensor = feat - nfeat_src[name] = tensor - - nfeat_dst = {} - - for feat in nfeat_list_dst: - if isinstance(feat, str): - nfeat_dst[feat] = self.node_feat[feat] - else: - name, tensor = feat - nfeat_dst[name] = tensor - efeat = {} for feat in efeat_list: if isinstance(feat, str): @@ -202,8 +178,7 @@ class BaseGraphWrapper(object): name, tensor = feat efeat[name] = tensor - msg = send(src, dst, nfeat, efeat, message_func, - nfeat_src, nfeat_dst) + msg = send(src, dst, nfeat, efeat, message_func) return msg def recv(self, msg, reduce_function): -- GitLab