未验证 提交 c0f98318 编写于 作者: H Huang Zhengjie 提交者: GitHub

Merge pull request #112 from Yelrose/master

add partial node feature
......@@ -27,12 +27,11 @@ from pgl.utils.logger import log
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"]
def send(src, dst, nfeat, efeat, message_func):
"""Send message from src to dst.
"""
src_feat = op.read_rows(nfeat, src)
dst_feat = op.read_rows(nfeat, dst)
src_feat = op.RowReader(nfeat, src)
dst_feat = op.RowReader(nfeat, dst)
msg = message_func(src_feat, dst_feat, efeat)
return msg
......@@ -144,11 +143,13 @@ class BaseGraphWrapper(object):
"""
if efeat_list is None:
efeat_list = {}
if nfeat_list is None:
nfeat_list = {}
src, dst = self.edges
nfeat = {}
for feat in nfeat_list:
if isinstance(feat, str):
nfeat[feat] = self.node_feat[feat]
......
......@@ -68,3 +68,18 @@ def read_rows(data, index):
return new_data
else:
return paddle_helper.gather(data, index)
class RowReader(object):
"""Memory Efficient RowReader
"""
def __init__(self, nfeat, index):
self.nfeat = nfeat
self.loaded_nfeat = {}
self.index = index
def __getitem__(self, key):
if key not in self.loaded_nfeat:
self.loaded_nfeat[key] = read_rows(self.nfeat[key], self.index)
return self.loaded_nfeat[key]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册