diff --git a/x2paddle/convert.py b/x2paddle/convert.py index a3cace47febe0429c52fcaa119229ee684f7ae90..ff03af888e78d4547e5ba68d44486de0758c7d67 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -104,6 +104,7 @@ def tf2paddle(model_path, # neccesary optimization optimizer.delete_redundance_code() # optimizer below is experimental + optimizer.optimize_elementwise_op() optimizer.merge_activation() optimizer.merge_bias() optimizer.optimize_sub_graph() diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index 71ee7909a7e39b0bff5aa99f60ba504760f197b9..25f11a99c1c446d3c668e53eed5e7b10683cc66f 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -60,6 +60,15 @@ class TFGraphNode(GraphNode): raise Exception("Dtype[{}] not in dtype_map".format(dtype)) return self.dtype_map[dtype] + @property + def raw_dtype(self): + keys = ['dtype', 'Tidx', 'T', 'DstT'] + for k in keys: + dtype = self.layer.attr[k].type + if dtype > 0: + break + return dtype + @property def value(self): assert self.layer_type == "Const", "Only Const node has value." @@ -120,6 +129,7 @@ class TFGraph(Graph): # tensorflow graph optimize self._remove_isolated_node() self._remove_identity_node() + self._remove_cast_node() def get_node(self, node_name, copy=False): items = node_name.strip().split(':') @@ -190,6 +200,27 @@ class TFGraph(Graph): idx = self.output_nodes.index(node_name) self.output_nodes[idx] = input_node.layer_name + def _remove_cast_node(self): + cast_node = list() + for node_name, node in self.node_map.items(): + if node.layer_type == "Cast": + input = self.get_node(node.inputs[0]) + if input.layer_type != "Placeholder" or len(input.outputs) != 1: + continue + cast_node.append(node_name) + + for node_name in cast_node: + node = self.get_node(node_name) + input_node = self.get_node(node.inputs[0]) + input_node.layer.attr["dtype"].type = node.raw_dtype + self.remove_node(node_name) + + self.identity_map[node_name] = input_node.layer_name + + if node_name in self.output_nodes: + idx = self.output_nodes.index(node_name) + self.output_nodes[idx] = input_node.layer_name + def data_format_propagation(self, node): current_node = self.node_map[node.layer_name] current_node = node.tf_data_format diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index d9e565a115c17d518de9757540825d391dd1feee..847ebc8b96aab160b1747c1bb73a7fe6ecd4dae0 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -170,7 +170,28 @@ class TFOpMapper(OpMapper): x_shape = y.out_shapes[0] y_shape = x.out_shapes[0] else: - raise Exception("Unexpected situation happend") + if len(x_shape) == 1 and len(y_shape) == 4 and x_shape[ + 0] == y_shape[-1] and y_shape.count(-1) < 1: + shape = [1, x_shape[0], 1, 1] + attr = {"shape": shape} + node.fluid_code.add_layer("reshape", + inputs=x_input, + output="reshape_x", + param_attr=attr) + if y_shape[0] != 1: + attr = {"expand_times": [y_shape[0], 1, 1, 1]} + node.fluid_code.add_layer("expand", + inputs="reshape_x", + output="reshape_x", + param_attr=attr) + inputs = {"x": "reshape_x", "y": y_input} + node.fluid_code.add_layer(op_type, + inputs=inputs, + output=node, + param_attr=None) + return + else: + raise Exception("Unexpected situation happend") if len(x_shape) == 4 and len(y_shape) == 1: if x_input.tf_data_format == "NHWC": diff --git a/x2paddle/optimizer/tf_optimizer.py b/x2paddle/optimizer/tf_optimizer.py index 0efbc2f7a166dbc74ecbb897077c891812b3a624..99156844a24c619e15d83be2d9345feba73b7e3e 100644 --- a/x2paddle/optimizer/tf_optimizer.py +++ b/x2paddle/optimizer/tf_optimizer.py @@ -16,6 +16,7 @@ from x2paddle.op_mapper.tf_op_mapper import TFOpMapper from x2paddle.core.fluid_code import Layer from x2paddle.core.util import * +import six import numpy import copy as cp @@ -104,6 +105,59 @@ class TFOptimizer(object): del out_node.inputs[index] del self.graph.node_map[node_name] + def optimize_elementwise_op(self): + elementwise_ops = [ + 'Sub', 'Add', 'RealDiv', 'Maximum', 'Mul', 'FloorDiv', + 'GreaterEqual' + ] + revertable_ops = ['Add', 'Mul'] + for node_name in self.graph.topo_sort: + node = self.graph.get_node(node_name) + if node is None: + continue + if node.layer_type in elementwise_ops: + if len(node.fluid_code.layers) != 2: + continue + if node.fluid_code.layers[0].op != "expand": + continue + expand_out = node.fluid_code.layers[0].output + expand_in = node.fluid_code.layers[0].inputs + expand_times = node.fluid_code.layers[0].param_attr[ + "expand_times"] + + x = node.fluid_code.layers[1].inputs["x"] + y = node.fluid_code.layers[1].inputs["y"] + if isinstance( + x, + six.string_types) and node.layer_type in revertable_ops: + node.fluid_code.layers[1].inputs["y"] = x + node.fluid_code.layers[1].inputs["x"] = y + x = node.fluid_code.layers[1].inputs["x"] + y = expand_in + elif isinstance(y, six.string_types): + y = expand_in + else: + continue + + x_shape = x.out_shapes[0] + y_shape = y.out_shapes[0] + if len(x_shape) != len(y_shape): + continue + if len(x_shape) == 4: + x_shape = [x_shape[i] for i in [0, 3, 1, 2]] + y_shape = [y_shape[i] for i in [0, 3, 1, 2]] + + continue_flag = True + for i in range(len(x_shape)): + if y_shape[-1 * (i + 1)] == 1 and continue_flag: + expand_times[-1 * (i + 1)] = 1 + else: + continue_flag = False + + if expand_times.count(1) == len(expand_times): + node.fluid_code.layers[1].inputs["y"] = expand_in + del node.fluid_code.layers[0] + def merge_activation(self): act_nodes = list() for node_name in self.graph.topo_sort: