提交 88b4ada0 编写于 作者: Y Yelrose

add partial node feature

上级 361da2cc
......@@ -27,22 +27,24 @@ from pgl.utils.logger import log
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"]
def send(src, dst, nfeat, efeat, message_func, nfeat_src, nfeat_dst):
class ReadRows(object):
"""Memory Efficient ReadRows
"""
def __init__(self, nfeat, index):
self.nfeat = nfeat
self.loaded_nfeat = {}
self.index = index
def __getitem__(self, key):
if key not in self.loaded_nfeat:
self.loaded_nfeat[key] = op.read_rows(self.nfeat[key], self.index)
return self.loaded_nfeat[key]
def send(src, dst, nfeat, efeat, message_func):
"""Send message from src to dst.
"""
for key in nfeat_src.keys():
if key in nfeat:
log.info("Node-Feature %s both in nfeat_src_list and nfeat_list" % key)
for key in nfeat_dst.keys():
if key in nfeat:
log.info("Node-Feature %s both in nfeat_dst_list and nfeat_list" % key)
nfeat_src.update(nfeat)
nfeat_dst.update(nfeat)
src_feat = op.read_rows(nfeat_src, src)
dst_feat = op.read_rows(nfeat_dst, dst)
src_feat = ReadRows(nfeat, src)
dst_feat = ReadRows(nfeat, dst)
msg = message_func(src_feat, dst_feat, efeat)
return msg
......@@ -121,7 +123,7 @@ class BaseGraphWrapper(object):
def __repr__(self):
return self._data_name_prefix
def send(self, message_func, nfeat_list=None, efeat_list=None, nfeat_list_src=None, nfeat_list_dst=None):
def send(self, message_func, nfeat_list=None, efeat_list=None):
"""Send message from all src nodes to dst nodes.
The UDF message function should has the following format.
......@@ -146,8 +148,6 @@ class BaseGraphWrapper(object):
message_func: UDF function.
nfeat_list: a list of names or tuple (name, tensor)
efeat_list: a list of names or tuple (name, tensor)
nfeat_list_src: a list of names or tuple (name, tensor). The node feature only for src
efeat_list_dst: a list of names or tuple (name, tensor). The node feature only for dst
Return:
A dictionary of tensor representing the message. Each of the values
......@@ -160,12 +160,6 @@ class BaseGraphWrapper(object):
if nfeat_list is None:
nfeat_list = {}
if nfeat_list_src is None:
nfeat_list_src = {}
if nfeat_list_dst is None:
nfeat_list_dst = {}
src, dst = self.edges
nfeat = {}
......@@ -176,24 +170,6 @@ class BaseGraphWrapper(object):
name, tensor = feat
nfeat[name] = tensor
nfeat_src = {}
for feat in nfeat_list_src:
if isinstance(feat, str):
nfeat_src[feat] = self.node_feat[feat]
else:
name, tensor = feat
nfeat_src[name] = tensor
nfeat_dst = {}
for feat in nfeat_list_dst:
if isinstance(feat, str):
nfeat_dst[feat] = self.node_feat[feat]
else:
name, tensor = feat
nfeat_dst[name] = tensor
efeat = {}
for feat in efeat_list:
if isinstance(feat, str):
......@@ -202,8 +178,7 @@ class BaseGraphWrapper(object):
name, tensor = feat
efeat[name] = tensor
msg = send(src, dst, nfeat, efeat, message_func,
nfeat_src, nfeat_dst)
msg = send(src, dst, nfeat, efeat, message_func)
return msg
def recv(self, msg, reduce_function):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册