From d248ff64cf91261a569acc6623e51c4aab377db1 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Thu, 1 Aug 2019 11:17:47 +0800 Subject: [PATCH] fix code --- x2paddle/core/fluid_code.py | 44 +++++--- x2paddle/decoder/tf_decoder.py | 5 +- x2paddle/op_mapper/tf_op_mapper.py | 169 +++++++++-------------------- 3 files changed, 82 insertions(+), 136 deletions(-) diff --git a/x2paddle/core/fluid_code.py b/x2paddle/core/fluid_code.py index 66cfb93..ba2f57a 100644 --- a/x2paddle/core/fluid_code.py +++ b/x2paddle/core/fluid_code.py @@ -13,6 +13,7 @@ # limitations under the License. from x2paddle.core.graph import GraphNode +import collections class Layer(object): @@ -36,25 +37,34 @@ class Layer(object): if isinstance(self.inputs, list): in_list = "[" for input in self.inputs: - assert isinstance( - input, GraphNode), "Type of input should be GraphNode" - if hasattr(input, "index"): - in_list += (input.layer_name + "[{}]".format(input.index) + - ", ") + if isinstance(input, GraphNode): + if hasattr(input, "index"): + in_list += (input.layer_name + + "[{}]".format(input.index) + ", ") + else: + in_list += (input.layer_name + ", ") + elif isinstance(input, str): + in_list += (input + ", ") else: - in_list += (input.layer_name + ", ") + raise Exception( + "Element of inputs should GraphNode or String") in_list = in_list.strip(", ") + "], " layer_code += in_list elif isinstance(self.inputs, dict): - for key, input in self.inputs.items(): - assert isinstance( - input, GraphNode), "Type of input should be GraphNode" - if hasattr(input, "index"): - layer_code = layer_code + key + "={}, ".format( - input.layer_name + "[{}]".format(input.index)) + inputs = collections.OrderedDict(self.inputs) + for key, input in inputs.items(): + if isinstance(input, GraphNode): + if hasattr(input, "index"): + layer_code = layer_code + key + "={}, ".format( + input.layer_name + "[{}]".format(input.index)) + else: + layer_code = layer_code + key + "={}, ".format( + input.layer_name) + elif isinstance(input, str): + layer_code = layer_code + key + "={}, ".format(input) else: - layer_code = layer_code + key + "={}, ".format( - input.layer_name) + raise Exception( + "Element of inputs should GraphNode or String") elif isinstance(self.inputs, GraphNode): if hasattr(self.inputs, "index"): layer_code += (self.inputs.layer_name + @@ -66,7 +76,8 @@ class Layer(object): else: raise Exception("Unknown type of inputs.") - for key, value in self.param_attr.items(): + param_attr = collections.OrderedDict(self.param_attr) + for key, value in param_attr.items(): layer_code = layer_code + key + "={}, ".format(value) layer_code = layer_code.strip(", ") @@ -97,7 +108,8 @@ class Layer(object): else: raise Exception("Unknown type of inputs.") - for key, value in self.param_attr.items(): + param_attr = collections.OrderedDict(self.param_attr) + for key, value in param_attr.items(): layer_code = layer_code + key + "={}, ".format(value) layer_code = layer_code.strip(", ") diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index 898ea3b..fe64414 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -176,8 +176,9 @@ class TFDecoder(object): self.sess.graph.as_default() tf.import_graph_def(graph_def, name='', input_map=input_map) - for node in graph_def.node: - print(node.name, node.op, node.input) + +# for node in graph_def.node: +# print(node.name, node.op, node.input) self.sess.run(tf.global_variables_initializer()) diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index ad23138..f0a886f 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -19,6 +19,31 @@ import numpy class TFOpMapper(OpMapper): + + directly_map_ops = { + 'Relu': ['relu'], + 'Relu6': ['relu6'], + 'Shape': ['shape'], + 'Abs': ['abs'], + 'Sigmoid': ['sigmoid'], + 'Exp': ['exp'], + 'Rsqrt': ['rsqrt'], + 'Squeeze': ['squeeze', { + 'squeeze_dims': 'axes' + }], + 'Softmax': ['softmax', { + 'axis': 'axis' + }], + } + elementwise_ops = { + 'Add': 'elementwise_add', + 'RealDiv': 'elementwise_div', + 'BiasAdd': 'elementwise_add', + 'Sub': 'elementwise_sub', + 'Maximum': 'elementwise_max', + 'Mul': 'elementwise_mul' + } + def __init__(self, decoder): super(TFOpMapper, self).__init__() self.decoder = decoder @@ -30,15 +55,20 @@ class TFOpMapper(OpMapper): print("Total nodes: {}".format(len(self.graph.topo_sort))) # check if ops in model are all supported - if not self.op_checker(): - raise Exception("Model are not supported yet.") + # TODO for node_name in self.graph.topo_sort: node = self.graph.get_node(node_name) op = node.layer_type - if hasattr(self, op): + if op in self.directly_map_ops: + self.directly_map(node) + elif op in self.elementwise_ops: + self.elementwise_map(node) + elif hasattr(self, op): func = getattr(self, op) func(node) + else: + raise Exception("OP: [{}] not support yet".format(op)) for i in range(len(self.graph.topo_sort)): node_name = self.graph.topo_sort[i] @@ -47,7 +77,24 @@ class TFOpMapper(OpMapper): node = self.graph.get_node(node_name) self.net_code += node.fluid_code.gen_codes() - def elementwise_operator(self, node, op_type): + def directly_map(self, node): + assert node.layer_type in self.directly_map_ops + op_info = self.directly_map_ops[node.layer_type] + input = self.graph.get_node(node.layer.input[0], copy=True) + attr = dict() + for param in op_info[1:]: + tf_param_name = list(param.keys())[0] + pd_param_name = list(param.values())[0] + tf_param = node.get_attr(tf_param_name) + attr[pd_param_name] = tf_param + node.fluid_code.add_layer(op_info[0], + inputs=input, + output=node, + param_attr=attr) + + def elementwise_map(self, node): + assert node.layer_type in self.elementwise_ops + op_type = self.elementwise_ops[node.layer_type] x = self.graph.get_node(node.layer.input[0], copy=True) y = self.graph.get_node(node.layer.input[1], copy=True) x_shape = x.out_shapes[0] @@ -161,41 +208,6 @@ class TFOpMapper(OpMapper): output=node, param_attr=attr) - def RealDiv(self, node): - self.elementwise_operator(node, "elementwise_div") - - def Relu(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) - node.fluid_code.add_layer("relu", - inputs=input, - output=node, - param_attr=None) - - def Squeeze(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) - squeeze_dims = node.get_attr('squeeze_dims') - attr = {'axes': squeeze_dims} - node.fluid_code.add_layer("squeeze", - inputs=input, - output=node, - param_attr=attr) - - def BiasAdd(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) - bias = self.graph.get_node(node.layer.input[1], copy=True) - inputs = {'x': input, 'y': bias} - node.fluid_code.add_layer("elementwise_add", - inputs=inputs, - output=node, - param_attr=None) - - def Identity(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) - node.fluid_code.add_layer("assign", - inputs=input, - output=node, - param_attr=None) - def MaxPool(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -314,13 +326,6 @@ class TFOpMapper(OpMapper): output=node, param_attr=attr) - def Relu6(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) - node.fluid_code.add_layer("relu6", - inputs=input, - output=node, - param_attr=None) - def FusedBatchNorm(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) gamma = self.graph.get_node(node.layer.input[1], copy=True) @@ -433,13 +438,6 @@ class TFOpMapper(OpMapper): output=node, param_attr=attr) - def Shape(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) - node.fluid_code.add_layer("shape", - inputs=input, - output=node, - param_attr=None) - def Reshape(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) param = self.graph.get_node(node.layer.input[1], copy=True) @@ -474,27 +472,6 @@ class TFOpMapper(OpMapper): inputs=input, output=node, param_attr=attr) - # temporary shape inference fix - - -# if param.layer_type == "Pack": -# shape_slices = list() -# for i in range(len(param.layer.input)): -# slice = self.graph.get_node(param.layer.input[i], copy=True) -# if slice.layer_type == "Const": -# shape_slices.append(slice.value.tolist()) -# else: -# shape_slices.append(0) -# if shape_slices.count(-1) == 0: -# shape_slices[shape_slices.index(0)] = -1 -# attr = {"shape": shape_slices} -# node.fluid_code.add_layer("reshape", -# inputs=node, -# output=node, -# param_attr=attr) - - def Add(self, node): - self.elementwise_operator(node, "elementwise_add") def AvgPool(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -542,23 +519,6 @@ class TFOpMapper(OpMapper): output=node, param_attr=attr) - def Softmax(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) - node.fluid_code.add_layer("softmax", - inputs=input, - output=node, - param_attr=None) - - def Sigmoid(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) - node.fluid_code.add_layer("sigmoid", - inputs=input, - output=node, - param_attr=None) - - def Maximum(self, node): - self.elementwise_operator(node, "elementwise_max") - def SplitV(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) num_sections = self.graph.get_node(node.layer.input[1], copy=True) @@ -576,13 +536,6 @@ class TFOpMapper(OpMapper): output=node, param_attr=attr) - def Exp(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) - node.fluid_code.add_layer("exp", - inputs=input, - output=node, - param_attr=None) - def ConcatV2(self, node): inputs = [ self.graph.get_node(name, copy=True) @@ -649,19 +602,6 @@ class TFOpMapper(OpMapper): output=node, param_attr=None) - def Mul(self, node): - self.elementwise_operator(node, "elementwise_mul") - - def Sub(self, node): - self.elementwise_operator(node, "elementwise_sub") - - def Rsqrt(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) - node.fluid_code.add_layer("rsqrt", - inputs=input, - output=node, - param_attr=None) - def swish_f32(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) node.fluid_code.add_layer("sigmoid", @@ -765,13 +705,6 @@ class TFOpMapper(OpMapper): output=node, param_attr=attr) - def Abs(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) - node.fluid_code.add_layer("abs", - inputs=input, - output=node, - param_attr=None) - def Conv2DBackpropInput(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True) -- GitLab