提交 934ee6a8 编写于 作者: S SunAhong1993

renam

上级 3786c539
...@@ -12,23 +12,23 @@ ...@@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .adaptive_pool2d_fuser import Dygraph_AdaptivePool2dFuser from .adaptive_pool2d_fuser import DygraphAdaptivePool2dFuser
from .adaptive_pool2d_fuse_pass import Dygraph_AdaptivePool2dFusePass from .adaptive_pool2d_fuse_pass import DygraphAdaptivePool2dFusePass
from .batchnorm2d_fuser import Dygraph_BatchNorm2dFuser from .batchnorm2d_fuser import DygraphBatchNorm2dFuser
from .batchnorm2d_fuse_pass import Dygraph_BatchNorm2dFusePass from .batchnorm2d_fuse_pass import DygraphBatchNorm2dFusePass
from .bn_scale_fuser import Dygraph_BNScaleFuser from .bn_scale_fuser import DygraphBNScaleFuser
from .bn_scale_fuse_pass import Dygraph_BNScaleFusePass from .bn_scale_fuse_pass import DygraphBNScaleFusePass
from .constant_fuser import Dygraph_ConstantFuser from .constant_fuser import DygraphConstantFuser
from .constant_fuse_pass import Dygraph_ConstantFusePass from .constant_fuse_pass import DygraphConstantFusePass
from .conv2d_add_fuser import Dygraph_Conv2D_AddFuser from .conv2d_add_fuser import DygraphConv2DAddFuser
from .conv2d_add_fuse_pass import Dygraph_Conv2D_AddFusePass from .conv2d_add_fuse_pass import DygraphConv2DAddFusePass
from .dropout_fuser import Dygraph_DropoutFuser from .dropout_fuser import DygraphDropoutFuser
from .dropout_fuse_pass import Dygraph_DropoutFusePass from .dropout_fuse_pass import DygraphDropoutFusePass
from .fc_fuser import Dygraph_FcFuser from .fc_fuser import DygraphFcFuser
from .fc_fuse_pass import Dygraph_FcFusePass from .fc_fuse_pass import DygraphFcFusePass
from .interpolate_bilinear_fuser import Dygraph_InterpolateBilinearFuser from .interpolate_bilinear_fuser import DygraphInterpolateBilinearFuser
from .interpolate_bilinear_fuse_pass import Dygraph_InterpolateBilinearFusePass from .interpolate_bilinear_fuse_pass import DygraphInterpolateBilinearFusePass
from .reshape_fuser import Dygraph_ReshapeFuser from .reshape_fuser import DygraphReshapeFuser
from .reshape_fuse_pass import Dygraph_ReshapeFusePass from .reshape_fuse_pass import DygraphReshapeFusePass
from .tf_batchnorm_fuser import Dygraph_TF_BatchNormFuser from .tf_batchnorm_fuser import DygraphTFBatchNormFuser
from .tf_batchnorm_fuse_pass import Dygraph_TF_BatchNormFusePass from .tf_batchnorm_fuse_pass import DygraphTFBatchNormFusePass
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_AdaptivePool2dFuser from x2paddle.optimizer.fusion.dygraph import DygraphAdaptivePool2dFuser
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class Dygraph_AdaptivePool2dFusePass(Pass): class DygraphAdaptivePool2dFusePass(Pass):
name = "dygraph_adaptive_pool2d_fuse_pass" name = "dygraph_adaptive_pool2d_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = Dygraph_AdaptivePool2dFuser() fuser = DygraphAdaptivePool2dFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
adaptive_pool2d_fuse_pass = Dygraph_AdaptivePool2dFusePass() adaptive_pool2d_fuse_pass = DygraphAdaptivePool2dFusePass()
...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer ...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class Dygraph_AdaptivePool2dFuser(FuseBase): class DygraphAdaptivePool2dFuser(FuseBase):
def __init__(self): def __init__(self):
super(Dygraph_AdaptivePool2dFuser, self).__init__(graph_type="dygraph") super(DygraphAdaptivePool2dFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的adaptive pool2d图结构。 """ 描述需要替换的adaptive pool2d图结构。
......
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_BatchNorm2dFuser from x2paddle.optimizer.fusion.dygraph import DygraphBatchNorm2dFuser
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class Dygraph_BatchNorm2dFusePass(Pass): class DygraphBatchNorm2dFusePass(Pass):
name = "dygraph_batchnorm2d_fuse_pass" name = "dygraph_batchnorm2d_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = Dygraph_BatchNorm2dFuser() fuser = DygraphBatchNorm2dFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
batchnorm2d_fuse_pass = Dygraph_BatchNorm2dFusePass() batchnorm2d_fuse_pass = DygraphBatchNorm2dFusePass()
...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer ...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class Dygraph_BatchNorm2dFuser(FuseBase): class DygraphBatchNorm2dFuser(FuseBase):
def __init__(self): def __init__(self):
super(Dygraph_BatchNorm2dFuser, self).__init__(graph_type="dygraph") super(DygraphBatchNorm2dFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的batchnorm2d图结构。 """ 描述需要替换的batchnorm2d图结构。
......
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_BNScaleFuser from x2paddle.optimizer.fusion.dygraph import DygraphBNScaleFuser
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class Dygraph_BNScaleFusePass(Pass): class DygraphBNScaleFusePass(Pass):
name = "dygraph_bn_scale_fuse_pass" name = "dygraph_bn_scale_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = Dygraph_BNScaleFuser() fuser = DygraphBNScaleFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
bn_scale_fuse_pass = Dygraph_BNScaleFusePass() bn_scale_fuse_pass = DygraphBNScaleFusePass()
...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer ...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class Dygraph_BNScaleFuser(FuseBase): class DygraphBNScaleFuser(FuseBase):
def __init__(self): def __init__(self):
super(Dygraph_BNScaleFuser, self).__init__(graph_type="dygraph") super(DygraphBNScaleFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的batchnorm2d图结构。 """ 描述需要替换的batchnorm2d图结构。
......
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_ConstantFuser from x2paddle.optimizer.fusion.dygraph import DygraphConstantFuser
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class Dygraph_ConstantFusePass(Pass): class DygraphConstantFusePass(Pass):
name = "dygraph_constant_fuse_pass" name = "dygraph_constant_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = Dygraph_ConstantFuser() fuser = DygraphConstantFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
constant_fuse_pass = Dygraph_ConstantFuser() constant_fuse_pass = DygraphConstantFuser()
...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer ...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class Dygraph_ConstantFuser(FuseBase): class DygraphConstantFuser(FuseBase):
def __init__(self): def __init__(self):
super(Dygraph_ConstantFuser, self).__init__(graph_type="dygraph") super(DygraphConstantFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的constant图结构。 """ 描述需要替换的constant图结构。
......
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_Conv2D_AddFuser from x2paddle.optimizer.fusion.dygraph import DygraphConv2DAddFuser
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class Dygraph_Conv2D_AddFusePass(Pass): class DygraphConv2DAddFusePass(Pass):
name = "dygraph_conv2d_add_fuse_pass" name = "dygraph_conv2d_add_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = Dygraph_Conv2D_AddFuser() fuser = DygraphConv2DAddFuser()
fuser.operate(graph, match_kind="edge") fuser.operate(graph, match_kind="edge")
# 用于注册 # 用于注册
dygraph_conv2d_add_fuse_pass = Dygraph_Conv2D_AddFusePass() dygraph_conv2d_add_fuse_pass = DygraphConv2DAddFusePass()
\ No newline at end of file \ No newline at end of file
...@@ -19,9 +19,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer ...@@ -19,9 +19,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class Dygraph_Conv2D_AddFuser(FuseBase): class DygraphConv2DAddFuser(FuseBase):
def __init__(self): def __init__(self):
super(Dygraph_Conv2D_AddFuser, self).__init__(graph_type="dygraph") super(DygraphConv2DAddFuser, self).__init__(graph_type="dygraph")
self.patterns = list() self.patterns = list()
def build_pattern(self): def build_pattern(self):
......
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_DropoutFuser from x2paddle.optimizer.fusion.dygraph import DygraphDropoutFuser
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class Dygraph_DropoutFusePass(Pass): class DygraphDropoutFusePass(Pass):
name = "dygraph_dropout_fuse_pass" name = "dygraph_dropout_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = Dygraph_DropoutFuser() fuser = DygraphDropoutFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
dropout_fuse_pass = Dygraph_DropoutFuser() dropout_fuse_pass = DygraphDropoutFuser()
...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer ...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class Dygraph_DropoutFuser(FuseBase): class DygraphDropoutFuser(FuseBase):
def __init__(self): def __init__(self):
super(Dygraph_DropoutFuser, self).__init__(graph_type="dygraph") super(DygraphDropoutFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的constant图结构。 """ 描述需要替换的constant图结构。
......
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_FcFuser from x2paddle.optimizer.fusion.dygraph import DygraphFcFuser
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class Dygraph_FcFusePass(Pass): class DygraphFcFusePass(Pass):
name = "dygraph_fc_fuse_pass" name = "dygraph_fc_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = Dygraph_FcFuser() fuser = DygraphFcFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
fc_fuse_pass = Dygraph_FcFusePass() fc_fuse_pass = DygraphFcFusePass()
...@@ -18,10 +18,10 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer ...@@ -18,10 +18,10 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class Dygraph_FcFuser(FuseBase): class DygraphFcFuser(FuseBase):
def __init__(self): def __init__(self):
self.linear_index = 0 self.linear_index = 0
super(Dygraph_FcFuser, self).__init__(graph_type="dygraph") super(DygraphFcFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的fc图结构。 """ 描述需要替换的fc图结构。
......
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_InterpolateBilinearFuser from x2paddle.optimizer.fusion.dygraph import DygraphInterpolateBilinearFuser
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class Dygraph_InterpolateBilinearFusePass(Pass): class DygraphInterpolateBilinearFusePass(Pass):
name = "dygraph_interpolate_bilinear_fuse_pass" name = "dygraph_interpolate_bilinear_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = Dygraph_InterpolateBilinearFuser() fuser = DygraphInterpolateBilinearFuser()
fuser.operate(graph, match_kind="topo") fuser.operate(graph, match_kind="topo")
# 用于注册 # 用于注册
interpolate_bilinear_fuse_pass = Dygraph_InterpolateBilinearFusePass() interpolate_bilinear_fuse_pass = DygraphInterpolateBilinearFusePass()
...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer ...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class Dygraph_InterpolateBilinearFuser(FuseBase): class DygraphInterpolateBilinearFuser(FuseBase):
def __init__(self): def __init__(self):
super(Dygraph_InterpolateBilinearFuser, self).__init__(graph_type="dygraph") super(DygraphInterpolateBilinearFuser, self).__init__(graph_type="dygraph")
import torch import torch
torch_version = torch.__version__ torch_version = torch.__version__
torch_version_part = torch_version.split(".") torch_version_part = torch_version.split(".")
......
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_ReshapeFuser from x2paddle.optimizer.fusion.dygraph import DygraphReshapeFuser
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class Dygraph_ReshapeFusePass(Pass): class DygraphReshapeFusePass(Pass):
name = "dygraph_reshape_fuse_pass" name = "dygraph_reshape_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = Dygraph_ReshapeFuser() fuser = DygraphReshapeFuser()
fuser.operate(graph, match_kind="edge") fuser.operate(graph, match_kind="edge")
# 用于注册 # 用于注册
reshape_fuse_pass = Dygraph_ReshapeFusePass() reshape_fuse_pass = DygraphReshapeFusePass()
...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer ...@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class Dygraph_ReshapeFuser(FuseBase): class DygraphReshapeFuser(FuseBase):
def __init__(self): def __init__(self):
super(Dygraph_ReshapeFuser, self).__init__(graph_type="dygraph") super(DygraphReshapeFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的reshape图结构。 """ 描述需要替换的reshape图结构。
......
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.pass_ import Pass from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import Dygraph_TF_BatchNormFuser from x2paddle.optimizer.fusion.dygraph import DygraphTFBatchNormFuser
from x2paddle.optimizer.pass_manager import pass_register from x2paddle.optimizer.pass_manager import pass_register
@pass_register @pass_register
class Dygraph_TF_BatchNormFusePass(Pass): class DygraphTFBatchNormFusePass(Pass):
name = "dygraph_tf_batchnorm_fuse_pass" name = "dygraph_tf_batchnorm_fuse_pass"
def __init__(self): def __init__(self):
Pass.__init__(self) Pass.__init__(self)
def apply(self, graph): def apply(self, graph):
fuser = Dygraph_TF_BatchNormFuser() fuser = DygraphTFBatchNormFuser()
fuser.operate(graph, match_kind="edge") fuser.operate(graph, match_kind="edge")
# 用于注册 # 用于注册
dygraph_tf_batchnorm_fuse_pass = Dygraph_TF_BatchNormFusePass() dygraph_tf_batchnorm_fuse_pass = DygraphTFBatchNormFusePass()
\ No newline at end of file \ No newline at end of file
...@@ -20,10 +20,10 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer ...@@ -20,10 +20,10 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import * from x2paddle.core.util import *
class Dygraph_TF_BatchNormFuser(FuseBase): class DygraphTFBatchNormFuser(FuseBase):
def __init__(self): def __init__(self):
self.bn_index = 0 self.bn_index = 0
super(Dygraph_TF_BatchNormFuser, self).__init__(graph_type="dygraph") super(DygraphTFBatchNormFuser, self).__init__(graph_type="dygraph")
def build_pattern(self): def build_pattern(self):
""" 描述需要替换的batchnorm图结构。 """ 描述需要替换的batchnorm图结构。
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册