未验证 提交 adfd0c08 编写于 作者: J Jason 提交者: GitHub

Merge pull request #490 from SunAhong1993/tf

fix the tf bn fuser
...@@ -1065,6 +1065,24 @@ class TFOpMapper(OpMapper): ...@@ -1065,6 +1065,24 @@ class TFOpMapper(OpMapper):
], ],
num_or_sections=num_split, num_or_sections=num_split,
axis=dim) axis=dim)
def SplitV(self, node):
input = self.graph.get_input_node(node, 0)
size_splits = self.graph.get_input_node(node, 1)
assert size_splits.layer_type == "Const", "size_splits of SplitV OP should be Const"
size_splits = size_splits.value.tolist()
dim = self.graph.get_input_node(node, 2)
assert dim.layer_type == "Const", "dim of SplitV OP should be Const"
dim = dim.value
self.paddle_graph.add_layer(
kernel="paddle.split",
inputs={"x": input.name},
outputs=[
"{}_p{}".format(node.layer_name, i) for i in range(len(size_splits))
],
num_or_sections=size_splits,
axis=dim)
def Slice(self, node): def Slice(self, node):
input = self.graph.get_input_node(node, 0) input = self.graph.get_input_node(node, 0)
......
...@@ -1042,6 +1042,24 @@ class TFOpMapper(OpMapper): ...@@ -1042,6 +1042,24 @@ class TFOpMapper(OpMapper):
], ],
num_or_sections=num_split, num_or_sections=num_split,
axis=dim) axis=dim)
def SplitV(self, node):
input = self.graph.get_input_node(node, 0)
size_splits = self.graph.get_input_node(node, 1)
assert size_splits.layer_type == "Const", "size_splits of SplitV OP should be Const"
size_splits = size_splits.value.tolist()
dim = self.graph.get_input_node(node, 2)
assert dim.layer_type == "Const", "dim of SplitV OP should be Const"
dim = dim.value
self.paddle_graph.add_layer(
kernel="paddle.split",
inputs={"x": input.name},
outputs=[
"{}_p{}".format(node.layer_name, i) for i in range(len(size_splits))
],
num_or_sections=size_splits,
axis=dim)
def Slice(self, node): def Slice(self, node):
input = self.graph.get_input_node(node, 0) input = self.graph.get_input_node(node, 0)
......
...@@ -74,7 +74,7 @@ class DygraphTFBatchNormFuser(FuseBase): ...@@ -74,7 +74,7 @@ class DygraphTFBatchNormFuser(FuseBase):
inputs={}, inputs={},
outputs=[gen_name(8)]) outputs=[gen_name(8)])
pattern.add_layer( pattern.add_layer(
"fluid.layers.elementwise_sub", "paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)}, inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)]) outputs=[gen_name(9)])
pattern.add_layer( pattern.add_layer(
...@@ -131,7 +131,7 @@ class DygraphTFBatchNormFuser(FuseBase): ...@@ -131,7 +131,7 @@ class DygraphTFBatchNormFuser(FuseBase):
inputs={}, inputs={},
outputs=[gen_name(8)]) outputs=[gen_name(8)])
pattern.add_layer( pattern.add_layer(
"fluid.layers.elementwise_sub", "paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)}, inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)]) outputs=[gen_name(9)])
pattern.add_layer( pattern.add_layer(
...@@ -180,7 +180,7 @@ class DygraphTFBatchNormFuser(FuseBase): ...@@ -180,7 +180,7 @@ class DygraphTFBatchNormFuser(FuseBase):
if matches[out_layer_id].kernel == "paddle.multiply": if matches[out_layer_id].kernel == "paddle.multiply":
gamma_layer_id = graph.edges_in[out_layer_id][1] gamma_layer_id = graph.edges_in[out_layer_id][1]
gamma_layer = matches[gamma_layer_id] gamma_layer = matches[gamma_layer_id]
if layer.kernel == "fluid.layers.elementwise_sub": if layer.kernel == "paddle.subtract":
in_layer_id = graph.edges_in[layer_id][0] in_layer_id = graph.edges_in[layer_id][0]
beta_layer = matches[in_layer_id] beta_layer = matches[in_layer_id]
in_layer_id = graph.edges_in[layer_id][1] in_layer_id = graph.edges_in[layer_id][1]
......
...@@ -73,7 +73,7 @@ class StaticTFBatchNormFuser(FuseBase): ...@@ -73,7 +73,7 @@ class StaticTFBatchNormFuser(FuseBase):
inputs={}, inputs={},
outputs=[gen_name(8)]) outputs=[gen_name(8)])
pattern.add_layer( pattern.add_layer(
"fluid.layers.elementwise_sub", "paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)}, inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)]) outputs=[gen_name(9)])
pattern.add_layer( pattern.add_layer(
...@@ -130,7 +130,7 @@ class StaticTFBatchNormFuser(FuseBase): ...@@ -130,7 +130,7 @@ class StaticTFBatchNormFuser(FuseBase):
inputs={}, inputs={},
outputs=[gen_name(8)]) outputs=[gen_name(8)])
pattern.add_layer( pattern.add_layer(
"fluid.layers.elementwise_sub", "paddle.subtract",
inputs={"x": gen_name(8), "y": gen_name(7)}, inputs={"x": gen_name(8), "y": gen_name(7)},
outputs=[gen_name(9)]) outputs=[gen_name(9)])
pattern.add_layer( pattern.add_layer(
...@@ -179,7 +179,7 @@ class StaticTFBatchNormFuser(FuseBase): ...@@ -179,7 +179,7 @@ class StaticTFBatchNormFuser(FuseBase):
if matches[out_layer_id].kernel == "paddle.multiply": if matches[out_layer_id].kernel == "paddle.multiply":
gamma_layer_id = graph.edges_in[out_layer_id][1] gamma_layer_id = graph.edges_in[out_layer_id][1]
gamma_layer = matches[gamma_layer_id] gamma_layer = matches[gamma_layer_id]
if layer.kernel == "fluid.layers.elementwise_sub": if layer.kernel == "paddle.subtract":
in_layer_id = graph.edges_in[layer_id][0] in_layer_id = graph.edges_in[layer_id][0]
beta_layer = matches[in_layer_id] beta_layer = matches[in_layer_id]
in_layer_id = graph.edges_in[layer_id][1] in_layer_id = graph.edges_in[layer_id][1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册