From 707ece69dba27eced7d85de0db6009c08a2b9fec Mon Sep 17 00:00:00 2001 From: yelrose <270018958@qq.com> Date: Tue, 14 Apr 2020 12:15:53 +0800 Subject: [PATCH] Remove Place in GraphWrapper --- examples/dgi/train.py | 1 - examples/distribute_graphsage/train.py | 2 +- examples/graphsage/train.py | 2 +- examples/graphsage/train_multi.py | 1 - examples/graphsage/train_scale.py | 1 - examples/stgcn/main.py | 1 - examples/unsup_graphsage/train.py | 2 +- pgl/graph_wrapper.py | 4 +--- 8 files changed, 4 insertions(+), 10 deletions(-) diff --git a/examples/dgi/train.py b/examples/dgi/train.py index 6774209..a23e4e7 100644 --- a/examples/dgi/train.py +++ b/examples/dgi/train.py @@ -65,7 +65,6 @@ def main(args): with fluid.program_guard(train_program, startup_program): gw = pgl.graph_wrapper.GraphWrapper( name="graph", - place=place, node_feat=dataset.graph.node_feat_info()) output = pgl.layers.gcn(gw, diff --git a/examples/distribute_graphsage/train.py b/examples/distribute_graphsage/train.py index fa52e3e..4faafdd 100644 --- a/examples/distribute_graphsage/train.py +++ b/examples/distribute_graphsage/train.py @@ -170,7 +170,7 @@ def main(args): with fluid.program_guard(train_program, startup_program): graph_wrapper = pgl.graph_wrapper.GraphWrapper( - "sub_graph", fluid.CPUPlace(), node_feat=[('feats', [None, 602], np.dtype('float32'))]) + "sub_graph", node_feat=[('feats', [None, 602], np.dtype('float32'))]) model_loss, model_acc = build_graph_model( graph_wrapper, num_class=data["num_class"], diff --git a/examples/graphsage/train.py b/examples/graphsage/train.py index da20f6e..463e0b6 100644 --- a/examples/graphsage/train.py +++ b/examples/graphsage/train.py @@ -204,8 +204,8 @@ def main(args): graph_wrapper = pgl.graph_wrapper.GraphWrapper( "sub_graph", - fluid.CPUPlace(), node_feat=data['graph'].node_feat_info()) + model_loss, model_acc = build_graph_model( graph_wrapper, num_class=data["num_class"], diff --git a/examples/graphsage/train_multi.py b/examples/graphsage/train_multi.py index eda3a34..1f8fe69 100644 --- a/examples/graphsage/train_multi.py +++ b/examples/graphsage/train_multi.py @@ -231,7 +231,6 @@ def main(args): with fluid.program_guard(train_program, startup_program): graph_wrapper = pgl.graph_wrapper.GraphWrapper( "sub_graph", - fluid.CPUPlace(), node_feat=data['graph'].node_feat_info()) model_loss, model_acc = build_graph_model( diff --git a/examples/graphsage/train_scale.py b/examples/graphsage/train_scale.py index f0625d0..c6fce99 100644 --- a/examples/graphsage/train_scale.py +++ b/examples/graphsage/train_scale.py @@ -227,7 +227,6 @@ def main(args): with fluid.program_guard(train_program, startup_program): graph_wrapper = pgl.graph_wrapper.GraphWrapper( "sub_graph", - fluid.CPUPlace(), node_feat=data['graph'].node_feat_info()) model_loss, model_acc = build_graph_model( diff --git a/examples/stgcn/main.py b/examples/stgcn/main.py index 6be8df9..26adb6a 100644 --- a/examples/stgcn/main.py +++ b/examples/stgcn/main.py @@ -49,7 +49,6 @@ def main(args): with fluid.program_guard(train_program, startup_program): gw = pgl.graph_wrapper.GraphWrapper( "gw", - place, node_feat=[('norm', [None, 1], "float32")], edge_feat=[('weights', [None, 1], "float32")]) diff --git a/examples/unsup_graphsage/train.py b/examples/unsup_graphsage/train.py index a53ffdc..cc7351b 100644 --- a/examples/unsup_graphsage/train.py +++ b/examples/unsup_graphsage/train.py @@ -88,7 +88,7 @@ def build_graph_model(args): graph_wrappers.append( pgl.graph_wrapper.GraphWrapper( - "layer_0", fluid.CPUPlace(), node_feat=node_feature_info)) + "layer_0", node_feat=node_feature_info)) #edge_feat=[("f", [None, 1], "float32")])) num_embed = args.num_nodes diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index 3f30da4..e8e6cb4 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -516,7 +516,6 @@ class GraphWrapper(BaseGraphWrapper): }) graph_wrapper = GraphWrapper(name="graph", - place=place, node_feat=graph.node_feat_info(), edge_feat=graph.edge_feat_info()) @@ -531,12 +530,11 @@ class GraphWrapper(BaseGraphWrapper): ret = exe.run(fetch_list=[...], feed=feed_dict ) """ - def __init__(self, name, place, node_feat=[], edge_feat=[]): + def __init__(self, name, node_feat=[], edge_feat=[], **kwargs): super(GraphWrapper, self).__init__() # collect holders for PyReader self._data_name_prefix = name self._holder_list = [] - self._place = place self.__create_graph_attr_holders() for node_feat_name, node_feat_shape, node_feat_dtype in node_feat: self.__create_graph_node_feat_holders( -- GitLab