提交 707ece69 编写于 作者: Y yelrose

Remove Place in GraphWrapper

上级 1dfe4882
...@@ -65,7 +65,6 @@ def main(args): ...@@ -65,7 +65,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper( gw = pgl.graph_wrapper.GraphWrapper(
name="graph", name="graph",
place=place,
node_feat=dataset.graph.node_feat_info()) node_feat=dataset.graph.node_feat_info())
output = pgl.layers.gcn(gw, output = pgl.layers.gcn(gw,
......
...@@ -170,7 +170,7 @@ def main(args): ...@@ -170,7 +170,7 @@ def main(args):
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper( graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph", fluid.CPUPlace(), node_feat=[('feats', [None, 602], np.dtype('float32'))]) "sub_graph", node_feat=[('feats', [None, 602], np.dtype('float32'))])
model_loss, model_acc = build_graph_model( model_loss, model_acc = build_graph_model(
graph_wrapper, graph_wrapper,
num_class=data["num_class"], num_class=data["num_class"],
......
...@@ -204,8 +204,8 @@ def main(args): ...@@ -204,8 +204,8 @@ def main(args):
graph_wrapper = pgl.graph_wrapper.GraphWrapper( graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph", "sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info()) node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model( model_loss, model_acc = build_graph_model(
graph_wrapper, graph_wrapper,
num_class=data["num_class"], num_class=data["num_class"],
......
...@@ -231,7 +231,6 @@ def main(args): ...@@ -231,7 +231,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper( graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph", "sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info()) node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model( model_loss, model_acc = build_graph_model(
......
...@@ -227,7 +227,6 @@ def main(args): ...@@ -227,7 +227,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper( graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph", "sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info()) node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model( model_loss, model_acc = build_graph_model(
......
...@@ -49,7 +49,6 @@ def main(args): ...@@ -49,7 +49,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper( gw = pgl.graph_wrapper.GraphWrapper(
"gw", "gw",
place,
node_feat=[('norm', [None, 1], "float32")], node_feat=[('norm', [None, 1], "float32")],
edge_feat=[('weights', [None, 1], "float32")]) edge_feat=[('weights', [None, 1], "float32")])
......
...@@ -88,7 +88,7 @@ def build_graph_model(args): ...@@ -88,7 +88,7 @@ def build_graph_model(args):
graph_wrappers.append( graph_wrappers.append(
pgl.graph_wrapper.GraphWrapper( pgl.graph_wrapper.GraphWrapper(
"layer_0", fluid.CPUPlace(), node_feat=node_feature_info)) "layer_0", node_feat=node_feature_info))
#edge_feat=[("f", [None, 1], "float32")])) #edge_feat=[("f", [None, 1], "float32")]))
num_embed = args.num_nodes num_embed = args.num_nodes
......
...@@ -516,7 +516,6 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -516,7 +516,6 @@ class GraphWrapper(BaseGraphWrapper):
}) })
graph_wrapper = GraphWrapper(name="graph", graph_wrapper = GraphWrapper(name="graph",
place=place,
node_feat=graph.node_feat_info(), node_feat=graph.node_feat_info(),
edge_feat=graph.edge_feat_info()) edge_feat=graph.edge_feat_info())
...@@ -531,12 +530,11 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -531,12 +530,11 @@ class GraphWrapper(BaseGraphWrapper):
ret = exe.run(fetch_list=[...], feed=feed_dict ) ret = exe.run(fetch_list=[...], feed=feed_dict )
""" """
def __init__(self, name, place, node_feat=[], edge_feat=[]): def __init__(self, name, node_feat=[], edge_feat=[], **kwargs):
super(GraphWrapper, self).__init__() super(GraphWrapper, self).__init__()
# collect holders for PyReader # collect holders for PyReader
self._data_name_prefix = name self._data_name_prefix = name
self._holder_list = [] self._holder_list = []
self._place = place
self.__create_graph_attr_holders() self.__create_graph_attr_holders()
for node_feat_name, node_feat_shape, node_feat_dtype in node_feat: for node_feat_name, node_feat_shape, node_feat_dtype in node_feat:
self.__create_graph_node_feat_holders( self.__create_graph_node_feat_holders(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册