From 3cf5a607a5a1beb8d2d359b7d031b8a0b2bdffab Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Wed, 17 Jul 2019 12:03:13 +0800 Subject: [PATCH] {add codes --- x2paddle/core/fluid_code.py | 31 +++++++++++++++++++++++++++++-- x2paddle/core/graph.py | 17 +++++++++++++---- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/x2paddle/core/fluid_code.py b/x2paddle/core/fluid_code.py index 72cfd1a..65b4a18 100644 --- a/x2paddle/core/fluid_code.py +++ b/x2paddle/core/fluid_code.py @@ -32,8 +32,35 @@ class Layer(object): layer_code = layer_code + "fluid.layers." + self.op + "(" - for key, tensor in self.inputs.items(): - layer_code = layer_code + key + "={}, ".format(tensor) + if isinstance(self.inputs, list): + in_list = "[" + for input in self.inputs: + assert isinstance( + input, GraphNode), "Type of input should be GraphNode" + if hasattr(input, "index"): + in_list += (input.layer_name + "[{}]".format(input.index) + + ", ") + else: + in_list += (input.layer_name + ", ") + inlist = in_list.strip(", ") + "], " + elif isinstance(self.inputs, dict): + for key, input in self.inputs.items(): + assert isinstance( + input, GraphNode), "Type of input should be GraphNode" + if hasattr(input, "index"): + layer_code = layer_code + key + "={}, ".format( + input.layer_name + "[{}]".format(input.index)) + else: + layer_code = layer_code + key + "={}, ".format( + input.layer_name) + elif isinstance(self.inputs, GraphNode): + if hasattr(self.inputs, "index"): + layer_code += (self.inputs.layer_name + + "[{}]".format(self.inputs.index) + ", ") + else: + layer_code += (self.inputs.layer_name + ", ") + else: + raise Exception("Unknown type of inputs.") for key, value in self.param_attr.items(): layer_code = layer_code + key + "={}, ".format(value) diff --git a/x2paddle/core/graph.py b/x2paddle/core/graph.py index d4fec45..86aefd2 100644 --- a/x2paddle/core/graph.py +++ b/x2paddle/core/graph.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +from copy import deepcopy class GraphNode(object): @@ -72,16 +73,24 @@ class Graph(object): self.topo_sort.append(node) idx += 1 - def get_node(self, name): + def get_node(self, name, copy=False): if name not in self.node_map: if name.split(':')[0] in self.node_map: name_prefix, idx = name.split(':') - self.node_map[name_prefix].index = int(idx) - return self.node_map[name_prefix] + if copy: + node = deepcopy(self.node_map[name_prefix]) + else: + node = self.node_map[name_prefix] + node.index = int(idx) + return node else: raise Exception("Graph doesn't have node [%s]." % name) else: - return self.node_map[name] + if copy: + node = deepcopy(self.node_map[name]) + else: + node = self.node_map[name] + return node def connect(self, src, dst): if dst not in self.node_map: -- GitLab