提交 a7496317 编写于 作者: S SunAhong1993

fix the tf

上级 e725aa37
......@@ -74,7 +74,7 @@ class DygraphTFBatchNormFuser(FuseBase):
inputs={},
outputs=[gen_name(8)])
pattern.add_layer(
"fluid.layers.elementwise_sub",
"paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)])
pattern.add_layer(
......@@ -131,7 +131,7 @@ class DygraphTFBatchNormFuser(FuseBase):
inputs={},
outputs=[gen_name(8)])
pattern.add_layer(
"fluid.layers.elementwise_sub",
"paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)])
pattern.add_layer(
......@@ -180,7 +180,7 @@ class DygraphTFBatchNormFuser(FuseBase):
if matches[out_layer_id].kernel == "paddle.multiply":
gamma_layer_id = graph.edges_in[out_layer_id][1]
gamma_layer = matches[gamma_layer_id]
if layer.kernel == "fluid.layers.elementwise_sub":
if layer.kernel == "paddle.subtract":
in_layer_id = graph.edges_in[layer_id][0]
beta_layer = matches[in_layer_id]
in_layer_id = graph.edges_in[layer_id][1]
......
......@@ -73,7 +73,7 @@ class StaticTFBatchNormFuser(FuseBase):
inputs={},
outputs=[gen_name(8)])
pattern.add_layer(
"fluid.layers.elementwise_sub",
"paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)])
pattern.add_layer(
......@@ -130,7 +130,7 @@ class StaticTFBatchNormFuser(FuseBase):
inputs={},
outputs=[gen_name(8)])
pattern.add_layer(
"fluid.layers.elementwise_sub",
"paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)])
pattern.add_layer(
......@@ -179,7 +179,7 @@ class StaticTFBatchNormFuser(FuseBase):
if matches[out_layer_id].kernel == "paddle.multiply":
gamma_layer_id = graph.edges_in[out_layer_id][1]
gamma_layer = matches[gamma_layer_id]
if layer.kernel == "fluid.layers.elementwise_sub":
if layer.kernel == "paddle.subtract":
in_layer_id = graph.edges_in[layer_id][0]
beta_layer = matches[in_layer_id]
in_layer_id = graph.edges_in[layer_id][1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册