diff --git a/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py b/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py index 37136a12b2fa6b2c36e9382561f6b501c46edd69..e30bef79e204692e66975a7638effddf8bcded34 100644 --- a/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py @@ -1065,6 +1065,24 @@ class TFOpMapper(OpMapper): ], num_or_sections=num_split, axis=dim) + + def SplitV(self, node): + input = self.graph.get_input_node(node, 0) + size_splits = self.graph.get_input_node(node, 1) + assert size_splits.layer_type == "Const", "size_splits of SplitV OP should be Const" + size_splits = size_splits.value.tolist() + dim = self.graph.get_input_node(node, 2) + assert dim.layer_type == "Const", "dim of SplitV OP should be Const" + dim = dim.value + + self.paddle_graph.add_layer( + kernel="paddle.split", + inputs={"x": input.name}, + outputs=[ + "{}_p{}".format(node.layer_name, i) for i in range(len(size_splits)) + ], + num_or_sections=size_splits, + axis=dim) def Slice(self, node): input = self.graph.get_input_node(node, 0) diff --git a/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py b/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py index 338b15e9ed03bbd973de175fddcc72aedb1b2745..7df97dc6dda557102a1b1cd8e7b68a0ef5963b7d 100644 --- a/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py +++ b/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py @@ -1042,6 +1042,24 @@ class TFOpMapper(OpMapper): ], num_or_sections=num_split, axis=dim) + + def SplitV(self, node): + input = self.graph.get_input_node(node, 0) + size_splits = self.graph.get_input_node(node, 1) + assert size_splits.layer_type == "Const", "size_splits of SplitV OP should be Const" + size_splits = size_splits.value.tolist() + dim = self.graph.get_input_node(node, 2) + assert dim.layer_type == "Const", "dim of SplitV OP should be Const" + dim = dim.value + + self.paddle_graph.add_layer( + kernel="paddle.split", + inputs={"x": input.name}, + outputs=[ + "{}_p{}".format(node.layer_name, i) for i in range(len(size_splits)) + ], + num_or_sections=size_splits, + axis=dim) def Slice(self, node): input = self.graph.get_input_node(node, 0) diff --git a/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py b/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py index f3af6b08db3a7e9f817a47b26db389b6e279bf4c..6a53b1db29959a8cf7088347647db092b24f458c 100644 --- a/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py @@ -74,7 +74,7 @@ class DygraphTFBatchNormFuser(FuseBase): inputs={}, outputs=[gen_name(8)]) pattern.add_layer( - "fluid.layers.elementwise_sub", + "paddle.subtract", inputs={"x": gen_name(8), "y": gen_name(7)}, outputs=[gen_name(9)]) pattern.add_layer( @@ -131,7 +131,7 @@ class DygraphTFBatchNormFuser(FuseBase): inputs={}, outputs=[gen_name(8)]) pattern.add_layer( - "fluid.layers.elementwise_sub", + "paddle.subtract", inputs={"x": gen_name(8), "y": gen_name(7)}, outputs=[gen_name(9)]) pattern.add_layer( @@ -180,7 +180,7 @@ class DygraphTFBatchNormFuser(FuseBase): if matches[out_layer_id].kernel == "paddle.multiply": gamma_layer_id = graph.edges_in[out_layer_id][1] gamma_layer = matches[gamma_layer_id] - if layer.kernel == "fluid.layers.elementwise_sub": + if layer.kernel == "paddle.subtract": in_layer_id = graph.edges_in[layer_id][0] beta_layer = matches[in_layer_id] in_layer_id = graph.edges_in[layer_id][1] diff --git a/x2paddle/optimizer/fusion/static/tf_batchnorm_fuser.py b/x2paddle/optimizer/fusion/static/tf_batchnorm_fuser.py index 796556bb96908be48fb4eca654c054b821da575f..1299b34d7664c10cc078679605932719ae0d9d11 100644 --- a/x2paddle/optimizer/fusion/static/tf_batchnorm_fuser.py +++ b/x2paddle/optimizer/fusion/static/tf_batchnorm_fuser.py @@ -73,7 +73,7 @@ class StaticTFBatchNormFuser(FuseBase): inputs={}, outputs=[gen_name(8)]) pattern.add_layer( - "fluid.layers.elementwise_sub", + "paddle.subtract", inputs={"x": gen_name(8), "y": gen_name(7)}, outputs=[gen_name(9)]) pattern.add_layer( @@ -130,7 +130,7 @@ class StaticTFBatchNormFuser(FuseBase): inputs={}, outputs=[gen_name(8)]) pattern.add_layer( - "fluid.layers.elementwise_sub", + "paddle.subtract", inputs={"x": gen_name(8), "y": gen_name(7)}, outputs=[gen_name(9)]) pattern.add_layer( @@ -179,7 +179,7 @@ class StaticTFBatchNormFuser(FuseBase): if matches[out_layer_id].kernel == "paddle.multiply": gamma_layer_id = graph.edges_in[out_layer_id][1] gamma_layer = matches[gamma_layer_id] - if layer.kernel == "fluid.layers.elementwise_sub": + if layer.kernel == "paddle.subtract": in_layer_id = graph.edges_in[layer_id][0] beta_layer = matches[in_layer_id] in_layer_id = graph.edges_in[layer_id][1]