From bbe1b9d8fd8d2ba8342a5472645c5ef459b07bde Mon Sep 17 00:00:00 2001 From: yelrose <270018958@qq.com> Date: Tue, 14 Apr 2020 12:15:53 +0800 Subject: [PATCH] Remove Place in GraphWrapper --- docs/source/quick_start/md/quick_start.md | 1 - .../quick_start/md/quick_start_for_heterGraph.md | 1 - examples/GATNE/model.py | 1 - examples/dgi/train.py | 1 - examples/distribute_graphsage/train.py | 2 +- examples/graphsage/train.py | 2 +- examples/graphsage/train_multi.py | 1 - examples/graphsage/train_scale.py | 1 - examples/stgcn/main.py | 1 - examples/unsup_graphsage/train.py | 2 +- ogb_examples/graphproppred/main_pgl.py | 1 - ogb_examples/linkproppred/main_pgl.py | 1 - pgl/graph_wrapper.py | 7 +------ pgl/heter_graph_wrapper.py | 10 +--------- 14 files changed, 5 insertions(+), 27 deletions(-) diff --git a/docs/source/quick_start/md/quick_start.md b/docs/source/quick_start/md/quick_start.md index 7df4f48..6c4fa5d 100644 --- a/docs/source/quick_start/md/quick_start.md +++ b/docs/source/quick_start/md/quick_start.md @@ -53,7 +53,6 @@ place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() # use GraphWrapper as a container for graph data to construct a graph neural network gw = pgl.graph_wrapper.GraphWrapper(name='graph', - place = place, node_feat=g.node_feat_info()) ``` diff --git a/docs/source/quick_start/md/quick_start_for_heterGraph.md b/docs/source/quick_start/md/quick_start_for_heterGraph.md index 1be9cb2..37ac0de 100644 --- a/docs/source/quick_start/md/quick_start_for_heterGraph.md +++ b/docs/source/quick_start/md/quick_start_for_heterGraph.md @@ -77,7 +77,6 @@ place = fluid.CPUPlace() # create a GraphWrapper as a container for graph data gw = heter_graph_wrapper.HeterGraphWrapper(name='heter_graph', - place = place, edge_types = g.edge_types_info(), node_feat=g.node_feat_info(), edge_feat=g.edge_feat_info()) diff --git a/examples/GATNE/model.py b/examples/GATNE/model.py index b193849..18f83c8 100644 --- a/examples/GATNE/model.py +++ b/examples/GATNE/model.py @@ -53,7 +53,6 @@ class GATNE(object): self.gw = heter_graph_wrapper.HeterGraphWrapper( name="heter_graph", - place=place, edge_types=self.graph.edge_types_info(), node_feat=self.graph.node_feat_info(), edge_feat=self.graph.edge_feat_info()) diff --git a/examples/dgi/train.py b/examples/dgi/train.py index 6774209..a23e4e7 100644 --- a/examples/dgi/train.py +++ b/examples/dgi/train.py @@ -65,7 +65,6 @@ def main(args): with fluid.program_guard(train_program, startup_program): gw = pgl.graph_wrapper.GraphWrapper( name="graph", - place=place, node_feat=dataset.graph.node_feat_info()) output = pgl.layers.gcn(gw, diff --git a/examples/distribute_graphsage/train.py b/examples/distribute_graphsage/train.py index fa52e3e..4faafdd 100644 --- a/examples/distribute_graphsage/train.py +++ b/examples/distribute_graphsage/train.py @@ -170,7 +170,7 @@ def main(args): with fluid.program_guard(train_program, startup_program): graph_wrapper = pgl.graph_wrapper.GraphWrapper( - "sub_graph", fluid.CPUPlace(), node_feat=[('feats', [None, 602], np.dtype('float32'))]) + "sub_graph", node_feat=[('feats', [None, 602], np.dtype('float32'))]) model_loss, model_acc = build_graph_model( graph_wrapper, num_class=data["num_class"], diff --git a/examples/graphsage/train.py b/examples/graphsage/train.py index da20f6e..463e0b6 100644 --- a/examples/graphsage/train.py +++ b/examples/graphsage/train.py @@ -204,8 +204,8 @@ def main(args): graph_wrapper = pgl.graph_wrapper.GraphWrapper( "sub_graph", - fluid.CPUPlace(), node_feat=data['graph'].node_feat_info()) + model_loss, model_acc = build_graph_model( graph_wrapper, num_class=data["num_class"], diff --git a/examples/graphsage/train_multi.py b/examples/graphsage/train_multi.py index eda3a34..1f8fe69 100644 --- a/examples/graphsage/train_multi.py +++ b/examples/graphsage/train_multi.py @@ -231,7 +231,6 @@ def main(args): with fluid.program_guard(train_program, startup_program): graph_wrapper = pgl.graph_wrapper.GraphWrapper( "sub_graph", - fluid.CPUPlace(), node_feat=data['graph'].node_feat_info()) model_loss, model_acc = build_graph_model( diff --git a/examples/graphsage/train_scale.py b/examples/graphsage/train_scale.py index f0625d0..c6fce99 100644 --- a/examples/graphsage/train_scale.py +++ b/examples/graphsage/train_scale.py @@ -227,7 +227,6 @@ def main(args): with fluid.program_guard(train_program, startup_program): graph_wrapper = pgl.graph_wrapper.GraphWrapper( "sub_graph", - fluid.CPUPlace(), node_feat=data['graph'].node_feat_info()) model_loss, model_acc = build_graph_model( diff --git a/examples/stgcn/main.py b/examples/stgcn/main.py index 6be8df9..26adb6a 100644 --- a/examples/stgcn/main.py +++ b/examples/stgcn/main.py @@ -49,7 +49,6 @@ def main(args): with fluid.program_guard(train_program, startup_program): gw = pgl.graph_wrapper.GraphWrapper( "gw", - place, node_feat=[('norm', [None, 1], "float32")], edge_feat=[('weights', [None, 1], "float32")]) diff --git a/examples/unsup_graphsage/train.py b/examples/unsup_graphsage/train.py index a53ffdc..cc7351b 100644 --- a/examples/unsup_graphsage/train.py +++ b/examples/unsup_graphsage/train.py @@ -88,7 +88,7 @@ def build_graph_model(args): graph_wrappers.append( pgl.graph_wrapper.GraphWrapper( - "layer_0", fluid.CPUPlace(), node_feat=node_feature_info)) + "layer_0", node_feat=node_feature_info)) #edge_feat=[("f", [None, 1], "float32")])) num_embed = args.num_nodes diff --git a/ogb_examples/graphproppred/main_pgl.py b/ogb_examples/graphproppred/main_pgl.py index ef7c112..1cc505e 100644 --- a/ogb_examples/graphproppred/main_pgl.py +++ b/ogb_examples/graphproppred/main_pgl.py @@ -148,7 +148,6 @@ def main(): with fluid.program_guard(train_program, startup_program): gw = pgl.graph_wrapper.GraphWrapper( "graph", - place=place, node_feat=graph_data.node_feat_info(), edge_feat=graph_data.edge_feat_info()) pred = model.forward(gw) diff --git a/ogb_examples/linkproppred/main_pgl.py b/ogb_examples/linkproppred/main_pgl.py index bb81a24..2f6be61 100644 --- a/ogb_examples/linkproppred/main_pgl.py +++ b/ogb_examples/linkproppred/main_pgl.py @@ -158,7 +158,6 @@ def main(): num_layers=args.num_layers) gw = pgl.graph_wrapper.GraphWrapper( "graph", - place, node_feat=graph_data.node_feat_info(), edge_feat=graph_data.edge_feat_info()) pred, prob, loss = model.forward(gw) diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index 3f30da4..0091764 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -475,9 +475,6 @@ class GraphWrapper(BaseGraphWrapper): Args: name: The graph data prefix - place: fluid.CPUPlace or fluid.CUDAPlace(n) indicating the - device to hold the graph data. - node_feat: A list of tuples that decribe the details of node feature tenosr. Each tuple mush be (name, shape, dtype) and the first dimension of the shape must be set unknown @@ -516,7 +513,6 @@ class GraphWrapper(BaseGraphWrapper): }) graph_wrapper = GraphWrapper(name="graph", - place=place, node_feat=graph.node_feat_info(), edge_feat=graph.edge_feat_info()) @@ -531,12 +527,11 @@ class GraphWrapper(BaseGraphWrapper): ret = exe.run(fetch_list=[...], feed=feed_dict ) """ - def __init__(self, name, place, node_feat=[], edge_feat=[]): + def __init__(self, name, node_feat=[], edge_feat=[], **kwargs): super(GraphWrapper, self).__init__() # collect holders for PyReader self._data_name_prefix = name self._holder_list = [] - self._place = place self.__create_graph_attr_holders() for node_feat_name, node_feat_shape, node_feat_dtype in node_feat: self.__create_graph_node_feat_holders( diff --git a/pgl/heter_graph_wrapper.py b/pgl/heter_graph_wrapper.py index a56bc4e..bd786c7 100644 --- a/pgl/heter_graph_wrapper.py +++ b/pgl/heter_graph_wrapper.py @@ -44,9 +44,6 @@ class HeterGraphWrapper(object): Args: name: The heterogeneous graph data prefix - place: fluid.CPUPlace or fluid.CUDAPlace(n) indicating the - device to hold the graph data. - node_feat: A dict of list of tuples that decribe the details of node feature tenosr. Each tuple mush be (name, shape, dtype) and the first dimension of the shape must be set unknown @@ -85,19 +82,15 @@ class HeterGraphWrapper(object): node_feat=node_feat, edge_feat=edges_feat) - place = fluid.CPUPlace() - gw = heter_graph_wrapper.HeterGraphWrapper( name='heter_graph', - place = place, edge_types = g.edge_types_info(), node_feat=g.node_feat_info(), edge_feat=g.edge_feat_info()) """ - def __init__(self, name, place, edge_types, node_feat={}, edge_feat={}): + def __init__(self, name, edge_types, node_feat={}, edge_feat={}, **kwargs): self.__data_name_prefix = name - self._place = place self._edge_types = edge_types self._multi_gw = {} for edge_type in self._edge_types: @@ -114,7 +107,6 @@ class HeterGraphWrapper(object): self._multi_gw[edge_type] = GraphWrapper( name=type_name, - place=self._place, node_feat=n_feat, edge_feat=e_feat) -- GitLab