提交 ea0259a1 编写于 作者: S SunAhong1993

fix the tf bug

上级 f7cc20f9
...@@ -101,7 +101,7 @@ class DygraphTransposeElimination(FuseBase): ...@@ -101,7 +101,7 @@ class DygraphTransposeElimination(FuseBase):
if _graph.layers[out].outputs[ouput_index] in _graph.outputs: if _graph.layers[out].outputs[ouput_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if _graph.layers[out].attrs.get('keepdim', False): if not _graph.layers[out].attrs.get('keepdim', False):
can_be_optimized = False can_be_optimized = False
break break
propagate_layers.append(out) propagate_layers.append(out)
...@@ -148,7 +148,7 @@ class DygraphTransposeElimination(FuseBase): ...@@ -148,7 +148,7 @@ class DygraphTransposeElimination(FuseBase):
if _graph.layers[out].outputs[output_index] in _graph.outputs: if _graph.layers[out].outputs[output_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if _graph.layers[out].attrs.get('keepdim', if not _graph.layers[out].attrs.get('keepdim',
False): False):
can_be_optimized = False can_be_optimized = False
break break
...@@ -219,7 +219,7 @@ class DygraphTransposeElimination(FuseBase): ...@@ -219,7 +219,7 @@ class DygraphTransposeElimination(FuseBase):
if _graph.layers[ipt].outputs[output_index] in _graph.outputs: if _graph.layers[ipt].outputs[output_index] in _graph.outputs:
can_be_optimized = False can_be_optimized = False
break break
if _graph.layers[ipt].attrs.get('keepdim', if not _graph.layers[ipt].attrs.get('keepdim',
False): False):
can_be_optimized = False can_be_optimized = False
break break
...@@ -252,7 +252,7 @@ class DygraphTransposeElimination(FuseBase): ...@@ -252,7 +252,7 @@ class DygraphTransposeElimination(FuseBase):
continue continue
for l in transpose_layers: for l in transpose_layers:
_graph.delete_layer(l) _graph.del_layer(l)
optimized_transpose_layers.extend(transpose_layers) optimized_transpose_layers.extend(transpose_layers)
optimized_reduce_layers.extend(reduce_layers) optimized_reduce_layers.extend(reduce_layers)
...@@ -268,21 +268,22 @@ class DygraphTransposeElimination(FuseBase): ...@@ -268,21 +268,22 @@ class DygraphTransposeElimination(FuseBase):
while strip_transpose(opt_graph): while strip_transpose(opt_graph):
pass pass
for layer_id in list(set(optimized_transpose_layers)): for layer_id in list(set(optimized_transpose_layers)):
self.delete_layer_with_associated(graph, layer_id) graph.del_layer(layer_id)
for layer_id in list(set(optimized_reduce_layers)): for layer_id in list(set(optimized_reduce_layers)):
dim = graph.layers[layer_id].attrs.get('dim', None) dim = graph.layers[layer_id].attrs.get('axis', None)
if dim is not None: if dim is not None:
for i in range(len(dim)): for i in range(len(dim)):
dim[i] = [0, 2, 3, 1][dim[i]] dim[i] = [0, 2, 3, 1][dim[i]]
graph.layers[layer_id].attrs['dim'] = dim graph.layers[layer_id].attrs['axis'] = dim
for layer_id in list(set(optimized_concat_layers)): for layer_id in list(set(optimized_concat_layers)):
axis = graph.layers[layer_id].attrs.get('axis', 0) axis = graph.layers[layer_id].attrs.get('axis', 0)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
for layer_id in list(set(optimized_elementwise_layers)): for layer_id in list(set(optimized_elementwise_layers)):
axis = graph.layers[layer_id].attrs.get('axis', -1) axis = graph.layers[layer_id].attrs.get('axis', -1)
graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis] graph.layers[layer_id].attrs['axis'] = [0, 2, 3, 1][axis]
if graph.layers[layer_id].kernel == "paddle.add":
graph.layers[layer_id].kernel = "fluid.layers.elementwise_add"
current_transpose_num = self.get_transpose_num(graph) current_transpose_num = self.get_transpose_num(graph)
print( print(
......
...@@ -24,6 +24,7 @@ class DygraphTFBatchNormFuser(FuseBase): ...@@ -24,6 +24,7 @@ class DygraphTFBatchNormFuser(FuseBase):
def __init__(self): def __init__(self):
self.bn_index = 0 self.bn_index = 0
super(DygraphTFBatchNormFuser, self).__init__(graph_type="dygraph") super(DygraphTFBatchNormFuser, self).__init__(graph_type="dygraph")
self.patterns = list()
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的batchnorm图结构。 """ 描述需要替换的batchnorm图结构。
...@@ -34,57 +35,111 @@ class DygraphTFBatchNormFuser(FuseBase): ...@@ -34,57 +35,111 @@ class DygraphTFBatchNormFuser(FuseBase):
def gen_name(id): def gen_name(id):
return "x" + str(id) return "x" + str(id)
self.pattern.add_layer( pattern = PaddleGraph(graph_type="dygraph")
pattern.add_layer(
"self.create_parameter", "self.create_parameter",
inputs={}, inputs={},
outputs=[gen_name(0)]) outputs=[gen_name(0)])
self.pattern.add_layer( pattern.add_layer(
"paddle.full", "paddle.full",
inputs={}, inputs={},
outputs=[gen_name(1)], outputs=[gen_name(1)],
shape=[1]) shape=[1])
self.pattern.add_layer( pattern.add_layer(
"paddle.add", "paddle.add",
inputs={"x": gen_name(0), "y": gen_name(1)}, inputs={"x": gen_name(0), "y": gen_name(1)},
outputs=[gen_name(2)]) outputs=[gen_name(2)])
self.pattern.add_layer( pattern.add_layer(
"paddle.rsqrt", "paddle.rsqrt",
inputs={"x": gen_name(2)}, inputs={"x": gen_name(2)},
outputs=[gen_name(3)]) outputs=[gen_name(3)])
self.pattern.add_layer( pattern.add_layer(
"self.create_parameter", "self.create_parameter",
inputs={}, inputs={},
outputs=[gen_name(4)]) outputs=[gen_name(4)])
self.pattern.add_layer( pattern.add_layer(
"paddle.multiply", "paddle.multiply",
inputs={"x": gen_name(3), "y": gen_name(4)}, inputs={"x": gen_name(3), "y": gen_name(4)},
outputs=[gen_name(5)]) outputs=[gen_name(5)])
self.pattern.add_layer( pattern.add_layer(
"self.create_parameter", "self.create_parameter",
inputs={}, inputs={},
outputs=[gen_name(6)]) outputs=[gen_name(6)])
self.pattern.add_layer( pattern.add_layer(
"paddle.multiply", "paddle.multiply",
inputs={"x": gen_name(6), "y": gen_name(5)}, inputs={"x": gen_name(6), "y": gen_name(5)},
outputs=[gen_name(7)]) outputs=[gen_name(7)])
self.pattern.add_layer( pattern.add_layer(
"self.create_parameter", "self.create_parameter",
inputs={}, inputs={},
outputs=[gen_name(8)]) outputs=[gen_name(8)])
self.pattern.add_layer( pattern.add_layer(
"fluid.layers.elementwise_sub", "fluid.layers.elementwise_sub",
inputs={"x": gen_name(8), "y": gen_name(7)}, inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)]) outputs=[gen_name(9)])
self.pattern.add_layer( pattern.add_layer(
"paddle.multiply", "paddle.multiply",
inputs={"x": "bn-input-0", "y": gen_name(5)}, inputs={"x": "bn-input-0", "y": gen_name(5)},
outputs=[gen_name(10)]) outputs=[gen_name(10)])
self.pattern.add_layer( pattern.add_layer(
"paddle.add", "paddle.add",
inputs={"x": gen_name(10), "y": gen_name(9)}, inputs={"x": gen_name(10), "y": gen_name(9)},
outputs=[gen_name(11)]) outputs=[gen_name(11)])
self.pattern.build(inputs={"input-0": "bn-input-0", }) pattern.build(inputs={"input-0": "bn-input-0", })
self.patterns.append(pattern)
pattern = PaddleGraph(graph_type="dygraph")
pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(0)])
pattern.add_layer(
"paddle.full",
inputs={},
outputs=[gen_name(1)],
shape=[1])
pattern.add_layer(
"paddle.add",
inputs={"x": gen_name(0), "y": gen_name(1)},
outputs=[gen_name(2)])
pattern.add_layer(
"paddle.rsqrt",
inputs={"x": gen_name(2)},
outputs=[gen_name(3)])
pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(4)])
pattern.add_layer(
"paddle.multiply",
inputs={"x": gen_name(3), "y": gen_name(4)},
outputs=[gen_name(5)])
pattern.add_layer(
"paddle.multiply",
inputs={"x": "bn-input-0", "y": gen_name(5)},
outputs=[gen_name(10)])
pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(6)])
pattern.add_layer(
"paddle.multiply",
inputs={"x": gen_name(6), "y": gen_name(5)},
outputs=[gen_name(7)])
pattern.add_layer(
"self.create_parameter",
inputs={},
outputs=[gen_name(8)])
pattern.add_layer(
"fluid.layers.elementwise_sub",
inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)])
pattern.add_layer(
"paddle.add",
inputs={"x": gen_name(10), "y": gen_name(9)},
outputs=[gen_name(11)])
pattern.build(inputs={"input-0": "bn-input-0", })
self.patterns.append(pattern)
def insert_new_layer(self, graph, parameters, matches): def insert_new_layer(self, graph, parameters, matches):
new_layers, last_layer_id = self.gen_new_layer(matches, parameters, graph) new_layers, last_layer_id = self.gen_new_layer(matches, parameters, graph)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册