提交 bbe1b9d8 编写于 作者: Y yelrose

Remove Place in GraphWrapper

上级 1dfe4882
...@@ -53,7 +53,6 @@ place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() ...@@ -53,7 +53,6 @@ place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
# use GraphWrapper as a container for graph data to construct a graph neural network # use GraphWrapper as a container for graph data to construct a graph neural network
gw = pgl.graph_wrapper.GraphWrapper(name='graph', gw = pgl.graph_wrapper.GraphWrapper(name='graph',
place = place,
node_feat=g.node_feat_info()) node_feat=g.node_feat_info())
``` ```
......
...@@ -77,7 +77,6 @@ place = fluid.CPUPlace() ...@@ -77,7 +77,6 @@ place = fluid.CPUPlace()
# create a GraphWrapper as a container for graph data # create a GraphWrapper as a container for graph data
gw = heter_graph_wrapper.HeterGraphWrapper(name='heter_graph', gw = heter_graph_wrapper.HeterGraphWrapper(name='heter_graph',
place = place,
edge_types = g.edge_types_info(), edge_types = g.edge_types_info(),
node_feat=g.node_feat_info(), node_feat=g.node_feat_info(),
edge_feat=g.edge_feat_info()) edge_feat=g.edge_feat_info())
......
...@@ -53,7 +53,6 @@ class GATNE(object): ...@@ -53,7 +53,6 @@ class GATNE(object):
self.gw = heter_graph_wrapper.HeterGraphWrapper( self.gw = heter_graph_wrapper.HeterGraphWrapper(
name="heter_graph", name="heter_graph",
place=place,
edge_types=self.graph.edge_types_info(), edge_types=self.graph.edge_types_info(),
node_feat=self.graph.node_feat_info(), node_feat=self.graph.node_feat_info(),
edge_feat=self.graph.edge_feat_info()) edge_feat=self.graph.edge_feat_info())
......
...@@ -65,7 +65,6 @@ def main(args): ...@@ -65,7 +65,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper( gw = pgl.graph_wrapper.GraphWrapper(
name="graph", name="graph",
place=place,
node_feat=dataset.graph.node_feat_info()) node_feat=dataset.graph.node_feat_info())
output = pgl.layers.gcn(gw, output = pgl.layers.gcn(gw,
......
...@@ -170,7 +170,7 @@ def main(args): ...@@ -170,7 +170,7 @@ def main(args):
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper( graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph", fluid.CPUPlace(), node_feat=[('feats', [None, 602], np.dtype('float32'))]) "sub_graph", node_feat=[('feats', [None, 602], np.dtype('float32'))])
model_loss, model_acc = build_graph_model( model_loss, model_acc = build_graph_model(
graph_wrapper, graph_wrapper,
num_class=data["num_class"], num_class=data["num_class"],
......
...@@ -204,8 +204,8 @@ def main(args): ...@@ -204,8 +204,8 @@ def main(args):
graph_wrapper = pgl.graph_wrapper.GraphWrapper( graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph", "sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info()) node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model( model_loss, model_acc = build_graph_model(
graph_wrapper, graph_wrapper,
num_class=data["num_class"], num_class=data["num_class"],
......
...@@ -231,7 +231,6 @@ def main(args): ...@@ -231,7 +231,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper( graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph", "sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info()) node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model( model_loss, model_acc = build_graph_model(
......
...@@ -227,7 +227,6 @@ def main(args): ...@@ -227,7 +227,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper( graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph", "sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info()) node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model( model_loss, model_acc = build_graph_model(
......
...@@ -49,7 +49,6 @@ def main(args): ...@@ -49,7 +49,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper( gw = pgl.graph_wrapper.GraphWrapper(
"gw", "gw",
place,
node_feat=[('norm', [None, 1], "float32")], node_feat=[('norm', [None, 1], "float32")],
edge_feat=[('weights', [None, 1], "float32")]) edge_feat=[('weights', [None, 1], "float32")])
......
...@@ -88,7 +88,7 @@ def build_graph_model(args): ...@@ -88,7 +88,7 @@ def build_graph_model(args):
graph_wrappers.append( graph_wrappers.append(
pgl.graph_wrapper.GraphWrapper( pgl.graph_wrapper.GraphWrapper(
"layer_0", fluid.CPUPlace(), node_feat=node_feature_info)) "layer_0", node_feat=node_feature_info))
#edge_feat=[("f", [None, 1], "float32")])) #edge_feat=[("f", [None, 1], "float32")]))
num_embed = args.num_nodes num_embed = args.num_nodes
......
...@@ -148,7 +148,6 @@ def main(): ...@@ -148,7 +148,6 @@ def main():
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper( gw = pgl.graph_wrapper.GraphWrapper(
"graph", "graph",
place=place,
node_feat=graph_data.node_feat_info(), node_feat=graph_data.node_feat_info(),
edge_feat=graph_data.edge_feat_info()) edge_feat=graph_data.edge_feat_info())
pred = model.forward(gw) pred = model.forward(gw)
......
...@@ -158,7 +158,6 @@ def main(): ...@@ -158,7 +158,6 @@ def main():
num_layers=args.num_layers) num_layers=args.num_layers)
gw = pgl.graph_wrapper.GraphWrapper( gw = pgl.graph_wrapper.GraphWrapper(
"graph", "graph",
place,
node_feat=graph_data.node_feat_info(), node_feat=graph_data.node_feat_info(),
edge_feat=graph_data.edge_feat_info()) edge_feat=graph_data.edge_feat_info())
pred, prob, loss = model.forward(gw) pred, prob, loss = model.forward(gw)
......
...@@ -475,9 +475,6 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -475,9 +475,6 @@ class GraphWrapper(BaseGraphWrapper):
Args: Args:
name: The graph data prefix name: The graph data prefix
place: fluid.CPUPlace or fluid.CUDAPlace(n) indicating the
device to hold the graph data.
node_feat: A list of tuples that decribe the details of node node_feat: A list of tuples that decribe the details of node
feature tenosr. Each tuple mush be (name, shape, dtype) feature tenosr. Each tuple mush be (name, shape, dtype)
and the first dimension of the shape must be set unknown and the first dimension of the shape must be set unknown
...@@ -516,7 +513,6 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -516,7 +513,6 @@ class GraphWrapper(BaseGraphWrapper):
}) })
graph_wrapper = GraphWrapper(name="graph", graph_wrapper = GraphWrapper(name="graph",
place=place,
node_feat=graph.node_feat_info(), node_feat=graph.node_feat_info(),
edge_feat=graph.edge_feat_info()) edge_feat=graph.edge_feat_info())
...@@ -531,12 +527,11 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -531,12 +527,11 @@ class GraphWrapper(BaseGraphWrapper):
ret = exe.run(fetch_list=[...], feed=feed_dict ) ret = exe.run(fetch_list=[...], feed=feed_dict )
""" """
def __init__(self, name, place, node_feat=[], edge_feat=[]): def __init__(self, name, node_feat=[], edge_feat=[], **kwargs):
super(GraphWrapper, self).__init__() super(GraphWrapper, self).__init__()
# collect holders for PyReader # collect holders for PyReader
self._data_name_prefix = name self._data_name_prefix = name
self._holder_list = [] self._holder_list = []
self._place = place
self.__create_graph_attr_holders() self.__create_graph_attr_holders()
for node_feat_name, node_feat_shape, node_feat_dtype in node_feat: for node_feat_name, node_feat_shape, node_feat_dtype in node_feat:
self.__create_graph_node_feat_holders( self.__create_graph_node_feat_holders(
......
...@@ -44,9 +44,6 @@ class HeterGraphWrapper(object): ...@@ -44,9 +44,6 @@ class HeterGraphWrapper(object):
Args: Args:
name: The heterogeneous graph data prefix name: The heterogeneous graph data prefix
place: fluid.CPUPlace or fluid.CUDAPlace(n) indicating the
device to hold the graph data.
node_feat: A dict of list of tuples that decribe the details of node node_feat: A dict of list of tuples that decribe the details of node
feature tenosr. Each tuple mush be (name, shape, dtype) feature tenosr. Each tuple mush be (name, shape, dtype)
and the first dimension of the shape must be set unknown and the first dimension of the shape must be set unknown
...@@ -85,19 +82,15 @@ class HeterGraphWrapper(object): ...@@ -85,19 +82,15 @@ class HeterGraphWrapper(object):
node_feat=node_feat, node_feat=node_feat,
edge_feat=edges_feat) edge_feat=edges_feat)
place = fluid.CPUPlace()
gw = heter_graph_wrapper.HeterGraphWrapper( gw = heter_graph_wrapper.HeterGraphWrapper(
name='heter_graph', name='heter_graph',
place = place,
edge_types = g.edge_types_info(), edge_types = g.edge_types_info(),
node_feat=g.node_feat_info(), node_feat=g.node_feat_info(),
edge_feat=g.edge_feat_info()) edge_feat=g.edge_feat_info())
""" """
def __init__(self, name, place, edge_types, node_feat={}, edge_feat={}): def __init__(self, name, edge_types, node_feat={}, edge_feat={}, **kwargs):
self.__data_name_prefix = name self.__data_name_prefix = name
self._place = place
self._edge_types = edge_types self._edge_types = edge_types
self._multi_gw = {} self._multi_gw = {}
for edge_type in self._edge_types: for edge_type in self._edge_types:
...@@ -114,7 +107,6 @@ class HeterGraphWrapper(object): ...@@ -114,7 +107,6 @@ class HeterGraphWrapper(object):
self._multi_gw[edge_type] = GraphWrapper( self._multi_gw[edge_type] = GraphWrapper(
name=type_name, name=type_name,
place=self._place,
node_feat=n_feat, node_feat=n_feat,
edge_feat=e_feat) edge_feat=e_feat)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册