提交 707ece69 编写于 作者: Y yelrose

Remove Place in GraphWrapper

上级 1dfe4882
......@@ -65,7 +65,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper(
name="graph",
place=place,
node_feat=dataset.graph.node_feat_info())
output = pgl.layers.gcn(gw,
......
......@@ -170,7 +170,7 @@ def main(args):
with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph", fluid.CPUPlace(), node_feat=[('feats', [None, 602], np.dtype('float32'))])
"sub_graph", node_feat=[('feats', [None, 602], np.dtype('float32'))])
model_loss, model_acc = build_graph_model(
graph_wrapper,
num_class=data["num_class"],
......
......@@ -204,8 +204,8 @@ def main(args):
graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model(
graph_wrapper,
num_class=data["num_class"],
......
......@@ -231,7 +231,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model(
......
......@@ -227,7 +227,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model(
......
......@@ -49,7 +49,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper(
"gw",
place,
node_feat=[('norm', [None, 1], "float32")],
edge_feat=[('weights', [None, 1], "float32")])
......
......@@ -88,7 +88,7 @@ def build_graph_model(args):
graph_wrappers.append(
pgl.graph_wrapper.GraphWrapper(
"layer_0", fluid.CPUPlace(), node_feat=node_feature_info))
"layer_0", node_feat=node_feature_info))
#edge_feat=[("f", [None, 1], "float32")]))
num_embed = args.num_nodes
......
......@@ -516,7 +516,6 @@ class GraphWrapper(BaseGraphWrapper):
})
graph_wrapper = GraphWrapper(name="graph",
place=place,
node_feat=graph.node_feat_info(),
edge_feat=graph.edge_feat_info())
......@@ -531,12 +530,11 @@ class GraphWrapper(BaseGraphWrapper):
ret = exe.run(fetch_list=[...], feed=feed_dict )
"""
def __init__(self, name, place, node_feat=[], edge_feat=[]):
def __init__(self, name, node_feat=[], edge_feat=[], **kwargs):
super(GraphWrapper, self).__init__()
# collect holders for PyReader
self._data_name_prefix = name
self._holder_list = []
self._place = place
self.__create_graph_attr_holders()
for node_feat_name, node_feat_shape, node_feat_dtype in node_feat:
self.__create_graph_node_feat_holders(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册