diff --git a/x2paddle/core/fluid_code.py b/x2paddle/core/fluid_code.py index 5424abbd3ac4c8fad0f5d8487c7b640551df4d82..0d428f823d418ae64889cde9eb36f5d334a242f8 100644 --- a/x2paddle/core/fluid_code.py +++ b/x2paddle/core/fluid_code.py @@ -17,18 +17,46 @@ class Layer(object): def __init__(self): self.op = None self.param_attr = dict() - self.input = None + self.inputs = dict() self.output = None - self.str_code = None def get_code(self): - if self.str_code is not None: - return self.str_code + layer_code = "" + if self.output is not None: + layer_code = self.output + " = " + + layer_code = layer_code + "fluid.layers." + self.op + "(" + + for key, tensor in self.inputs.items(): + layer_code = layer_code + key + "=" + tensor + ", " + + for key, value in self.param_attr.items(): + layer_code = layer_code + key + "=" + value + ", " + layer_code = layer_code.strip(", ") + return layer_code += ")" class FluidCode(object): def __init__(self): - self.codes = list() + self.layers = list() - def add_layer(self, op, input, output, param_attr=None): - + def add_layer(self, op, inputs, output, param_attr=None): + layer = Layer() + layer.op = op + layer.inputs = inputs + layer.output = output + if param_attr is not None: + layer.param_attr = param_attr + self.layers.append(layer) + + def add_note(self, note): + # note should be string + self.layers.append(note) + + def gen_codes(self): + codes = list() + for layer in self.layers: + if isinstance(layer, Layer): + codes.append(layer.get_code()) + elif isinstance(layer, str): + codes.append(layer) diff --git a/x2paddle/core/graph.py b/x2paddle/core/graph.py index 3f28e90dd5124935f781dcfba7f20fa4fc4a8d66..21b9bef83e96f270951d9dd3f13315abd681b1b3 100644 --- a/x2paddle/core/graph.py +++ b/x2paddle/core/graph.py @@ -62,14 +62,16 @@ class Graph(object): num_inputs[name] = len(node.inputs) self.topo_sort = self.input_nodes[:] - for idx in range(len(self.topo_sort)): + while idx < len(self.topo_sort): current_node = self.node_map[self.topo_sort[idx]] for node in current_node.outputs: num_inputs[node.layer_name] -= 1 if num_inputs[node.layer_name] == 0: self.topo_sort.append(node.layer_name) + idx += 1 + for i, tmp in enumerate(self.topo_sort): - print(tmp) + print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs) def get_node(self, name): if name not in self.node_map: diff --git a/x2paddle/emitter/tf_emitter.py b/x2paddle/emitter/tf_emitter.py index 3968cd112489ad584f07fd624aec8f38d7e6a2ef..0e343aed57b9ec412675ed4dc7fc8b000a03fc13 100644 --- a/x2paddle/emitter/tf_emitter.py +++ b/x2paddle/emitter/tf_emitter.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from x2paddle.parser import TFGraph - +from x2paddle.parser.tf_parser import TFGraph +from x2paddle.core.emitter import Emitter +class TFEmitter(Emitter): + def __init__(self): + super(TFEmitter, self diff --git a/x2paddle/optimizer/tf_optimizer.py b/x2paddle/optimizer/tf_optimizer.py index 0d45a6200dede038fc9d76d411a8e24ce7c73237..9d2775ce239e8d6516410409f2cc5a15e0a6fec9 100644 --- a/x2paddle/optimizer/tf_optimizer.py +++ b/x2paddle/optimizer/tf_optimizer.py @@ -13,6 +13,18 @@ # limitations under the License. # TODO useless node remove +from x2paddle.parser.tf_parser import TFGraph + + +class TFGraphOptimizer(object): + def __init__(self): + print("Not Implement") + self.useless_op = [ + 'NoOp'] + + def remove_useless_node(self, graph): + for name, node in graph.node_map.items(): + if node.layer_type in self.useless_op: # TODO identity node remove