Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
934ee6a8
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
934ee6a8
编写于
11月 13, 2020
作者:
S
SunAhong1993
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
renam
上级
3786c539
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
80 addition
and
80 deletion
+80
-80
x2paddle/optimizer/fusion/dygraph/__init__.py
x2paddle/optimizer/fusion/dygraph/__init__.py
+20
-20
x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuse_pass.py
...dle/optimizer/fusion/dygraph/adaptive_pool2d_fuse_pass.py
+4
-4
x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuser.py
x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuser.py
+2
-2
x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuse_pass.py
x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuse_pass.py
+4
-4
x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuser.py
x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuser.py
+2
-2
x2paddle/optimizer/fusion/dygraph/bn_scale_fuse_pass.py
x2paddle/optimizer/fusion/dygraph/bn_scale_fuse_pass.py
+4
-4
x2paddle/optimizer/fusion/dygraph/bn_scale_fuser.py
x2paddle/optimizer/fusion/dygraph/bn_scale_fuser.py
+2
-2
x2paddle/optimizer/fusion/dygraph/constant_fuse_pass.py
x2paddle/optimizer/fusion/dygraph/constant_fuse_pass.py
+4
-4
x2paddle/optimizer/fusion/dygraph/constant_fuser.py
x2paddle/optimizer/fusion/dygraph/constant_fuser.py
+2
-2
x2paddle/optimizer/fusion/dygraph/conv2d_add_fuse_pass.py
x2paddle/optimizer/fusion/dygraph/conv2d_add_fuse_pass.py
+4
-4
x2paddle/optimizer/fusion/dygraph/conv2d_add_fuser.py
x2paddle/optimizer/fusion/dygraph/conv2d_add_fuser.py
+2
-2
x2paddle/optimizer/fusion/dygraph/dropout_fuse_pass.py
x2paddle/optimizer/fusion/dygraph/dropout_fuse_pass.py
+4
-4
x2paddle/optimizer/fusion/dygraph/dropout_fuser.py
x2paddle/optimizer/fusion/dygraph/dropout_fuser.py
+2
-2
x2paddle/optimizer/fusion/dygraph/fc_fuse_pass.py
x2paddle/optimizer/fusion/dygraph/fc_fuse_pass.py
+4
-4
x2paddle/optimizer/fusion/dygraph/fc_fuser.py
x2paddle/optimizer/fusion/dygraph/fc_fuser.py
+2
-2
x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuse_pass.py
...ptimizer/fusion/dygraph/interpolate_bilinear_fuse_pass.py
+4
-4
x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuser.py
...le/optimizer/fusion/dygraph/interpolate_bilinear_fuser.py
+2
-2
x2paddle/optimizer/fusion/dygraph/reshape_fuse_pass.py
x2paddle/optimizer/fusion/dygraph/reshape_fuse_pass.py
+4
-4
x2paddle/optimizer/fusion/dygraph/reshape_fuser.py
x2paddle/optimizer/fusion/dygraph/reshape_fuser.py
+2
-2
x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuse_pass.py
x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuse_pass.py
+4
-4
x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py
x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py
+2
-2
未找到文件。
x2paddle/optimizer/fusion/dygraph/__init__.py
浏览文件 @
934ee6a8
...
...
@@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.adaptive_pool2d_fuser
import
Dygraph
_
AdaptivePool2dFuser
from
.adaptive_pool2d_fuse_pass
import
Dygraph
_
AdaptivePool2dFusePass
from
.batchnorm2d_fuser
import
Dygraph
_
BatchNorm2dFuser
from
.batchnorm2d_fuse_pass
import
Dygraph
_
BatchNorm2dFusePass
from
.bn_scale_fuser
import
Dygraph
_
BNScaleFuser
from
.bn_scale_fuse_pass
import
Dygraph
_
BNScaleFusePass
from
.constant_fuser
import
Dygraph
_
ConstantFuser
from
.constant_fuse_pass
import
Dygraph
_
ConstantFusePass
from
.conv2d_add_fuser
import
Dygraph
_Conv2D_
AddFuser
from
.conv2d_add_fuse_pass
import
Dygraph
_Conv2D_
AddFusePass
from
.dropout_fuser
import
Dygraph
_
DropoutFuser
from
.dropout_fuse_pass
import
Dygraph
_
DropoutFusePass
from
.fc_fuser
import
Dygraph
_
FcFuser
from
.fc_fuse_pass
import
Dygraph
_
FcFusePass
from
.interpolate_bilinear_fuser
import
Dygraph
_
InterpolateBilinearFuser
from
.interpolate_bilinear_fuse_pass
import
Dygraph
_
InterpolateBilinearFusePass
from
.reshape_fuser
import
Dygraph
_
ReshapeFuser
from
.reshape_fuse_pass
import
Dygraph
_
ReshapeFusePass
from
.tf_batchnorm_fuser
import
Dygraph
_TF_
BatchNormFuser
from
.tf_batchnorm_fuse_pass
import
Dygraph
_TF_
BatchNormFusePass
from
.adaptive_pool2d_fuser
import
DygraphAdaptivePool2dFuser
from
.adaptive_pool2d_fuse_pass
import
DygraphAdaptivePool2dFusePass
from
.batchnorm2d_fuser
import
DygraphBatchNorm2dFuser
from
.batchnorm2d_fuse_pass
import
DygraphBatchNorm2dFusePass
from
.bn_scale_fuser
import
DygraphBNScaleFuser
from
.bn_scale_fuse_pass
import
DygraphBNScaleFusePass
from
.constant_fuser
import
DygraphConstantFuser
from
.constant_fuse_pass
import
DygraphConstantFusePass
from
.conv2d_add_fuser
import
Dygraph
Conv2D
AddFuser
from
.conv2d_add_fuse_pass
import
Dygraph
Conv2D
AddFusePass
from
.dropout_fuser
import
DygraphDropoutFuser
from
.dropout_fuse_pass
import
DygraphDropoutFusePass
from
.fc_fuser
import
DygraphFcFuser
from
.fc_fuse_pass
import
DygraphFcFusePass
from
.interpolate_bilinear_fuser
import
DygraphInterpolateBilinearFuser
from
.interpolate_bilinear_fuse_pass
import
DygraphInterpolateBilinearFusePass
from
.reshape_fuser
import
DygraphReshapeFuser
from
.reshape_fuse_pass
import
DygraphReshapeFusePass
from
.tf_batchnorm_fuser
import
Dygraph
TF
BatchNormFuser
from
.tf_batchnorm_fuse_pass
import
Dygraph
TF
BatchNormFusePass
x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuse_pass.py
浏览文件 @
934ee6a8
...
...
@@ -13,21 +13,21 @@
# limitations under the License.
from
x2paddle.optimizer.pass_
import
Pass
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
_
AdaptivePool2dFuser
from
x2paddle.optimizer.fusion.dygraph
import
DygraphAdaptivePool2dFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
Dygraph
_
AdaptivePool2dFusePass
(
Pass
):
class
DygraphAdaptivePool2dFusePass
(
Pass
):
name
=
"dygraph_adaptive_pool2d_fuse_pass"
def
__init__
(
self
):
Pass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
Dygraph
_
AdaptivePool2dFuser
()
fuser
=
DygraphAdaptivePool2dFuser
()
fuser
.
operate
(
graph
,
match_kind
=
"topo"
)
# 用于注册
adaptive_pool2d_fuse_pass
=
Dygraph
_
AdaptivePool2dFusePass
()
adaptive_pool2d_fuse_pass
=
DygraphAdaptivePool2dFusePass
()
x2paddle/optimizer/fusion/dygraph/adaptive_pool2d_fuser.py
浏览文件 @
934ee6a8
...
...
@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from
x2paddle.core.util
import
*
class
Dygraph
_
AdaptivePool2dFuser
(
FuseBase
):
class
DygraphAdaptivePool2dFuser
(
FuseBase
):
def
__init__
(
self
):
super
(
Dygraph
_
AdaptivePool2dFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
super
(
DygraphAdaptivePool2dFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
def
build_pattern
(
self
):
""" 描述需要替换的adaptive pool2d图结构。
...
...
x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuse_pass.py
浏览文件 @
934ee6a8
...
...
@@ -13,21 +13,21 @@
# limitations under the License.
from
x2paddle.optimizer.pass_
import
Pass
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
_
BatchNorm2dFuser
from
x2paddle.optimizer.fusion.dygraph
import
DygraphBatchNorm2dFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
Dygraph
_
BatchNorm2dFusePass
(
Pass
):
class
DygraphBatchNorm2dFusePass
(
Pass
):
name
=
"dygraph_batchnorm2d_fuse_pass"
def
__init__
(
self
):
Pass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
Dygraph
_
BatchNorm2dFuser
()
fuser
=
DygraphBatchNorm2dFuser
()
fuser
.
operate
(
graph
,
match_kind
=
"topo"
)
# 用于注册
batchnorm2d_fuse_pass
=
Dygraph
_
BatchNorm2dFusePass
()
batchnorm2d_fuse_pass
=
DygraphBatchNorm2dFusePass
()
x2paddle/optimizer/fusion/dygraph/batchnorm2d_fuser.py
浏览文件 @
934ee6a8
...
...
@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from
x2paddle.core.util
import
*
class
Dygraph
_
BatchNorm2dFuser
(
FuseBase
):
class
DygraphBatchNorm2dFuser
(
FuseBase
):
def
__init__
(
self
):
super
(
Dygraph
_
BatchNorm2dFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
super
(
DygraphBatchNorm2dFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
def
build_pattern
(
self
):
""" 描述需要替换的batchnorm2d图结构。
...
...
x2paddle/optimizer/fusion/dygraph/bn_scale_fuse_pass.py
浏览文件 @
934ee6a8
...
...
@@ -13,21 +13,21 @@
# limitations under the License.
from
x2paddle.optimizer.pass_
import
Pass
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
_
BNScaleFuser
from
x2paddle.optimizer.fusion.dygraph
import
DygraphBNScaleFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
Dygraph
_
BNScaleFusePass
(
Pass
):
class
DygraphBNScaleFusePass
(
Pass
):
name
=
"dygraph_bn_scale_fuse_pass"
def
__init__
(
self
):
Pass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
Dygraph
_
BNScaleFuser
()
fuser
=
DygraphBNScaleFuser
()
fuser
.
operate
(
graph
,
match_kind
=
"topo"
)
# 用于注册
bn_scale_fuse_pass
=
Dygraph
_
BNScaleFusePass
()
bn_scale_fuse_pass
=
DygraphBNScaleFusePass
()
x2paddle/optimizer/fusion/dygraph/bn_scale_fuser.py
浏览文件 @
934ee6a8
...
...
@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from
x2paddle.core.util
import
*
class
Dygraph
_
BNScaleFuser
(
FuseBase
):
class
DygraphBNScaleFuser
(
FuseBase
):
def
__init__
(
self
):
super
(
Dygraph
_
BNScaleFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
super
(
DygraphBNScaleFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
def
build_pattern
(
self
):
""" 描述需要替换的batchnorm2d图结构。
...
...
x2paddle/optimizer/fusion/dygraph/constant_fuse_pass.py
浏览文件 @
934ee6a8
...
...
@@ -13,21 +13,21 @@
# limitations under the License.
from
x2paddle.optimizer.pass_
import
Pass
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
_
ConstantFuser
from
x2paddle.optimizer.fusion.dygraph
import
DygraphConstantFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
Dygraph
_
ConstantFusePass
(
Pass
):
class
DygraphConstantFusePass
(
Pass
):
name
=
"dygraph_constant_fuse_pass"
def
__init__
(
self
):
Pass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
Dygraph
_
ConstantFuser
()
fuser
=
DygraphConstantFuser
()
fuser
.
operate
(
graph
,
match_kind
=
"topo"
)
# 用于注册
constant_fuse_pass
=
Dygraph
_
ConstantFuser
()
constant_fuse_pass
=
DygraphConstantFuser
()
x2paddle/optimizer/fusion/dygraph/constant_fuser.py
浏览文件 @
934ee6a8
...
...
@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from
x2paddle.core.util
import
*
class
Dygraph
_
ConstantFuser
(
FuseBase
):
class
DygraphConstantFuser
(
FuseBase
):
def
__init__
(
self
):
super
(
Dygraph
_
ConstantFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
super
(
DygraphConstantFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
def
build_pattern
(
self
):
""" 描述需要替换的constant图结构。
...
...
x2paddle/optimizer/fusion/dygraph/conv2d_add_fuse_pass.py
浏览文件 @
934ee6a8
...
...
@@ -13,21 +13,21 @@
# limitations under the License.
from
x2paddle.optimizer.pass_
import
Pass
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
_Conv2D_
AddFuser
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
Conv2D
AddFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
Dygraph
_Conv2D_
AddFusePass
(
Pass
):
class
Dygraph
Conv2D
AddFusePass
(
Pass
):
name
=
"dygraph_conv2d_add_fuse_pass"
def
__init__
(
self
):
Pass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
Dygraph
_Conv2D_
AddFuser
()
fuser
=
Dygraph
Conv2D
AddFuser
()
fuser
.
operate
(
graph
,
match_kind
=
"edge"
)
# 用于注册
dygraph_conv2d_add_fuse_pass
=
Dygraph_Conv2D_AddFusePass
()
\ No newline at end of file
dygraph_conv2d_add_fuse_pass
=
DygraphConv2DAddFusePass
()
\ No newline at end of file
x2paddle/optimizer/fusion/dygraph/conv2d_add_fuser.py
浏览文件 @
934ee6a8
...
...
@@ -19,9 +19,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from
x2paddle.core.util
import
*
class
Dygraph
_Conv2D_
AddFuser
(
FuseBase
):
class
Dygraph
Conv2D
AddFuser
(
FuseBase
):
def
__init__
(
self
):
super
(
Dygraph
_Conv2D_
AddFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
super
(
Dygraph
Conv2D
AddFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
self
.
patterns
=
list
()
def
build_pattern
(
self
):
...
...
x2paddle/optimizer/fusion/dygraph/dropout_fuse_pass.py
浏览文件 @
934ee6a8
...
...
@@ -13,21 +13,21 @@
# limitations under the License.
from
x2paddle.optimizer.pass_
import
Pass
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
_
DropoutFuser
from
x2paddle.optimizer.fusion.dygraph
import
DygraphDropoutFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
Dygraph
_
DropoutFusePass
(
Pass
):
class
DygraphDropoutFusePass
(
Pass
):
name
=
"dygraph_dropout_fuse_pass"
def
__init__
(
self
):
Pass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
Dygraph
_
DropoutFuser
()
fuser
=
DygraphDropoutFuser
()
fuser
.
operate
(
graph
,
match_kind
=
"topo"
)
# 用于注册
dropout_fuse_pass
=
Dygraph
_
DropoutFuser
()
dropout_fuse_pass
=
DygraphDropoutFuser
()
x2paddle/optimizer/fusion/dygraph/dropout_fuser.py
浏览文件 @
934ee6a8
...
...
@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from
x2paddle.core.util
import
*
class
Dygraph
_
DropoutFuser
(
FuseBase
):
class
DygraphDropoutFuser
(
FuseBase
):
def
__init__
(
self
):
super
(
Dygraph
_
DropoutFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
super
(
DygraphDropoutFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
def
build_pattern
(
self
):
""" 描述需要替换的constant图结构。
...
...
x2paddle/optimizer/fusion/dygraph/fc_fuse_pass.py
浏览文件 @
934ee6a8
...
...
@@ -13,21 +13,21 @@
# limitations under the License.
from
x2paddle.optimizer.pass_
import
Pass
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
_
FcFuser
from
x2paddle.optimizer.fusion.dygraph
import
DygraphFcFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
Dygraph
_
FcFusePass
(
Pass
):
class
DygraphFcFusePass
(
Pass
):
name
=
"dygraph_fc_fuse_pass"
def
__init__
(
self
):
Pass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
Dygraph
_
FcFuser
()
fuser
=
DygraphFcFuser
()
fuser
.
operate
(
graph
,
match_kind
=
"topo"
)
# 用于注册
fc_fuse_pass
=
Dygraph
_
FcFusePass
()
fc_fuse_pass
=
DygraphFcFusePass
()
x2paddle/optimizer/fusion/dygraph/fc_fuser.py
浏览文件 @
934ee6a8
...
...
@@ -18,10 +18,10 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from
x2paddle.core.util
import
*
class
Dygraph
_
FcFuser
(
FuseBase
):
class
DygraphFcFuser
(
FuseBase
):
def
__init__
(
self
):
self
.
linear_index
=
0
super
(
Dygraph
_
FcFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
super
(
DygraphFcFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
def
build_pattern
(
self
):
""" 描述需要替换的fc图结构。
...
...
x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuse_pass.py
浏览文件 @
934ee6a8
...
...
@@ -13,21 +13,21 @@
# limitations under the License.
from
x2paddle.optimizer.pass_
import
Pass
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
_
InterpolateBilinearFuser
from
x2paddle.optimizer.fusion.dygraph
import
DygraphInterpolateBilinearFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
Dygraph
_
InterpolateBilinearFusePass
(
Pass
):
class
DygraphInterpolateBilinearFusePass
(
Pass
):
name
=
"dygraph_interpolate_bilinear_fuse_pass"
def
__init__
(
self
):
Pass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
Dygraph
_
InterpolateBilinearFuser
()
fuser
=
DygraphInterpolateBilinearFuser
()
fuser
.
operate
(
graph
,
match_kind
=
"topo"
)
# 用于注册
interpolate_bilinear_fuse_pass
=
Dygraph
_
InterpolateBilinearFusePass
()
interpolate_bilinear_fuse_pass
=
DygraphInterpolateBilinearFusePass
()
x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuser.py
浏览文件 @
934ee6a8
...
...
@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from
x2paddle.core.util
import
*
class
Dygraph
_
InterpolateBilinearFuser
(
FuseBase
):
class
DygraphInterpolateBilinearFuser
(
FuseBase
):
def
__init__
(
self
):
super
(
Dygraph
_
InterpolateBilinearFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
super
(
DygraphInterpolateBilinearFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
import
torch
torch_version
=
torch
.
__version__
torch_version_part
=
torch_version
.
split
(
"."
)
...
...
x2paddle/optimizer/fusion/dygraph/reshape_fuse_pass.py
浏览文件 @
934ee6a8
...
...
@@ -13,21 +13,21 @@
# limitations under the License.
from
x2paddle.optimizer.pass_
import
Pass
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
_
ReshapeFuser
from
x2paddle.optimizer.fusion.dygraph
import
DygraphReshapeFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
Dygraph
_
ReshapeFusePass
(
Pass
):
class
DygraphReshapeFusePass
(
Pass
):
name
=
"dygraph_reshape_fuse_pass"
def
__init__
(
self
):
Pass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
Dygraph
_
ReshapeFuser
()
fuser
=
DygraphReshapeFuser
()
fuser
.
operate
(
graph
,
match_kind
=
"edge"
)
# 用于注册
reshape_fuse_pass
=
Dygraph
_
ReshapeFusePass
()
reshape_fuse_pass
=
DygraphReshapeFusePass
()
x2paddle/optimizer/fusion/dygraph/reshape_fuser.py
浏览文件 @
934ee6a8
...
...
@@ -18,9 +18,9 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from
x2paddle.core.util
import
*
class
Dygraph
_
ReshapeFuser
(
FuseBase
):
class
DygraphReshapeFuser
(
FuseBase
):
def
__init__
(
self
):
super
(
Dygraph
_
ReshapeFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
super
(
DygraphReshapeFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
def
build_pattern
(
self
):
""" 描述需要替换的reshape图结构。
...
...
x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuse_pass.py
浏览文件 @
934ee6a8
...
...
@@ -13,21 +13,21 @@
# limitations under the License.
from
x2paddle.optimizer.pass_
import
Pass
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
_TF_
BatchNormFuser
from
x2paddle.optimizer.fusion.dygraph
import
Dygraph
TF
BatchNormFuser
from
x2paddle.optimizer.pass_manager
import
pass_register
@
pass_register
class
Dygraph
_TF_
BatchNormFusePass
(
Pass
):
class
Dygraph
TF
BatchNormFusePass
(
Pass
):
name
=
"dygraph_tf_batchnorm_fuse_pass"
def
__init__
(
self
):
Pass
.
__init__
(
self
)
def
apply
(
self
,
graph
):
fuser
=
Dygraph
_TF_
BatchNormFuser
()
fuser
=
Dygraph
TF
BatchNormFuser
()
fuser
.
operate
(
graph
,
match_kind
=
"edge"
)
# 用于注册
dygraph_tf_batchnorm_fuse_pass
=
Dygraph_TF_BatchNormFusePass
()
\ No newline at end of file
dygraph_tf_batchnorm_fuse_pass
=
DygraphTFBatchNormFusePass
()
\ No newline at end of file
x2paddle/optimizer/fusion/dygraph/tf_batchnorm_fuser.py
浏览文件 @
934ee6a8
...
...
@@ -20,10 +20,10 @@ from x2paddle.core.program import PaddleGraph, PaddleLayer
from
x2paddle.core.util
import
*
class
Dygraph
_TF_
BatchNormFuser
(
FuseBase
):
class
Dygraph
TF
BatchNormFuser
(
FuseBase
):
def
__init__
(
self
):
self
.
bn_index
=
0
super
(
Dygraph
_TF_
BatchNormFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
super
(
Dygraph
TF
BatchNormFuser
,
self
).
__init__
(
graph_type
=
"dygraph"
)
def
build_pattern
(
self
):
""" 描述需要替换的batchnorm图结构。
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录