diff --git a/x2paddle/convert.py b/x2paddle/convert.py index b52aefbb7c16874f5ad968d6670f65e550ecce0e..76b9ece5bc6c8009d53edcbfc6aa68e5b3917441 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -144,9 +144,9 @@ def tf2paddle(model_path, bias_opt = BiasOpt() transpose_opt = TransposeOpt() batch_norm_opt = BatchNormOpt() - bias_opt.run(program) - batch_norm_opt.run(program) - transpose_opt.run(program) + bias_opt.run(mapper.paddle_graph) + batch_norm_opt.run(mapper.paddle_graph) + transpose_opt.run(mapper.paddle_graph) mapper.paddle_graph.gen_model(save_dir) diff --git a/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py b/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py index 012d9bf217151cc7f90bb8621ff668f0a7278d7e..a1f56f26c3c871bff3908a2baffb522b6b39d742 100644 --- a/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py @@ -176,11 +176,12 @@ class TFOpMapper(OpMapper): x_shape = x.out_shapes[0] y_shape = y.out_shapes[0] - self.paddle_graph.add_layer( + layer_id = self.paddle_graph.add_layer( kernel=op_type, inputs={"x": x.name, "y": y.name}, outputs=[node.name]) + self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape} def NotEqual(self, node): x = self.graph.get_node(node.layer.input[0]) @@ -1241,13 +1242,15 @@ class TFOpMapper(OpMapper): x_shape = x.out_shapes[0] y_shape = y.out_shapes[0] layer_id = self.paddle_graph.add_layer( - "paddle.fluid.layers.elementwise_sub", inputs=inputs, outputs=[node.name]) + "fluid.layers.elementwise_sub", inputs=inputs, outputs=[node.name]) + self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape} inputs = {"x": node.name, "y": node.name} x_shape = node.out_shapes[0] y_shape = node.out_shapes[0] layer_id = self.paddle_graph.add_layer( "paddle.multiply", inputs=inputs, outputs=[node.name]) + self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape} def OneHot(self, node): input = self.graph.get_node(node.layer.input[0]) diff --git a/x2paddle/optimizer/elimination/dygraph/__init__.py b/x2paddle/optimizer/elimination/dygraph/__init__.py index 8c0e3576f3ad115f12bb859bdb1a334f2a201add..84f8c550d7afe865101986f3431017d70fd1c9ba 100644 --- a/x2paddle/optimizer/elimination/dygraph/__init__.py +++ b/x2paddle/optimizer/elimination/dygraph/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .transpose_elimination import Dygraph_TransposeElimination -from .transpose_eliminate_pass import Dygraph_TransposeEliminatePass \ No newline at end of file +from .transpose_elimination import DygraphTransposeElimination +from .transpose_eliminate_pass import DygraphTransposeEliminatePass \ No newline at end of file diff --git a/x2paddle/optimizer/elimination/dygraph/transpose_eliminate_pass.py b/x2paddle/optimizer/elimination/dygraph/transpose_eliminate_pass.py index c60f0344ae147cae0797ee16cfe004c3758d3ba0..ea2a7e9ca1e83dcbc601019914cbcbcc53cc552e 100644 --- a/x2paddle/optimizer/elimination/dygraph/transpose_eliminate_pass.py +++ b/x2paddle/optimizer/elimination/dygraph/transpose_eliminate_pass.py @@ -13,21 +13,21 @@ # limitations under the License. from x2paddle.optimizer.pass_ import Pass -from x2paddle.optimizer.elimination.dygraph import Dygraph_TransposeElimination +from x2paddle.optimizer.elimination.dygraph import DygraphTransposeElimination from x2paddle.optimizer.pass_manager import pass_register @pass_register -class Dygraph_TransposeEliminatePass(Pass): +class DygraphTransposeEliminatePass(Pass): name = "transpose_eliminate_pass" def __init__(self): Pass.__init__(self) def apply(self, graph): - fuser = Dygraph_TransposeElimination() + fuser = DygraphTransposeElimination() fuser.operate(graph) # 用于注册 -transpose_eliminate_pass = Dygraph_TransposeEliminatePass() \ No newline at end of file +transpose_eliminate_pass = DygraphTransposeEliminatePass() \ No newline at end of file diff --git a/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py b/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py index b7cbf034ee726efcfcaa1cf9fe1f245ae0e4ce4e..8e5b61c10db2db92c53300398045a911ac7c851a 100644 --- a/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py +++ b/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py @@ -13,15 +13,16 @@ # limitations under the License. import copy +import sys import numpy as np from x2paddle.optimizer.pattern_matcher import FuseBase from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -class Dygraph_TransposeElimination(FuseBase): +class DygraphTransposeElimination(FuseBase): def __init__(self): - super(Dygraph_TransposeElimination, self).__init__(graph_type="dygraph") + super(DygraphTransposeElimination, self).__init__(graph_type="dygraph") self.direct_layers = [ 'paddle.nn.ReLU', 'paddle.nn.ReLU6', 'paddle.abs', 'paddle.nn.Sigmoid', 'paddle.exp', 'paddle.rsqrt', @@ -53,6 +54,12 @@ class Dygraph_TransposeElimination(FuseBase): optimized_reduce_layers = list() optimized_concat_layers = list() optimized_elementwise_layers = list() + + def get_index(layer): + if layer.kernel.startswith("paddle.nn") and "functional" not in layer.kernel: + return 1 + else: + return 0 def strip_transpose(_graph): layers = copy.deepcopy(_graph.layers) @@ -61,7 +68,7 @@ class Dygraph_TransposeElimination(FuseBase): continue scanned_layers.add(layer_id) percent = round(len(scanned_layers) / total_layer_num * 100, 2) - print("\rOptimize Transpose Layers...{}%".format( + sys.stderr.write("\rOptimize Transpose Layers...{}%".format( percent)) if layer.kernel != "paddle.transpose": @@ -84,13 +91,14 @@ class Dygraph_TransposeElimination(FuseBase): elif _graph.layers[out].kernel in self.elementwise_layers: propagate_layers.append(out) elif _graph.layers[out].kernel in self.direct_layers: - ouput_index = 1 if _graph.layers[out].kernel.startswith("paddle.nn.") else 0 + ouput_index = get_index(_graph.layers[out]) if _graph.layers[out].outputs[ouput_index] in _graph.outputs: can_be_optimized = False break propagate_layers.append(out) elif _graph.layers[out].kernel in self.reduce_layers: - if _graph.layers[out].outputs[0] in _graph.outputs: + ouput_index = get_index(_graph.layers[out]) + if _graph.layers[out].outputs[ouput_index] in _graph.outputs: can_be_optimized = False break if _graph.layers[out].attrs.get('keepdim', False): @@ -99,7 +107,8 @@ class Dygraph_TransposeElimination(FuseBase): propagate_layers.append(out) reduce_layers.append(out) elif _graph.layers[out].kernel == "paddle.concat": - if _graph.layers[out].outputs[0] in _graph.outputs: + ouput_index = get_index(_graph.layers[out]) + if _graph.layers[out].outputs[ouput_index] in _graph.outputs: can_be_optimized = False break propagate_layers.append(out) @@ -121,20 +130,22 @@ class Dygraph_TransposeElimination(FuseBase): transpose_layers.append(out) elif _graph.layers[ out].kernel in self.elementwise_layers: - if _graph.layers[out].outputs[0] in _graph.outputs: + output_index = get_index(_graph.layers[out]) + if _graph.layers[out].outputs[output_index] in _graph.outputs: can_be_optimized = False break if out not in visited_layers: propagate_layers.append(out) elif _graph.layers[out].kernel in self.direct_layers: - ouput_index = 1 if _graph.layers[out].kernel.startswith("paddle.nn.") else 0 - if _graph.layers[out].outputs[ouput_index] in _graph.outputs: + output_index = get_index(_graph.layers[out]) + if _graph.layers[out].outputs[output_index] in _graph.outputs: can_be_optimized = False break if out not in visited_layers: propagate_layers.append(out) elif _graph.layers[out].kernel in self.reduce_layers: - if _graph.layers[out].outputs[0] in _graph.outputs: + output_index = get_index(_graph.layers[out]) + if _graph.layers[out].outputs[output_index] in _graph.outputs: can_be_optimized = False break if _graph.layers[out].attrs.get('keepdim', @@ -145,7 +156,8 @@ class Dygraph_TransposeElimination(FuseBase): propagate_layers.append(out) reduce_layers.append(out) elif _graph.layers[out].kernel == "paddle.concat": - if _graph.layers[out].outputs[0] in _graph.outputs: + output_index = get_index(_graph.layers[out]) + if _graph.layers[out].outputs[output_index] in _graph.outputs: can_be_optimized = False break if out not in visited_layers: @@ -162,14 +174,15 @@ class Dygraph_TransposeElimination(FuseBase): current_id].input_shapes['x'] y_shape = _graph.layers[ current_id].input_shapes['y'] + output_index = get_index(_graph.layers[ipt]) if _graph.layers[ipt].outputs[ - 0] == _graph.layers[current_id].inputs[ + output_index] == _graph.layers[current_id].inputs[ 'x']: if len(x_shape) <= 1: elementwise_layers.append(current_id) continue elif _graph.layers[ipt].outputs[ - 0] == _graph.layers[current_id].inputs[ + output_index] == _graph.layers[current_id].inputs[ 'y']: if len(y_shape) <= 1: elementwise_layers.append(current_id) @@ -181,6 +194,7 @@ class Dygraph_TransposeElimination(FuseBase): except Exception as e: can_be_optimized = False break + output_index = get_index(_graph.layers[ipt]) if _graph.layers[ ipt].kernel == "paddle.transpose": if _graph.layers[ipt].attrs["perm"] != [0, 2, 3, 1]: @@ -190,20 +204,19 @@ class Dygraph_TransposeElimination(FuseBase): transpose_layers.append(ipt) elif _graph.layers[ ipt].kernel in self.elementwise_layers: - if _graph.layers[ipt].outputs[0] in _graph.outputs: + if _graph.layers[ipt].outputs[output_index] in _graph.outputs: can_be_optimized = False break if ipt not in visited_layers: propagate_layers.append(ipt) elif _graph.layers[ipt].kernel in self.direct_layers: - ouput_index = 1 if _graph.layers[ipt].kernel.startswith("paddle.nn.") else 0 - if _graph.layers[ipt].outputs[ouput_index] in _graph.outputs: + if _graph.layers[ipt].outputs[output_index] in _graph.outputs: can_be_optimized = False break if ipt not in visited_layers: propagate_layers.append(ipt) elif _graph.layers[ipt].kernel in self.reduce_layers: - if _graph.layers[ipt].outputs[0] in _graph.outputs: + if _graph.layers[ipt].outputs[output_index] in _graph.outputs: can_be_optimized = False break if _graph.layers[ipt].attrs.get('keepdim', @@ -214,7 +227,7 @@ class Dygraph_TransposeElimination(FuseBase): propagate_layers.append(ipt) reduce_layers.append(ipt) elif _graph.layers[ipt].kernel == "paddle.concat": - if _graph.layers[ipt].outputs[0] in _graph.outputs: + if _graph.layers[ipt].outputs[output_index] in _graph.outputs: can_be_optimized = False break if ipt not in visited_layers: @@ -231,7 +244,8 @@ class Dygraph_TransposeElimination(FuseBase): transpose_layers.append(layer_id) transpose_layers = list(set(transpose_layers)) for l in transpose_layers: - if graph.layers[l].outputs[0] in graph.outputs: + output_index = get_index(graph.layers[l]) + if graph.layers[l].outputs[output_index] in graph.outputs: can_be_optimized = False break if not can_be_optimized: @@ -254,6 +268,7 @@ class Dygraph_TransposeElimination(FuseBase): while strip_transpose(opt_graph): pass + for layer_id in list(set(optimized_transpose_layers)): self.delete_layer_with_associated(graph, layer_id) for layer_id in list(set(optimized_reduce_layers)):