提交 bbe1b9d8 编写于 作者: Y yelrose

Remove Place in GraphWrapper

上级 1dfe4882
......@@ -53,7 +53,6 @@ place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
# use GraphWrapper as a container for graph data to construct a graph neural network
gw = pgl.graph_wrapper.GraphWrapper(name='graph',
place = place,
node_feat=g.node_feat_info())
```
......
......@@ -77,7 +77,6 @@ place = fluid.CPUPlace()
# create a GraphWrapper as a container for graph data
gw = heter_graph_wrapper.HeterGraphWrapper(name='heter_graph',
place = place,
edge_types = g.edge_types_info(),
node_feat=g.node_feat_info(),
edge_feat=g.edge_feat_info())
......
......@@ -53,7 +53,6 @@ class GATNE(object):
self.gw = heter_graph_wrapper.HeterGraphWrapper(
name="heter_graph",
place=place,
edge_types=self.graph.edge_types_info(),
node_feat=self.graph.node_feat_info(),
edge_feat=self.graph.edge_feat_info())
......
......@@ -65,7 +65,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper(
name="graph",
place=place,
node_feat=dataset.graph.node_feat_info())
output = pgl.layers.gcn(gw,
......
......@@ -170,7 +170,7 @@ def main(args):
with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph", fluid.CPUPlace(), node_feat=[('feats', [None, 602], np.dtype('float32'))])
"sub_graph", node_feat=[('feats', [None, 602], np.dtype('float32'))])
model_loss, model_acc = build_graph_model(
graph_wrapper,
num_class=data["num_class"],
......
......@@ -204,8 +204,8 @@ def main(args):
graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model(
graph_wrapper,
num_class=data["num_class"],
......
......@@ -231,7 +231,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model(
......
......@@ -227,7 +227,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program):
graph_wrapper = pgl.graph_wrapper.GraphWrapper(
"sub_graph",
fluid.CPUPlace(),
node_feat=data['graph'].node_feat_info())
model_loss, model_acc = build_graph_model(
......
......@@ -49,7 +49,6 @@ def main(args):
with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper(
"gw",
place,
node_feat=[('norm', [None, 1], "float32")],
edge_feat=[('weights', [None, 1], "float32")])
......
......@@ -88,7 +88,7 @@ def build_graph_model(args):
graph_wrappers.append(
pgl.graph_wrapper.GraphWrapper(
"layer_0", fluid.CPUPlace(), node_feat=node_feature_info))
"layer_0", node_feat=node_feature_info))
#edge_feat=[("f", [None, 1], "float32")]))
num_embed = args.num_nodes
......
......@@ -148,7 +148,6 @@ def main():
with fluid.program_guard(train_program, startup_program):
gw = pgl.graph_wrapper.GraphWrapper(
"graph",
place=place,
node_feat=graph_data.node_feat_info(),
edge_feat=graph_data.edge_feat_info())
pred = model.forward(gw)
......
......@@ -158,7 +158,6 @@ def main():
num_layers=args.num_layers)
gw = pgl.graph_wrapper.GraphWrapper(
"graph",
place,
node_feat=graph_data.node_feat_info(),
edge_feat=graph_data.edge_feat_info())
pred, prob, loss = model.forward(gw)
......
......@@ -475,9 +475,6 @@ class GraphWrapper(BaseGraphWrapper):
Args:
name: The graph data prefix
place: fluid.CPUPlace or fluid.CUDAPlace(n) indicating the
device to hold the graph data.
node_feat: A list of tuples that decribe the details of node
feature tenosr. Each tuple mush be (name, shape, dtype)
and the first dimension of the shape must be set unknown
......@@ -516,7 +513,6 @@ class GraphWrapper(BaseGraphWrapper):
})
graph_wrapper = GraphWrapper(name="graph",
place=place,
node_feat=graph.node_feat_info(),
edge_feat=graph.edge_feat_info())
......@@ -531,12 +527,11 @@ class GraphWrapper(BaseGraphWrapper):
ret = exe.run(fetch_list=[...], feed=feed_dict )
"""
def __init__(self, name, place, node_feat=[], edge_feat=[]):
def __init__(self, name, node_feat=[], edge_feat=[], **kwargs):
super(GraphWrapper, self).__init__()
# collect holders for PyReader
self._data_name_prefix = name
self._holder_list = []
self._place = place
self.__create_graph_attr_holders()
for node_feat_name, node_feat_shape, node_feat_dtype in node_feat:
self.__create_graph_node_feat_holders(
......
......@@ -44,9 +44,6 @@ class HeterGraphWrapper(object):
Args:
name: The heterogeneous graph data prefix
place: fluid.CPUPlace or fluid.CUDAPlace(n) indicating the
device to hold the graph data.
node_feat: A dict of list of tuples that decribe the details of node
feature tenosr. Each tuple mush be (name, shape, dtype)
and the first dimension of the shape must be set unknown
......@@ -85,19 +82,15 @@ class HeterGraphWrapper(object):
node_feat=node_feat,
edge_feat=edges_feat)
place = fluid.CPUPlace()
gw = heter_graph_wrapper.HeterGraphWrapper(
name='heter_graph',
place = place,
edge_types = g.edge_types_info(),
node_feat=g.node_feat_info(),
edge_feat=g.edge_feat_info())
"""
def __init__(self, name, place, edge_types, node_feat={}, edge_feat={}):
def __init__(self, name, edge_types, node_feat={}, edge_feat={}, **kwargs):
self.__data_name_prefix = name
self._place = place
self._edge_types = edge_types
self._multi_gw = {}
for edge_type in self._edge_types:
......@@ -114,7 +107,6 @@ class HeterGraphWrapper(object):
self._multi_gw[edge_type] = GraphWrapper(
name=type_name,
place=self._place,
node_feat=n_feat,
edge_feat=e_feat)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册