提交 69546bdc 编写于 作者: Y Yelrose

add partial node feature

上级 88b4ada0
......@@ -27,24 +27,11 @@ from pgl.utils.logger import log
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"]
class ReadRows(object):
"""Memory Efficient ReadRows
"""
def __init__(self, nfeat, index):
self.nfeat = nfeat
self.loaded_nfeat = {}
self.index = index
def __getitem__(self, key):
if key not in self.loaded_nfeat:
self.loaded_nfeat[key] = op.read_rows(self.nfeat[key], self.index)
return self.loaded_nfeat[key]
def send(src, dst, nfeat, efeat, message_func):
"""Send message from src to dst.
"""
src_feat = ReadRows(nfeat, src)
dst_feat = ReadRows(nfeat, dst)
src_feat = op.RowReader(nfeat, src)
dst_feat = op.RowReader(nfeat, dst)
msg = message_func(src_feat, dst_feat, efeat)
return msg
......
......@@ -68,3 +68,18 @@ def read_rows(data, index):
return new_data
else:
return paddle_helper.gather(data, index)
class RowReader(object):
"""Memory Efficient RowReader
"""
def __init__(self, nfeat, index):
self.nfeat = nfeat
self.loaded_nfeat = {}
self.index = index
def __getitem__(self, key):
if key not in self.loaded_nfeat:
self.loaded_nfeat[key] = read_rows(self.nfeat[key], self.index)
return self.loaded_nfeat[key]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册