diff --git a/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py b/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py index b2a8a679336964cdc0734a49db379a0dc3b8cb28..22c59d79c3c6e9860293881dfa065d55ac69cc6b 100644 --- a/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py +++ b/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py @@ -101,7 +101,7 @@ class DygraphTransposeElimination(FuseBase): if _graph.layers[out].outputs[ouput_index] in _graph.outputs: can_be_optimized = False break - if _graph.layers[out].attrs.get('keepdim', False): + if not _graph.layers[out].attrs.get('keepdim', False): can_be_optimized = False break propagate_layers.append(out) @@ -148,7 +148,7 @@ class DygraphTransposeElimination(FuseBase): if _graph.layers[out].outputs[output_index] in _graph.outputs: can_be_optimized = False break - if _graph.layers[out].attrs.get('keepdim', + if not _graph.layers[out].attrs.get('keepdim', False): can_be_optimized = False break @@ -219,7 +219,7 @@ class DygraphTransposeElimination(FuseBase): if _graph.layers[ipt].outputs[output_index] in _graph.outputs: can_be_optimized = False break - if _graph.layers[ipt].attrs.get('keepdim', + if not _graph.layers[ipt].attrs.get('keepdim', False): can_be_optimized = False break @@ -252,7 +252,7 @@ class DygraphTransposeElimination(FuseBase): continue for l in transpose_layers: - _graph.delete_layer(l) + _graph.del_layer(l) optimized_transpose_layers.extend(transpose_layers) optimized_reduce_layers.extend(reduce_layers) @@ -268,21 +268,22 @@ class DygraphTransposeElimination(FuseBase): while strip_transpose(opt_graph): pass - for layer_id in list(set(optimized_transpose_layers)): - self.delete_layer_with_associated(graph, layer_id) + graph.del_layer(layer_id) for layer_id in list(set(optimized_reduce_layers)): - dim = graph.layers[layer_id].attrs.get('dim', None) + dim = graph.layers[layer_id].attrs.get('axis', None) if dim is not None: for i in range(len(dim)): dim[i] = [0, 2, 3, 1][dim[i]] - graph.layers[layer_id].attrs['dim'] = dim + graph.layers[layer_id].attrs['axis'] = dim for layer_id in list(set(optimized_concat_layers)): axis = graph.layers[layer_id].attrs.get('axis', 0) graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] for layer_id in list(set(optimized_elementwise_layers)): axis = graph.layers[layer_id].attrs.get('axis', -1) graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] + if graph.layers[layer_id].kernel == "paddle.add": + graph.layers[layer_id].kernel = "fluid.layers.elementwise_add" current_transpose_num = self.get_transpose_num(graph) print( diff --git a/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py b/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py index 9dd0727f789455ade114aa2e3efdfca479b55888..f3af6b08db3a7e9f817a47b26db389b6e279bf4c 100644 --- a/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py @@ -24,6 +24,7 @@ class DygraphTFBatchNormFuser(FuseBase): def __init__(self): self.bn_index = 0 super(DygraphTFBatchNormFuser, self).__init__(graph_type="dygraph") + self.patterns = list() def build_pattern(self): """ 描述需要替换的batchnorm图结构。 @@ -34,57 +35,111 @@ class DygraphTFBatchNormFuser(FuseBase): def gen_name(id): return "x" + str(id) - self.pattern.add_layer( + pattern = PaddleGraph(graph_type="dygraph") + pattern.add_layer( "self.create_parameter", inputs={}, outputs=[gen_name(0)]) - self.pattern.add_layer( + pattern.add_layer( "paddle.full", inputs={}, outputs=[gen_name(1)], shape=[1]) - self.pattern.add_layer( + pattern.add_layer( "paddle.add", inputs={"x": gen_name(0), "y": gen_name(1)}, outputs=[gen_name(2)]) - self.pattern.add_layer( + pattern.add_layer( "paddle.rsqrt", inputs={"x": gen_name(2)}, outputs=[gen_name(3)]) - self.pattern.add_layer( + pattern.add_layer( "self.create_parameter", inputs={}, outputs=[gen_name(4)]) - self.pattern.add_layer( + pattern.add_layer( "paddle.multiply", inputs={"x": gen_name(3), "y": gen_name(4)}, outputs=[gen_name(5)]) - self.pattern.add_layer( + pattern.add_layer( "self.create_parameter", inputs={}, outputs=[gen_name(6)]) - self.pattern.add_layer( + pattern.add_layer( "paddle.multiply", inputs={"x": gen_name(6), "y": gen_name(5)}, outputs=[gen_name(7)]) - self.pattern.add_layer( + pattern.add_layer( "self.create_parameter", inputs={}, outputs=[gen_name(8)]) - self.pattern.add_layer( + pattern.add_layer( "fluid.layers.elementwise_sub", inputs={"x": gen_name(8), "y": gen_name(7)}, outputs=[gen_name(9)]) - self.pattern.add_layer( + pattern.add_layer( "paddle.multiply", inputs={"x": "bn-input-0", "y": gen_name(5)}, outputs=[gen_name(10)]) - self.pattern.add_layer( + pattern.add_layer( "paddle.add", inputs={"x": gen_name(10), "y": gen_name(9)}, outputs=[gen_name(11)]) - self.pattern.build(inputs={"input-0": "bn-input-0", }) + pattern.build(inputs={"input-0": "bn-input-0", }) + self.patterns.append(pattern) + pattern = PaddleGraph(graph_type="dygraph") + pattern.add_layer( + "self.create_parameter", + inputs={}, + outputs=[gen_name(0)]) + pattern.add_layer( + "paddle.full", + inputs={}, + outputs=[gen_name(1)], + shape=[1]) + pattern.add_layer( + "paddle.add", + inputs={"x": gen_name(0), "y": gen_name(1)}, + outputs=[gen_name(2)]) + pattern.add_layer( + "paddle.rsqrt", + inputs={"x": gen_name(2)}, + outputs=[gen_name(3)]) + pattern.add_layer( + "self.create_parameter", + inputs={}, + outputs=[gen_name(4)]) + pattern.add_layer( + "paddle.multiply", + inputs={"x": gen_name(3), "y": gen_name(4)}, + outputs=[gen_name(5)]) + pattern.add_layer( + "paddle.multiply", + inputs={"x": "bn-input-0", "y": gen_name(5)}, + outputs=[gen_name(10)]) + pattern.add_layer( + "self.create_parameter", + inputs={}, + outputs=[gen_name(6)]) + pattern.add_layer( + "paddle.multiply", + inputs={"x": gen_name(6), "y": gen_name(5)}, + outputs=[gen_name(7)]) + pattern.add_layer( + "self.create_parameter", + inputs={}, + outputs=[gen_name(8)]) + pattern.add_layer( + "fluid.layers.elementwise_sub", + inputs={"x": gen_name(8), "y": gen_name(7)}, + outputs=[gen_name(9)]) + pattern.add_layer( + "paddle.add", + inputs={"x": gen_name(10), "y": gen_name(9)}, + outputs=[gen_name(11)]) + pattern.build(inputs={"input-0": "bn-input-0", }) + self.patterns.append(pattern) def insert_new_layer(self, graph, parameters, matches): new_layers, last_layer_id = self.gen_new_layer(matches, parameters, graph)