提交 a7496317 编写于 作者: S SunAhong1993

fix the tf

上级 e725aa37
...@@ -74,7 +74,7 @@ class DygraphTFBatchNormFuser(FuseBase): ...@@ -74,7 +74,7 @@ class DygraphTFBatchNormFuser(FuseBase):
inputs={}, inputs={},
outputs=[gen_name(8)]) outputs=[gen_name(8)])
pattern.add_layer( pattern.add_layer(
"fluid.layers.elementwise_sub", "paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)}, inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)]) outputs=[gen_name(9)])
pattern.add_layer( pattern.add_layer(
...@@ -131,7 +131,7 @@ class DygraphTFBatchNormFuser(FuseBase): ...@@ -131,7 +131,7 @@ class DygraphTFBatchNormFuser(FuseBase):
inputs={}, inputs={},
outputs=[gen_name(8)]) outputs=[gen_name(8)])
pattern.add_layer( pattern.add_layer(
"fluid.layers.elementwise_sub", "paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)}, inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)]) outputs=[gen_name(9)])
pattern.add_layer( pattern.add_layer(
...@@ -180,7 +180,7 @@ class DygraphTFBatchNormFuser(FuseBase): ...@@ -180,7 +180,7 @@ class DygraphTFBatchNormFuser(FuseBase):
if matches[out_layer_id].kernel == "paddle.multiply": if matches[out_layer_id].kernel == "paddle.multiply":
gamma_layer_id = graph.edges_in[out_layer_id][1] gamma_layer_id = graph.edges_in[out_layer_id][1]
gamma_layer = matches[gamma_layer_id] gamma_layer = matches[gamma_layer_id]
if layer.kernel == "fluid.layers.elementwise_sub": if layer.kernel == "paddle.subtract":
in_layer_id = graph.edges_in[layer_id][0] in_layer_id = graph.edges_in[layer_id][0]
beta_layer = matches[in_layer_id] beta_layer = matches[in_layer_id]
in_layer_id = graph.edges_in[layer_id][1] in_layer_id = graph.edges_in[layer_id][1]
......
...@@ -73,7 +73,7 @@ class StaticTFBatchNormFuser(FuseBase): ...@@ -73,7 +73,7 @@ class StaticTFBatchNormFuser(FuseBase):
inputs={}, inputs={},
outputs=[gen_name(8)]) outputs=[gen_name(8)])
pattern.add_layer( pattern.add_layer(
"fluid.layers.elementwise_sub", "paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)}, inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)]) outputs=[gen_name(9)])
pattern.add_layer( pattern.add_layer(
...@@ -130,7 +130,7 @@ class StaticTFBatchNormFuser(FuseBase): ...@@ -130,7 +130,7 @@ class StaticTFBatchNormFuser(FuseBase):
inputs={}, inputs={},
outputs=[gen_name(8)]) outputs=[gen_name(8)])
pattern.add_layer( pattern.add_layer(
"fluid.layers.elementwise_sub", "paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)}, inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)]) outputs=[gen_name(9)])
pattern.add_layer( pattern.add_layer(
...@@ -179,7 +179,7 @@ class StaticTFBatchNormFuser(FuseBase): ...@@ -179,7 +179,7 @@ class StaticTFBatchNormFuser(FuseBase):
if matches[out_layer_id].kernel == "paddle.multiply": if matches[out_layer_id].kernel == "paddle.multiply":
gamma_layer_id = graph.edges_in[out_layer_id][1] gamma_layer_id = graph.edges_in[out_layer_id][1]
gamma_layer = matches[gamma_layer_id] gamma_layer = matches[gamma_layer_id]
if layer.kernel == "fluid.layers.elementwise_sub": if layer.kernel == "paddle.subtract":
in_layer_id = graph.edges_in[layer_id][0] in_layer_id = graph.edges_in[layer_id][0]
beta_layer = matches[in_layer_id] beta_layer = matches[in_layer_id]
in_layer_id = graph.edges_in[layer_id][1] in_layer_id = graph.edges_in[layer_id][1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册