未验证 提交 c0f98318 编写于 作者: H Huang Zhengjie 提交者: GitHub

Merge pull request #112 from Yelrose/master

add partial node feature
...@@ -27,12 +27,11 @@ from pgl.utils.logger import log ...@@ -27,12 +27,11 @@ from pgl.utils.logger import log
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"] __all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"]
def send(src, dst, nfeat, efeat, message_func): def send(src, dst, nfeat, efeat, message_func):
"""Send message from src to dst. """Send message from src to dst.
""" """
src_feat = op.read_rows(nfeat, src) src_feat = op.RowReader(nfeat, src)
dst_feat = op.read_rows(nfeat, dst) dst_feat = op.RowReader(nfeat, dst)
msg = message_func(src_feat, dst_feat, efeat) msg = message_func(src_feat, dst_feat, efeat)
return msg return msg
...@@ -144,11 +143,13 @@ class BaseGraphWrapper(object): ...@@ -144,11 +143,13 @@ class BaseGraphWrapper(object):
""" """
if efeat_list is None: if efeat_list is None:
efeat_list = {} efeat_list = {}
if nfeat_list is None: if nfeat_list is None:
nfeat_list = {} nfeat_list = {}
src, dst = self.edges src, dst = self.edges
nfeat = {} nfeat = {}
for feat in nfeat_list: for feat in nfeat_list:
if isinstance(feat, str): if isinstance(feat, str):
nfeat[feat] = self.node_feat[feat] nfeat[feat] = self.node_feat[feat]
......
...@@ -68,3 +68,18 @@ def read_rows(data, index): ...@@ -68,3 +68,18 @@ def read_rows(data, index):
return new_data return new_data
else: else:
return paddle_helper.gather(data, index) return paddle_helper.gather(data, index)
class RowReader(object):
"""Memory Efficient RowReader
"""
def __init__(self, nfeat, index):
self.nfeat = nfeat
self.loaded_nfeat = {}
self.index = index
def __getitem__(self, key):
if key not in self.loaded_nfeat:
self.loaded_nfeat[key] = read_rows(self.nfeat[key], self.index)
return self.loaded_nfeat[key]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册