diff --git a/AUTHORS.md b/AUTHORS.md index 4a904a51ac5eeacd8a7f08d624e0fd683c457584..8932ba2268212d34b700807047e391a2e887f267 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -7,4 +7,3 @@ | Macrobull | Nai-Rui Luo | | Channingss | Ling-Chi Chen | | mamingjie-China | Ming-Jie Ma | - diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 84bafc0f0467944406343f565076a16e5bb57e1f..e22f69ba6fcf556df2da2c087edde091d07789d4 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -78,6 +78,8 @@ def tf2paddle(model_path, define_input_shape=False): # check tensorflow installation and version try: + import os + os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3' import tensorflow as tf version = tf.__version__ if version >= '2.0.0' or version < '1.0.0': @@ -109,6 +111,9 @@ def tf2paddle(model_path, optimizer = TFOptimizer(mapper) optimizer.delete_redundance_code() optimizer.strip_graph() + optimizer.merge_activation() + optimizer.merge_bias() + optimizer.remove_transpose() mapper.save_inference_model(save_dir) diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index c337df4ccc58318c53a5d7bb6b7d3e593d3c9c55..4115013548519975c377a4072d032b4d7e98afbb 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -15,7 +15,6 @@ from x2paddle.core.graph import GraphNode, Graph from x2paddle.core.fluid_code import FluidCode from tensorflow.python.framework import tensor_util -from tensorflow.python.platform import gfile from tensorflow.core.framework import attr_value_pb2 import tensorflow as tf import copy as cp @@ -140,7 +139,7 @@ class TFGraph(Graph): raise Exception("Node[{}] not in graph".format(node_name)) inputs = self.node_map[node_name].inputs outputs = self.node_map[node_name].outputs - assert len(inputs) == 1 + # assert len(inputs) == 1 input_node = self.node_map[inputs[0]] idx = input_node.outputs.index(node_name) del input_node.outputs[idx] @@ -205,18 +204,28 @@ class TFGraph(Graph): class TFDecoder(object): def __init__(self, pb_model, data_format="NHWC", define_input_shape=False): - self.sess = tf.Session() + try: + self.sess = tf.compat.v1.Session() + except: + self.sess = tf.Session() self.input_info = dict() self.define_input_shape = define_input_shape - with gfile.FastGFile(pb_model, 'rb') as f: - graph_def = tf.GraphDef() + with open(pb_model, 'rb') as f: + try: + graph_def = tf.compat.v1.GraphDef() + except: + graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) input_map = self._check_input_shape(graph_def) self._fix_output_shape(graph_def) self.sess.graph.as_default() tf.import_graph_def(graph_def, name='', input_map=input_map) - self.sess.run(tf.global_variables_initializer()) + try: + initializer = tf.compat.v1.global_variables_initializer() + except: + initializer = tf.global_variables_initializer() + self.sess.run(initializer) self.tf_graph = TFGraph( self.sess.graph._as_graph_def(add_shapes=True)[0], data_format) @@ -237,7 +246,6 @@ class TFDecoder(object): continue graph_node = TFGraphNode(layer) dtype = graph_node.layer.attr['dtype'].type - print("========dtype", dtype) need_define_shape = 0 if self.define_input_shape: @@ -284,11 +292,17 @@ class TFDecoder(object): for dim in shape.strip().split(',') ] assert shape.count(None) <= 1, "Only one dimension can be None" - print("]]]]]]]]]dtype", dtype) - x2paddle_input = tf.placeholder(dtype=dtype, - shape=shape, - name="x2paddle_{}".format( - layer.name)) + try: + x2paddle_input = tf.compat.v1.placeholder( + dtype=dtype, + shape=shape, + name="x2paddle_{}".format(layer.name)) + except: + x2paddle_input = tf.placeholder(dtype=dtype, + shape=shape, + name="x2paddle_{}".format( + layer.name)) + input_map["{}:0".format(layer.name)] = x2paddle_input if shape.count(None) > 0: shape[shape.index(None)] = -1 @@ -304,7 +318,6 @@ class TFDecoder(object): # trick method # should be removed after PaddlePaddle V1.6 been released def infer_tensor(self, graph_node): - print("========== Use infer_tensor for tensor: ", graph_node.layer.name) if hasattr(graph_node, "index"): tensor_name = graph_node.layer.name + ":{}".format(graph_node.index) else: @@ -320,8 +333,6 @@ class TFDecoder(object): return self.sess.run([output_tensor], feed)[0] def infer_shape_tensor(self, graph_node, out_shape=None): - print("========== Use infer_shape_tensor for tensor: ", - graph_node.layer.name) if hasattr(graph_node, "index"): tensor_name = graph_node.layer.name + ":{}".format(graph_node.index) else: diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index 704ca2a141869dd0292763eb85690f05db9b3323..0f12fe3b1d1d74e9933738c88b660c83015f11a4 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -67,13 +67,15 @@ class TFOpMapper(OpMapper): 'RealDiv': 'elementwise_div', 'Sub': 'elementwise_sub', 'Maximum': 'elementwise_max', - 'Mul': 'elementwise_mul' + 'Mul': 'elementwise_mul', + 'FloorDiv': 'elementwise_floordiv' } def __init__(self, decoder): super(TFOpMapper, self).__init__() self.decoder = decoder self.graph = decoder.tf_graph + self.batch_node = None self.weights = dict() self.omit_nodes = list() self.used_custom_layers = dict() @@ -86,9 +88,10 @@ class TFOpMapper(OpMapper): idx = self.graph.input_nodes.index(name) del self.graph.input_nodes[idx] - print("Total nodes: {}".format(len(self.graph.topo_sort))) + sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort))) unsupported_ops = set() - for node_name in self.graph.topo_sort: + for i, node_name in enumerate(self.graph.topo_sort): + sys.stderr.write("\rConverting node {} ... ".format(i + 1)) node = self.graph.get_node(node_name) op = node.layer_type if op in self.directly_map_ops: @@ -107,11 +110,13 @@ class TFOpMapper(OpMapper): else: unsupported_ops.add(op) if len(unsupported_ops) > 0: - print("=========={} Ops are not supported yet======".format( - len(unsupported_ops))) + sys.stderr.write( + "=========={} Ops are not supported yet======\n".format( + len(unsupported_ops))) for op in unsupported_ops: - print("========== {} ==========".format(op)) + sys.stderr.write("========== {} ==========\n".format(op)) sys.exit(-1) + sys.stderr.write('\nDone!\n') def add_omit_nodes(self, in_node_name, out_node_name): in_node = self.graph.get_node(in_node_name) @@ -144,6 +149,10 @@ class TFOpMapper(OpMapper): y = self.graph.get_node(node.layer.input[1], copy=True) x_shape = x.out_shapes[0] y_shape = y.out_shapes[0] + if len(x_shape) == 0: + x_shape = [1] + if len(y_shape) == 0: + y_shape = [1] # incomplement broadcasting support for paddle x_input = x y_input = y @@ -237,6 +246,9 @@ class TFOpMapper(OpMapper): 'name': string(node.layer_name), 'append_batch_size': False } + if shape[0] < 0: + self.batch_node = node + node.fluid_code.add_layer("data", inputs=None, output=node, @@ -285,17 +297,28 @@ class TFOpMapper(OpMapper): perm = perm.value.tolist() if perm == [0, 3, 1, 2] and input.data_format == "NHWC": - node.fluid_code.add_layer("assign", - inputs=input, - output=node, - param_attr=None) + # node.fluid_code.add_layer("assign", + # inputs=input, + # output=node, + # param_attr=None) + input_name = input.layer_name + if hasattr(input, "index"): + input_name = input_name + "[{}]".format(input.index) + node.fluid_code.add_layer("{} = {}").format(node.layer_name, + input_name) node.tf_data_format = "NCHW" self.graph.data_format_propagation(node) elif perm == [0, 2, 3, 1] and input.tf_data_format == "NCHW": - node.fluid_code.add_layer("assign", - inputs=input, - output=node, - param_attr=None) + input_name = input.layer_name + if hasattr(input, "index"): + input_name = input_name + "[{}]".format(input.index) + node.fluid_code.add_layer("{} = {}").format(node.layer_name, + input_name) + # + # node.fluid_code.add_layer("assign", + # inputs=input, + # output=node, + # param_attr=None) node.tf_data_format = "NHWC" self.graph.data_format_propagation(node) elif len(input.out_shapes[0]) > 4: @@ -564,6 +587,20 @@ class TFOpMapper(OpMapper): new_param += (node.layer_name + "[{}]".format(i) + ", ") new_param = new_param.strip(", ") + "]" attr = {"shape": new_param} + + if len(input.out_shapes[0]) == 4 and node.tf_data_format == "NHWC": + if len(attr["shape"]) < 3: + perm = {"perm": [0, 2, 3, 1]} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=perm) + node.fluid_code.add_layer("reshape", + inputs=node, + output=node, + param_attr=attr) + return + if len(attr["shape"]) == 4 and node.tf_data_format == "NHWC": input_shape = self.decoder.infer_tensor(input).shape if input_shape[1] == attr["shape"][1]: @@ -860,17 +897,32 @@ class TFOpMapper(OpMapper): size = [size[i] for i in [0, 3, 1, 2]] begin = [begin[i] for i in [0, 3, 1, 2]] - attr = {"shape": size, "offsets": begin} - node.fluid_code.add_layer("crop", + for i in range(len(size)): + if size[i] < 0: + size[i] = 99999999 + else: + size[i] = size[i] + begin[i] + + attr = { + "axes": [i for i in range(len(size))], + "starts": begin, + "ends": size + } + node.fluid_code.add_layer("slice", inputs=input, output=node, param_attr=attr) def Conv2DBackpropInput(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) + out_shape = self.graph.get_node(node.layer.input[0], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True) + input = self.graph.get_node(node.layer.input[2], copy=True) + assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const" + self.add_omit_nodes(kernel.layer_name, node.layer_name) + self.add_omit_nodes(out_shape.layer_name, node.layer_name) + in_shape = input.out_shapes[0] if in_shape.count(-1) > 2: in_shape = self.decoder.infer_tensor(input).shape @@ -878,14 +930,14 @@ class TFOpMapper(OpMapper): if k_size.count(-1) > 2: k_size = self.decoder.infer_tensor(kernel).shape + pad_mode = node.get_attr("padding") strides = node.get_attr("strides") dilations = node.get_attr("dilations") data_format = node.get_attr("data_format").decode() - pad_mode = node.get_attr("padding").decode() channel_first = data_format == "NCHW" + self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose( kernel.value, (3, 2, 0, 1)) - if not channel_first: in_shape = [in_shape[i] for i in [0, 3, 1, 2]] strides = [strides[i] for i in [0, 3, 1, 2]] @@ -906,6 +958,7 @@ class TFOpMapper(OpMapper): output=node, param_attr=attr) input = node + attr = { "bias_attr": False, "param_attr": string(kernel.layer_name), @@ -915,11 +968,10 @@ class TFOpMapper(OpMapper): "dilation": dilations[2:4], "padding": padding } - node.fluid_code.add_layer( - "conv2d_transpose", - inputs=input if channel_first and pad_mode != "SAME" else node, - output=node, - param_attr=attr) + node.fluid_code.add_layer("conv2d_transpose", + inputs=input, + output=node, + param_attr=attr) def Max(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -960,18 +1012,19 @@ class TFOpMapper(OpMapper): output=node, param_attr=attr) - def FloorDiv(self, node): - x = self.graph.get_node(node.layer.input[0], copy=True) - y = self.graph.get_node(node.layer.input[1], copy=True) - inputs = {'x': x, 'y': y} - node.fluid_code.add_layer("elementwise_div", - inputs=inputs, - output=node, - param_attr=None) - node.fluid_code.add_layer("floor", - inputs=node, - output=node, - param_attr=None) + +# def FloorDiv(self, node): +# x = self.graph.get_node(node.layer.input[0], copy=True) +# y = self.graph.get_node(node.layer.input[1], copy=True) +# inputs = {'x': x, 'y': y} +# node.fluid_code.add_layer("elementwise_div", +# inputs=inputs, +# output=node, +# param_attr=None) +# node.fluid_code.add_layer("floor", +# inputs=node, +# output=node, +# param_attr=None) def Split(self, node): dim = self.graph.get_node(node.layer.input[0], copy=True) @@ -1082,3 +1135,34 @@ class TFOpMapper(OpMapper): inputs=input, output=node, param_attr=attr) + + def GreaterEqual(self, node): + x = self.graph.get_node(node.layer.input[0], copy=True) + y = self.graph.get_node(node.layer.input[1], copy=True) + inputs = {"x": x, "y": y} + node.fluid_code.add_layer("greater_equal", + inputs=inputs, + output=node, + param_attr=None) + + def RandomUniform(self, node): + shape = self.graph.get_node(node.layer.input[0], copy=True) + self.add_omit_nodes(shape.layer_name, node.layer_name) + if shape.layer_type == "Const": + shape = shape.value.tolist() + else: + shape = self.decoder.infer_shape_tensor(shape) + if node.tf_data_format == "NHWC" and len(shape) == 4: + shape = [shape[i] for i in [0, 3, 1, 2]] + attr = {"shape": shape, "min": 0.0, "max": 0.9999} + if shape[0] < 0: + input = self.batch_node + node.fluid_code.add_layer("uniform_random_batch_size_like", + inputs=input, + output=node, + param_attr=attr) + else: + node.fluid_code.add_layer("uniform_random", + inputs=None, + output=node, + param_attr=attr) diff --git a/x2paddle/op_mapper/tf_op_mapper_nhwc.py b/x2paddle/op_mapper/tf_op_mapper_nhwc.py index 8d5be31fc9d11fdbbd6e2dfb968707f0f48129a4..eccbb44d8b2c45401bcdff48323d4c5f3283c9b6 100644 --- a/x2paddle/op_mapper/tf_op_mapper_nhwc.py +++ b/x2paddle/op_mapper/tf_op_mapper_nhwc.py @@ -56,6 +56,7 @@ class TFOpMapperNHWC(OpMapper): self.decoder = decoder self.graph = decoder.tf_graph self.weights = dict() + self.batch_node = None self.omit_nodes = list() self.used_custom_layers = dict() @@ -68,8 +69,9 @@ class TFOpMapperNHWC(OpMapper): del self.graph.input_nodes[idx] unsupported_ops = set() - print("Total nodes: {}".format(len(self.graph.topo_sort))) - for node_name in self.graph.topo_sort: + sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort))) + for i, node_name in enumerate(self.graph.topo_sort): + sys.stderr.write("\rConverting node {} ... ".format(i)) node = self.graph.get_node(node_name) op = node.layer_type if op in self.directly_map_ops: @@ -94,6 +96,7 @@ class TFOpMapperNHWC(OpMapper): for op in unsupported_ops: print("========== {} ============".format(op)) sys.exit(-1) + sys.stderr.write("\nDone\n") def add_omit_nodes(self, in_node_name, out_node_name): in_node = self.graph.get_node(in_node_name) @@ -126,6 +129,10 @@ class TFOpMapperNHWC(OpMapper): y = self.graph.get_node(node.layer.input[1], copy=True) x_shape = x.out_shapes[0] y_shape = y.out_shapes[0] + if len(x_shape) == 0: + x_shape = [1] + if len(y_shape) == 0: + y_shape = [1] # incomplement broadcasting support for paddle x_input = x y_input = y @@ -199,6 +206,8 @@ class TFOpMapperNHWC(OpMapper): 'name': string(node.layer_name), 'append_batch_size': False } + if shape[0] < 0: + self.batch_node = node node.fluid_code.add_layer("data", inputs=None, output=node, @@ -823,7 +832,6 @@ class TFOpMapperNHWC(OpMapper): inputs=input, output=node, param_attr=attr) - print(node.layer.name) if len(new_axes) > 0: attr = {"axes": new_axes} node.fluid_code.add_layer("unsqueeze", @@ -857,17 +865,32 @@ class TFOpMapperNHWC(OpMapper): else: size = self.decoder.infer_tensor(size).tolist() - attr = {"shape": size, "offsets": begin} - node.fluid_code.add_layer("crop", + for i in range(len(size)): + if size[i] < 0: + size[i] = 99999999 + else: + size[i] = size[i] + begin[i] + + attr = { + "axes": [i for i in range(len(size))], + "starts": begin, + "ends": size + } + + node.fluid_code.add_layer("slice", inputs=input, output=node, param_attr=attr) def Conv2DBackpropInput(self, node): - input = self.graph.get_node(node.layer.input[0], copy=True) + out_shape = self.graph.get_node(node.layer.input[0], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True) + input = self.graph.get_node(node.layer.input[2], copy=True) + assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const" + self.add_omit_nodes(kernel.layer_name, node.layer_name) + self.add_omit_nodes(out_shape.layer_name, node.layer_name) in_shape = input.out_shapes[0] if in_shape.count(-1) > 2: @@ -876,14 +899,14 @@ class TFOpMapperNHWC(OpMapper): if k_size.count(-1) > 2: k_size = self.decoder.infer_tensor(kernel).shape + pad_mode = node.get_attr("padding") strides = node.get_attr("strides") dilations = node.get_attr("dilations") data_format = node.get_attr("data_format").decode() - pad_mode = node.get_attr("padding").decode() channel_first = data_format == "NCHW" + self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose( kernel.value, (3, 2, 0, 1)) - if not channel_first: in_shape = [in_shape[i] for i in [0, 3, 1, 2]] strides = [strides[i] for i in [0, 3, 1, 2]] @@ -894,6 +917,9 @@ class TFOpMapperNHWC(OpMapper): output=node, param_attr=attr) input = node + else: + self.data_format_propagation(node) + padding = 0 if pad_mode == "SAME": pad_h = get_same_padding(in_shape[2], k_size[0], strides[2]) @@ -907,6 +933,7 @@ class TFOpMapperNHWC(OpMapper): output=node, param_attr=attr) input = node + attr = { "bias_attr": False, "param_attr": string(kernel.layer_name), @@ -920,6 +947,7 @@ class TFOpMapperNHWC(OpMapper): inputs=input, output=node, param_attr=attr) + if not channel_first: attr = {"perm": [0, 2, 3, 1]} node.fluid_code.add_layer("transpose", @@ -1062,3 +1090,32 @@ class TFOpMapperNHWC(OpMapper): inputs=node, output=node, param_attr=attr) + + def GreaterEqual(self, node): + x = self.graph.get_node(node.layer.input[0], copy=True) + y = self.graph.get_node(node.layer.input[1], copy=True) + inputs = {"x": x, "y": y} + node.fluid_code.add_layer("greater_equal", + inputs=inputs, + output=node, + param_attr=None) + + def RandomUniform(self, node): + shape = self.graph.get_node(node.layer.input[0], copy=True) + self.add_omit_nodes(shape.layer_name, node.layer_name) + if shape.layer_type == "Const": + shape = shape.value.tolist() + else: + shape = self.decoder.infer_shape_tensor(shape) + attr = {"shape": shape, "min": 0.0, "max": 0.9999} + if shape[0] < 0: + input = self.batch_node + node.fluid_code.add_layer("uniform_random_batch_size_like", + inputs=input, + output=node, + param_attr=attr) + else: + node.fluid_code.add_layer("uniform_random", + inputs=None, + output=node, + param_attr=attr) diff --git a/x2paddle/optimizer/tf_optimizer.py b/x2paddle/optimizer/tf_optimizer.py index cba381e95750f78bbf34b6e27b2227462dfbce5a..6c49a64ab56b69a1c0ef426e496f21ef46fd9c25 100644 --- a/x2paddle/optimizer/tf_optimizer.py +++ b/x2paddle/optimizer/tf_optimizer.py @@ -26,10 +26,12 @@ class TFOptimizer(object): } layers_with_act = [ 'Conv2D', 'BiasAdd', 'DepthwiseConv2dNative', 'Conv2DBackpropInput', - 'FusedBatchNorm' + 'FusedBatchNorm', 'conv2d', 'elementwise_add', 'conv2d_transpose', + 'batch_norm' ] layers_with_bias = [ - 'Conv2D', 'DepthwiseConv2dNative', 'Conv2DBackpropInput' + 'Conv2D', 'DepthwiseConv2dNative', 'Conv2DBackpropInput', 'conv2d', + 'conv2d_transpose' ] def __init__(self, op_mapper): @@ -129,7 +131,12 @@ class TFOptimizer(object): continue if len(input.outputs) != 1: continue - input.fluid_code.layers[-1].param_attr['act'] = string( + index = -1 + for i in range(len(input.fluid_code.layers)): + if input.fluid_code.layers[i].op in self.layers_with_act: + index = i + break + input.fluid_code.layers[index].param_attr['act'] = string( self.activation_ops[node.layer_type]) input.fluid_code.layers[-1].output = node.fluid_code.layers[ 0].output @@ -153,45 +160,70 @@ class TFOptimizer(object): if 'act' in node.fluid_code.layers[-1].param_attr: bias_with_act = True layer_with_act = False + index = -1 + for i in range(len(input.fluid_code.layers)): + if input.fluid_code.layers[i].op in self.layers_with_bias: + index = i + break if 'act' in input.fluid_code.layers[ - -1].param_attr and input.fluid_code.layers[ - -1].param_attr['act'] is not None: + index].param_attr and input.fluid_code.layers[ + index].param_attr['act'] is not None: layer_with_act = True if bias_with_act and layer_with_act: continue - if not input.fluid_code.layers[-1].param_attr['bias_attr']: + if not input.fluid_code.layers[index].param_attr['bias_attr']: bias_name = node.inputs[1] - input.fluid_code.layers[-1].param_attr[ + input.fluid_code.layers[index].param_attr[ 'bias_attr'] = string(bias_name) input.fluid_code.layers[-1].output = node.fluid_code.layers[ 0].output if bias_with_act: - input.fluid_code.layers[-1].param_attr[ + input.fluid_code.layers[index].param_attr[ 'act'] = node.fluid_code.layers[-1].param_attr[ 'act'] node.fluid_code.clear() + self.graph.remove_node(node.layer_name) + + def remove_transpose(self): + optimize_ops = [ + 'Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative', + 'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor', + 'ResizeBilinear' + ] + for node_name in self.graph.topo_sort: + node = self.graph.get_node(node_name) + if node is None: + continue + if node.layer_type not in optimize_ops: + continue + if node.fluid_code.layers[ + -1].op != "transpose" or node.fluid_code.layers[ + -1].param_attr["perm"] != [0, 2, 3, 1]: + continue + output_names = node.outputs + can_be_removed = True + for out_name in output_names: + out_node = self.graph.get_node(out_name) + if out_node.layer_type == "BiasAdd": + can_be_removed = True + if out_node.fluid_code.layers[ + 0].op != "transpose" or out_node.fluid_code.layers[ + 0].param_attr["perm"] != [0, 3, 1, 2]: + can_be_removed = False + break + + if can_be_removed and len(output_names) > 0: + last_out = node.fluid_code.layers[-1].inputs + del node.fluid_code.layers[-1] + for out_name in output_names: + out_node = self.graph.get_node(out_name) + if out_node.layer_type == "BiasAdd": + del out_node.fluid_code.layers[0] + out_node.fluid_code.layers[0].inputs['x'] = last_out -# def remove_transpose(self): -# optimize_ops = ['Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative', 'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor', 'ResizeBilinear'] -# for node_name in self.graph.topo_sort: -# node = self.graph.get_node(node_name) -# if node.layer_type not in optimize_ops: -# continue -# if node.fluid_code.layers[-1].op != "transpose" or node.fluid_code.layers[-1].param_attr["perm"] != [0, 2, 3, 1]: -# continue -# output_names = node.outputs -# can_be_removed = True -# for out_name in outputs_names: -# out_node = self.graph.get_node(out_name) -# if out_node.fluid_code.layers[0].op != "transpose" or out_node.fluid_code.layers[-1].param_attr["perm"] != [0, 3, 1, 2]: -# can_be_removed = False -# break -# if can_be_removed and len(output_names) > 0: -# last_out = node.fluid_code.layers[-1].inputs -# del node.fluid_code.layers[-1] -# for out_name in outputs_names: -# out_node = self.graph.get_node(out_name) -# del out_node.fluid_code.layers[0] -# out_node.fluid_code.layers[0].inputs = last_out +# out_node.fluid_code.layers[0].param_attr["axis"] = 1 + else: + del out_node.fluid_code.layers[0] + out_node.fluid_code.layers[0].inputs = last_out