Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
4c160be2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4c160be2
编写于
11月 16, 2021
作者:
Z
Zeng Jinle
提交者:
GitHub
11月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine pass by removing CommOpt, CalcOpt, ParallelOpt (#37206)
上级
70b7c7ed
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
30 addition
and
32 deletion
+30
-32
python/paddle/distributed/passes/cpp_pass.py
python/paddle/distributed/passes/cpp_pass.py
+4
-1
python/paddle/distributed/passes/fuse_all_reduce.py
python/paddle/distributed/passes/fuse_all_reduce.py
+5
-2
python/paddle/distributed/passes/pass_base.py
python/paddle/distributed/passes/pass_base.py
+21
-29
未找到文件。
python/paddle/distributed/passes/cpp_pass.py
浏览文件 @
4c160be2
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.pass_base
import
CPPPassWrapper
,
register_pass
from
.pass_base
import
PassType
,
CPPPassWrapper
,
register_pass
@
register_pass
(
"fuse_elewise_add_act"
)
@
register_pass
(
"fuse_elewise_add_act"
)
...
@@ -23,3 +23,6 @@ class FuseElementwiseAddActPass(CPPPassWrapper):
...
@@ -23,3 +23,6 @@ class FuseElementwiseAddActPass(CPPPassWrapper):
@
property
@
property
def
cpp_name
(
self
):
def
cpp_name
(
self
):
return
"fuse_elewise_add_act_pass"
return
"fuse_elewise_add_act_pass"
def
_type
(
self
):
return
PassType
.
FUSION_OPT
python/paddle/distributed/passes/fuse_all_reduce.py
浏览文件 @
4c160be2
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
from
paddle.framework
import
core
from
paddle.framework
import
core
from
paddle.fluid
import
unique_name
from
paddle.fluid
import
unique_name
from
.pass_base
import
CommOptPass
,
register_pass
from
.pass_base
import
PassBase
,
PassType
,
register_pass
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
numpy
as
np
import
numpy
as
np
...
@@ -329,7 +329,7 @@ def insert_fuse_all_reduce_by_memory_size(block, groups, max_memory_size):
...
@@ -329,7 +329,7 @@ def insert_fuse_all_reduce_by_memory_size(block, groups, max_memory_size):
@
register_pass
(
"fuse_all_reduce"
)
@
register_pass
(
"fuse_all_reduce"
)
class
FuseAllReducePass
(
CommOptPass
):
class
FuseAllReducePass
(
PassBase
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
FuseAllReducePass
,
self
).
__init__
()
super
(
FuseAllReducePass
,
self
).
__init__
()
self
.
set_attr
(
"max_memory_size"
,
-
1
)
self
.
set_attr
(
"max_memory_size"
,
-
1
)
...
@@ -341,6 +341,9 @@ class FuseAllReducePass(CommOptPass):
...
@@ -341,6 +341,9 @@ class FuseAllReducePass(CommOptPass):
def
_check_conflict
(
self
,
other_pass
):
def
_check_conflict
(
self
,
other_pass
):
return
True
return
True
def
_type
(
self
):
return
PassType
.
COMM_OPT
# NOTE: why FuseAllReducePass can override apply_single_impl instead of
# NOTE: why FuseAllReducePass can override apply_single_impl instead of
# apply_impl? AllReduce is a collective operation, so the program of each
# apply_impl? AllReduce is a collective operation, so the program of each
# rank inside the same communication group should have the same
# rank inside the same communication group should have the same
...
...
python/paddle/distributed/passes/pass_base.py
浏览文件 @
4c160be2
...
@@ -40,9 +40,20 @@ class PassContext:
...
@@ -40,9 +40,20 @@ class PassContext:
del
self
.
_applied_passes
[
-
1
]
del
self
.
_applied_passes
[
-
1
]
class
PassType
:
UNKNOWN
=
0
COMM_OPT
=
1
CALC_OPT
=
2
PARALLEL_OPT
=
3
FUSION_OPT
=
4
class
PassBase
(
ABC
):
class
PassBase
(
ABC
):
_REGISTERED_PASSES
=
{}
_REGISTERED_PASSES
=
{}
_COMMON_RULES
=
[]
_COMMON_RULES
=
[]
# TODO(zengjinle): add white/black list
name
=
None
@
staticmethod
@
staticmethod
def
_register
(
pass_name
,
pass_class
):
def
_register
(
pass_name
,
pass_class
):
...
@@ -67,6 +78,9 @@ class PassBase(ABC):
...
@@ -67,6 +78,9 @@ class PassBase(ABC):
def
_check_conflict
(
self
,
other_pass
):
def
_check_conflict
(
self
,
other_pass
):
pass
pass
def
_type
(
self
):
return
PassType
.
UNKNOWN
def
_check_conflict_including_common_rules
(
self
,
other_pass
):
def
_check_conflict_including_common_rules
(
self
,
other_pass
):
return
self
.
_check_conflict
(
other_pass
)
and
all
(
return
self
.
_check_conflict
(
other_pass
)
and
all
(
[
r
(
other_pass
,
self
)
for
r
in
PassBase
.
_COMMON_RULES
])
[
r
(
other_pass
,
self
)
for
r
in
PassBase
.
_COMMON_RULES
])
...
@@ -142,40 +156,18 @@ class CPPPassWrapper(PassBase):
...
@@ -142,40 +156,18 @@ class CPPPassWrapper(PassBase):
self
.
_attrs
,
self
.
cpp_attr_types
)
self
.
_attrs
,
self
.
cpp_attr_types
)
# Like AutoParallel/HybridParallel, etc.
def
_fusion_opt_last_rule
(
pass_before
,
pass_after
):
class
ParallelOptPass
(
PassBase
):
if
pass_before
.
_type
()
==
PassType
.
FUSION_OPT
and
pass_after
.
_type
(
def
__init__
(
self
):
)
!=
PassType
.
FUSION_OPT
:
super
(
ParallelOptPass
,
self
).
__init__
()
return
False
else
:
# Like AMP, Recompute, etc.
class
CalcOptPass
(
PassBase
):
def
__init__
(
self
):
super
(
CalcOptPass
,
self
).
__init__
()
# Like FuseAllReduce, FuseGradientMerge, etc.
class
CommOptPass
(
PassBase
):
def
__init__
(
self
):
super
(
CommOptPass
,
self
).
__init__
()
def
_make_pass_order_rule
(
pass_class_before
,
pass_class_after
):
def
impl
(
pass_obj_before
,
pass_obj_after
):
if
isinstance
(
pass_obj_before
,
pass_class_after
)
\
and
isinstance
(
pass_obj_after
,
pass_class_before
):
return
False
return
True
return
True
return
impl
PassBase
.
_COMMON_RULES
=
[
PassBase
.
_COMMON_RULES
=
[
_make_pass_order_rule
(
CalcOptPass
,
CommOptPass
),
_fusion_opt_last_rule
,
_make_pass_order_rule
(
ParallelOptPass
,
CPPPassWrapper
),
_make_pass_order_rule
(
CalcOptPass
,
CPPPassWrapper
),
_make_pass_order_rule
(
CommOptPass
,
CPPPassWrapper
),
lambda
pass_before
,
pass_after
:
type
(
pass_before
)
!=
type
(
pass_after
),
lambda
pass_before
,
pass_after
:
type
(
pass_before
)
!=
type
(
pass_after
),
# Add more common rules here
]
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录