diff --git a/x2paddle/optimizer/fusion/dygraph/__init__.py b/x2paddle/optimizer/fusion/dygraph/__init__.py index 4bbcf483e4a30b73b780da37d852c1579153f6e5..86079ec6ebd2fb511a37788f45c5fa95d75fd399 100644 --- a/x2paddle/optimizer/fusion/dygraph/__init__.py +++ b/x2paddle/optimizer/fusion/dygraph/__init__.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .adaptive_pool2d_fuser import Dygraph_AdaptivePool2dFuser -from .adaptive_pool2d_fuse_pass import Dygraph_AdaptivePool2dFusePass -from .batchnorm2d_fuser import Dygraph_BatchNorm2dFuser -from .batchnorm2d_fuse_pass import Dygraph_BatchNorm2dFusePass -from .bn_scale_fuser import Dygraph_BNScaleFuser -from .bn_scale_fuse_pass import Dygraph_BNScaleFusePass -from .constant_fuser import Dygraph_ConstantFuser -from .constant_fuse_pass import Dygraph_ConstantFusePass -from .conv2d_add_fuser import Dygraph_Conv2D_AddFuser -from .conv2d_add_fuse_pass import Dygraph_Conv2D_AddFusePass -from .dropout_fuser import Dygraph_DropoutFuser -from .dropout_fuse_pass import Dygraph_DropoutFusePass -from .fc_fuser import Dygraph_FcFuser -from .fc_fuse_pass import Dygraph_FcFusePass -from .interpolate_bilinear_fuser import Dygraph_InterpolateBilinearFuser -from .interpolate_bilinear_fuse_pass import Dygraph_InterpolateBilinearFusePass -from .reshape_fuser import Dygraph_ReshapeFuser -from .reshape_fuse_pass import Dygraph_ReshapeFusePass -from .tf_batchnorm_fuser import Dygraph_TF_BatchNormFuser -from .tf_batchnorm_fuse_pass import Dygraph_TF_BatchNormFusePass +from .adaptive_pool2d_fuser import DygraphAdaptivePool2dFuser +from .adaptive_pool2d_fuse_pass import DygraphAdaptivePool2dFusePass +from .batchnorm2d_fuser import DygraphBatchNorm2dFuser +from .batchnorm2d_fuse_pass import DygraphBatchNorm2dFusePass +from .bn_scale_fuser import DygraphBNScaleFuser +from .bn_scale_fuse_pass import DygraphBNScaleFusePass +from .constant_fuser import DygraphConstantFuser +from .constant_fuse_pass import DygraphConstantFusePass +from .conv2d_add_fuser import DygraphConv2DAddFuser +from .conv2d_add_fuse_pass import DygraphConv2DAddFusePass +from .dropout_fuser import DygraphDropoutFuser +from .dropout_fuse_pass import DygraphDropoutFusePass +from .fc_fuser import DygraphFcFuser +from .fc_fuse_pass import DygraphFcFusePass +from .interpolate_bilinear_fuser import DygraphInterpolateBilinearFuser +from .interpolate_bilinear_fuse_pass import DygraphInterpolateBilinearFusePass +from .reshape_fuser import DygraphReshapeFuser +from .reshape_fuse_pass import DygraphReshapeFusePass +from .tf_batchnorm_fuser import DygraphTFBatchNormFuser +from .tf_batchnorm_fuse_pass import DygraphTFBatchNormFusePass diff --git a/x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuse_pass.py b/x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuse_pass.py index 14e50fd6c26c736b82a3bb78de46ce870b83e715..f47874e2ce8a17ec88292d61eb333ded0a8a2b16 100644 --- a/x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuse_pass.py +++ b/x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuse_pass.py @@ -13,21 +13,21 @@ # limitations under the License. from x2paddle.optimizer.pass_ import Pass -from x2paddle.optimizer.fusion.dygraph import Dygraph_AdaptivePool2dFuser +from x2paddle.optimizer.fusion.dygraph import DygraphAdaptivePool2dFuser from x2paddle.optimizer.pass_manager import pass_register @pass_register -class Dygraph_AdaptivePool2dFusePass(Pass): +class DygraphAdaptivePool2dFusePass(Pass): name = "dygraph_adaptive_pool2d_fuse_pass" def __init__(self): Pass.__init__(self) def apply(self, graph): - fuser = Dygraph_AdaptivePool2dFuser() + fuser = DygraphAdaptivePool2dFuser() fuser.operate(graph, match_kind="topo") # 用于注册 -adaptive_pool2d_fuse_pass = Dygraph_AdaptivePool2dFusePass() +adaptive_pool2d_fuse_pass = DygraphAdaptivePool2dFusePass() diff --git a/x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuser.py b/x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuser.py index 95fe1b599ad3fe4efe7d02846eee47de8c5afe4b..12d8ecda423d2ffcc788dc80cfb9a68e1d0cdeea 100644 --- a/x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuser.py @@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -class Dygraph_AdaptivePool2dFuser(FuseBase): +class DygraphAdaptivePool2dFuser(FuseBase): def __init__(self): - super(Dygraph_AdaptivePool2dFuser, self).__init__(graph_type="dygraph") + super(DygraphAdaptivePool2dFuser, self).__init__(graph_type="dygraph") def build_pattern(self): """ 描述需要替换的adaptive pool2d图结构。 diff --git a/x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuse_pass.py b/x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuse_pass.py index 2f738213a83fdaba1e2efebf821a34d6215b2a85..7184d81d71554440ee7265fac18d70dce1a24d6b 100644 --- a/x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuse_pass.py +++ b/x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuse_pass.py @@ -13,21 +13,21 @@ # limitations under the License. from x2paddle.optimizer.pass_ import Pass -from x2paddle.optimizer.fusion.dygraph import Dygraph_BatchNorm2dFuser +from x2paddle.optimizer.fusion.dygraph import DygraphBatchNorm2dFuser from x2paddle.optimizer.pass_manager import pass_register @pass_register -class Dygraph_BatchNorm2dFusePass(Pass): +class DygraphBatchNorm2dFusePass(Pass): name = "dygraph_batchnorm2d_fuse_pass" def __init__(self): Pass.__init__(self) def apply(self, graph): - fuser = Dygraph_BatchNorm2dFuser() + fuser = DygraphBatchNorm2dFuser() fuser.operate(graph, match_kind="topo") # 用于注册 -batchnorm2d_fuse_pass = Dygraph_BatchNorm2dFusePass() +batchnorm2d_fuse_pass = DygraphBatchNorm2dFusePass() diff --git a/x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuser.py b/x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuser.py index 3de0bb9c88eccded5e3b1ae4c1a6a147c8e4ae1f..24dd27c3c3bf3e1b05d604e34e85f9039e4d5994 100644 --- a/x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuser.py @@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -class Dygraph_BatchNorm2dFuser(FuseBase): +class DygraphBatchNorm2dFuser(FuseBase): def __init__(self): - super(Dygraph_BatchNorm2dFuser, self).__init__(graph_type="dygraph") + super(DygraphBatchNorm2dFuser, self).__init__(graph_type="dygraph") def build_pattern(self): """ 描述需要替换的batchnorm2d图结构。 diff --git a/x2paddle/optimizer/fusion/dygraph/bn_scale_fuse_pass.py b/x2paddle/optimizer/fusion/dygraph/bn_scale_fuse_pass.py index 72b3d90a00b3c29a9a2dcdffb04164b9a19aa003..29a2b4fe3091efd864a33706a24e192326fbbe7d 100644 --- a/x2paddle/optimizer/fusion/dygraph/bn_scale_fuse_pass.py +++ b/x2paddle/optimizer/fusion/dygraph/bn_scale_fuse_pass.py @@ -13,21 +13,21 @@ # limitations under the License. from x2paddle.optimizer.pass_ import Pass -from x2paddle.optimizer.fusion.dygraph import Dygraph_BNScaleFuser +from x2paddle.optimizer.fusion.dygraph import DygraphBNScaleFuser from x2paddle.optimizer.pass_manager import pass_register @pass_register -class Dygraph_BNScaleFusePass(Pass): +class DygraphBNScaleFusePass(Pass): name = "dygraph_bn_scale_fuse_pass" def __init__(self): Pass.__init__(self) def apply(self, graph): - fuser = Dygraph_BNScaleFuser() + fuser = DygraphBNScaleFuser() fuser.operate(graph, match_kind="topo") # 用于注册 -bn_scale_fuse_pass = Dygraph_BNScaleFusePass() +bn_scale_fuse_pass = DygraphBNScaleFusePass() diff --git a/x2paddle/optimizer/fusion/dygraph/bn_scale_fuser.py b/x2paddle/optimizer/fusion/dygraph/bn_scale_fuser.py index 19214b89f2fca3a1e21afd6d7a0e73cc2faa4cdb..5b093d1b6b40871637f169dc858dd16d8e51a413 100644 --- a/x2paddle/optimizer/fusion/dygraph/bn_scale_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/bn_scale_fuser.py @@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -class Dygraph_BNScaleFuser(FuseBase): +class DygraphBNScaleFuser(FuseBase): def __init__(self): - super(Dygraph_BNScaleFuser, self).__init__(graph_type="dygraph") + super(DygraphBNScaleFuser, self).__init__(graph_type="dygraph") def build_pattern(self): """ 描述需要替换的batchnorm2d图结构。 diff --git a/x2paddle/optimizer/fusion/dygraph/constant_fuse_pass.py b/x2paddle/optimizer/fusion/dygraph/constant_fuse_pass.py index c708f4026328b1e854c7e829e853ba1ce83d5a74..46fef74b4d706a4e884b1c25a14a9edc32428013 100644 --- a/x2paddle/optimizer/fusion/dygraph/constant_fuse_pass.py +++ b/x2paddle/optimizer/fusion/dygraph/constant_fuse_pass.py @@ -13,21 +13,21 @@ # limitations under the License. from x2paddle.optimizer.pass_ import Pass -from x2paddle.optimizer.fusion.dygraph import Dygraph_ConstantFuser +from x2paddle.optimizer.fusion.dygraph import DygraphConstantFuser from x2paddle.optimizer.pass_manager import pass_register @pass_register -class Dygraph_ConstantFusePass(Pass): +class DygraphConstantFusePass(Pass): name = "dygraph_constant_fuse_pass" def __init__(self): Pass.__init__(self) def apply(self, graph): - fuser = Dygraph_ConstantFuser() + fuser = DygraphConstantFuser() fuser.operate(graph, match_kind="topo") # 用于注册 -constant_fuse_pass = Dygraph_ConstantFuser() +constant_fuse_pass = DygraphConstantFuser() diff --git a/x2paddle/optimizer/fusion/dygraph/constant_fuser.py b/x2paddle/optimizer/fusion/dygraph/constant_fuser.py index fa4bd05454cc2f4e39279607e061e95b7ef74920..904ff83a5995ccdbc79a5f2d4f3cd73f7c4021c9 100644 --- a/x2paddle/optimizer/fusion/dygraph/constant_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/constant_fuser.py @@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -class Dygraph_ConstantFuser(FuseBase): +class DygraphConstantFuser(FuseBase): def __init__(self): - super(Dygraph_ConstantFuser, self).__init__(graph_type="dygraph") + super(DygraphConstantFuser, self).__init__(graph_type="dygraph") def build_pattern(self): """ 描述需要替换的constant图结构。 diff --git a/x2paddle/optimizer/fusion/dygraph/conv2d_add_fuse_pass.py b/x2paddle/optimizer/fusion/dygraph/conv2d_add_fuse_pass.py index 8d796e94f1e328355657301b99def2d6ff603287..10b2663dd216ae48b315c4a3a5059fb26f0ddcd4 100644 --- a/x2paddle/optimizer/fusion/dygraph/conv2d_add_fuse_pass.py +++ b/x2paddle/optimizer/fusion/dygraph/conv2d_add_fuse_pass.py @@ -13,21 +13,21 @@ # limitations under the License. from x2paddle.optimizer.pass_ import Pass -from x2paddle.optimizer.fusion.dygraph import Dygraph_Conv2D_AddFuser +from x2paddle.optimizer.fusion.dygraph import DygraphConv2DAddFuser from x2paddle.optimizer.pass_manager import pass_register @pass_register -class Dygraph_Conv2D_AddFusePass(Pass): +class DygraphConv2DAddFusePass(Pass): name = "dygraph_conv2d_add_fuse_pass" def __init__(self): Pass.__init__(self) def apply(self, graph): - fuser = Dygraph_Conv2D_AddFuser() + fuser = DygraphConv2DAddFuser() fuser.operate(graph, match_kind="edge") # 用于注册 -dygraph_conv2d_add_fuse_pass = Dygraph_Conv2D_AddFusePass() \ No newline at end of file +dygraph_conv2d_add_fuse_pass = DygraphConv2DAddFusePass() \ No newline at end of file diff --git a/x2paddle/optimizer/fusion/dygraph/conv2d_add_fuser.py b/x2paddle/optimizer/fusion/dygraph/conv2d_add_fuser.py index 6d34621d64887f51964f5b18a3b2df115d857e1e..c7e715b45c69db9083f90b10d7dad5b3ae8072dd 100644 --- a/x2paddle/optimizer/fusion/dygraph/conv2d_add_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/conv2d_add_fuser.py @@ -19,9 +19,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -class Dygraph_Conv2D_AddFuser(FuseBase): +class DygraphConv2DAddFuser(FuseBase): def __init__(self): - super(Dygraph_Conv2D_AddFuser, self).__init__(graph_type="dygraph") + super(DygraphConv2DAddFuser, self).__init__(graph_type="dygraph") self.patterns = list() def build_pattern(self): diff --git a/x2paddle/optimizer/fusion/dygraph/dropout_fuse_pass.py b/x2paddle/optimizer/fusion/dygraph/dropout_fuse_pass.py index c0954441a8c60ad13189e6e6810d3c90f215dbea..778b3c90424ed394e3d4759e39e1ee68a754ae20 100644 --- a/x2paddle/optimizer/fusion/dygraph/dropout_fuse_pass.py +++ b/x2paddle/optimizer/fusion/dygraph/dropout_fuse_pass.py @@ -13,21 +13,21 @@ # limitations under the License. from x2paddle.optimizer.pass_ import Pass -from x2paddle.optimizer.fusion.dygraph import Dygraph_DropoutFuser +from x2paddle.optimizer.fusion.dygraph import DygraphDropoutFuser from x2paddle.optimizer.pass_manager import pass_register @pass_register -class Dygraph_DropoutFusePass(Pass): +class DygraphDropoutFusePass(Pass): name = "dygraph_dropout_fuse_pass" def __init__(self): Pass.__init__(self) def apply(self, graph): - fuser = Dygraph_DropoutFuser() + fuser = DygraphDropoutFuser() fuser.operate(graph, match_kind="topo") # 用于注册 -dropout_fuse_pass = Dygraph_DropoutFuser() +dropout_fuse_pass = DygraphDropoutFuser() diff --git a/x2paddle/optimizer/fusion/dygraph/dropout_fuser.py b/x2paddle/optimizer/fusion/dygraph/dropout_fuser.py index 6d41f676628ed7ac1226546c631a98b8a5d9be85..1457b934e06aab4d92d18150266e0daaf3d1bf20 100644 --- a/x2paddle/optimizer/fusion/dygraph/dropout_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/dropout_fuser.py @@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -class Dygraph_DropoutFuser(FuseBase): +class DygraphDropoutFuser(FuseBase): def __init__(self): - super(Dygraph_DropoutFuser, self).__init__(graph_type="dygraph") + super(DygraphDropoutFuser, self).__init__(graph_type="dygraph") def build_pattern(self): """ 描述需要替换的constant图结构。 diff --git a/x2paddle/optimizer/fusion/dygraph/fc_fuse_pass.py b/x2paddle/optimizer/fusion/dygraph/fc_fuse_pass.py index 54cce21c96738759b427b6e09efb0dbf47ecc6f6..8b89a7f62635107a8d6c88eef7e6b13b23ef9826 100644 --- a/x2paddle/optimizer/fusion/dygraph/fc_fuse_pass.py +++ b/x2paddle/optimizer/fusion/dygraph/fc_fuse_pass.py @@ -13,21 +13,21 @@ # limitations under the License. from x2paddle.optimizer.pass_ import Pass -from x2paddle.optimizer.fusion.dygraph import Dygraph_FcFuser +from x2paddle.optimizer.fusion.dygraph import DygraphFcFuser from x2paddle.optimizer.pass_manager import pass_register @pass_register -class Dygraph_FcFusePass(Pass): +class DygraphFcFusePass(Pass): name = "dygraph_fc_fuse_pass" def __init__(self): Pass.__init__(self) def apply(self, graph): - fuser = Dygraph_FcFuser() + fuser = DygraphFcFuser() fuser.operate(graph, match_kind="topo") # 用于注册 -fc_fuse_pass = Dygraph_FcFusePass() +fc_fuse_pass = DygraphFcFusePass() diff --git a/x2paddle/optimizer/fusion/dygraph/fc_fuser.py b/x2paddle/optimizer/fusion/dygraph/fc_fuser.py index 9c521428195b21ae573911396bd55269ee5ffe2c..b171bc28d2574beeda3b68b5c5195f601fb6da15 100644 --- a/x2paddle/optimizer/fusion/dygraph/fc_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/fc_fuser.py @@ -18,10 +18,10 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -class Dygraph_FcFuser(FuseBase): +class DygraphFcFuser(FuseBase): def __init__(self): self.linear_index = 0 - super(Dygraph_FcFuser, self).__init__(graph_type="dygraph") + super(DygraphFcFuser, self).__init__(graph_type="dygraph") def build_pattern(self): """ 描述需要替换的fc图结构。 diff --git a/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuse_pass.py b/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuse_pass.py index cc2138c7dce9adaff176cda1f47bbefa7b3edb6b..c9eb7dfe0752fc6aec0a43d99755fad2b2affc18 100644 --- a/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuse_pass.py +++ b/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuse_pass.py @@ -13,21 +13,21 @@ # limitations under the License. from x2paddle.optimizer.pass_ import Pass -from x2paddle.optimizer.fusion.dygraph import Dygraph_InterpolateBilinearFuser +from x2paddle.optimizer.fusion.dygraph import DygraphInterpolateBilinearFuser from x2paddle.optimizer.pass_manager import pass_register @pass_register -class Dygraph_InterpolateBilinearFusePass(Pass): +class DygraphInterpolateBilinearFusePass(Pass): name = "dygraph_interpolate_bilinear_fuse_pass" def __init__(self): Pass.__init__(self) def apply(self, graph): - fuser = Dygraph_InterpolateBilinearFuser() + fuser = DygraphInterpolateBilinearFuser() fuser.operate(graph, match_kind="topo") # 用于注册 -interpolate_bilinear_fuse_pass = Dygraph_InterpolateBilinearFusePass() +interpolate_bilinear_fuse_pass = DygraphInterpolateBilinearFusePass() diff --git a/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuser.py b/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuser.py index 896db6bbfb6d24335eaa432e121856316801c3a7..358eca83284299432424d77e79c84dfac084d1f8 100644 --- a/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuser.py @@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -class Dygraph_InterpolateBilinearFuser(FuseBase): +class DygraphInterpolateBilinearFuser(FuseBase): def __init__(self): - super(Dygraph_InterpolateBilinearFuser, self).__init__(graph_type="dygraph") + super(DygraphInterpolateBilinearFuser, self).__init__(graph_type="dygraph") import torch torch_version = torch.__version__ torch_version_part = torch_version.split(".") diff --git a/x2paddle/optimizer/fusion/dygraph/reshape_fuse_pass.py b/x2paddle/optimizer/fusion/dygraph/reshape_fuse_pass.py index b240fae3da24980e6480271b3eac660bb38734cd..28cfda047c11d3b47ab10f1861cd769f37de267a 100644 --- a/x2paddle/optimizer/fusion/dygraph/reshape_fuse_pass.py +++ b/x2paddle/optimizer/fusion/dygraph/reshape_fuse_pass.py @@ -13,21 +13,21 @@ # limitations under the License. from x2paddle.optimizer.pass_ import Pass -from x2paddle.optimizer.fusion.dygraph import Dygraph_ReshapeFuser +from x2paddle.optimizer.fusion.dygraph import DygraphReshapeFuser from x2paddle.optimizer.pass_manager import pass_register @pass_register -class Dygraph_ReshapeFusePass(Pass): +class DygraphReshapeFusePass(Pass): name = "dygraph_reshape_fuse_pass" def __init__(self): Pass.__init__(self) def apply(self, graph): - fuser = Dygraph_ReshapeFuser() + fuser = DygraphReshapeFuser() fuser.operate(graph, match_kind="edge") # 用于注册 -reshape_fuse_pass = Dygraph_ReshapeFusePass() +reshape_fuse_pass = DygraphReshapeFusePass() diff --git a/x2paddle/optimizer/fusion/dygraph/reshape_fuser.py b/x2paddle/optimizer/fusion/dygraph/reshape_fuser.py index 357b3d5dfb5896b3030dab6c63c3e5d904053548..a5a68258da941a5da302051055b22d3eb8a65f90 100644 --- a/x2paddle/optimizer/fusion/dygraph/reshape_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/reshape_fuser.py @@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -class Dygraph_ReshapeFuser(FuseBase): +class DygraphReshapeFuser(FuseBase): def __init__(self): - super(Dygraph_ReshapeFuser, self).__init__(graph_type="dygraph") + super(DygraphReshapeFuser, self).__init__(graph_type="dygraph") def build_pattern(self): """ 描述需要替换的reshape图结构。 diff --git a/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuse_pass.py b/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuse_pass.py index 52265fb598bab2ed2f15a9ac999e7e2374426349..5bfe144cebb691df4fd84c40714f82a8323195df 100644 --- a/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuse_pass.py +++ b/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuse_pass.py @@ -13,21 +13,21 @@ # limitations under the License. from x2paddle.optimizer.pass_ import Pass -from x2paddle.optimizer.fusion.dygraph import Dygraph_TF_BatchNormFuser +from x2paddle.optimizer.fusion.dygraph import DygraphTFBatchNormFuser from x2paddle.optimizer.pass_manager import pass_register @pass_register -class Dygraph_TF_BatchNormFusePass(Pass): +class DygraphTFBatchNormFusePass(Pass): name = "dygraph_tf_batchnorm_fuse_pass" def __init__(self): Pass.__init__(self) def apply(self, graph): - fuser = Dygraph_TF_BatchNormFuser() + fuser = DygraphTFBatchNormFuser() fuser.operate(graph, match_kind="edge") # 用于注册 -dygraph_tf_batchnorm_fuse_pass = Dygraph_TF_BatchNormFusePass() \ No newline at end of file +dygraph_tf_batchnorm_fuse_pass = DygraphTFBatchNormFusePass() \ No newline at end of file diff --git a/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py b/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py index e7391217c73401bd855ce1f523087f24bfd0253a..9dd0727f789455ade114aa2e3efdfca479b55888 100644 --- a/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py @@ -20,10 +20,10 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer from x2paddle.core.util import * -class Dygraph_TF_BatchNormFuser(FuseBase): +class DygraphTFBatchNormFuser(FuseBase): def __init__(self): self.bn_index = 0 - super(Dygraph_TF_BatchNormFuser, self).__init__(graph_type="dygraph") + super(DygraphTFBatchNormFuser, self).__init__(graph_type="dygraph") def build_pattern(self): """ 描述需要替换的batchnorm图结构。