未验证 提交 4c160be2 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine pass by removing CommOpt, CalcOpt, ParallelOpt (#37206)

上级 70b7c7ed
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .pass_base import CPPPassWrapper, register_pass
from .pass_base import PassType, CPPPassWrapper, register_pass
@register_pass("fuse_elewise_add_act")
......@@ -23,3 +23,6 @@ class FuseElementwiseAddActPass(CPPPassWrapper):
@property
def cpp_name(self):
return "fuse_elewise_add_act_pass"
def _type(self):
return PassType.FUSION_OPT
......@@ -14,7 +14,7 @@
from paddle.framework import core
from paddle.fluid import unique_name
from .pass_base import CommOptPass, register_pass
from .pass_base import PassBase, PassType, register_pass
from collections import OrderedDict
import numpy as np
......@@ -329,7 +329,7 @@ def insert_fuse_all_reduce_by_memory_size(block, groups, max_memory_size):
@register_pass("fuse_all_reduce")
class FuseAllReducePass(CommOptPass):
class FuseAllReducePass(PassBase):
def __init__(self):
super(FuseAllReducePass, self).__init__()
self.set_attr("max_memory_size", -1)
......@@ -341,6 +341,9 @@ class FuseAllReducePass(CommOptPass):
def _check_conflict(self, other_pass):
return True
def _type(self):
return PassType.COMM_OPT
# NOTE: why FuseAllReducePass can override apply_single_impl instead of
# apply_impl? AllReduce is a collective operation, so the program of each
# rank inside the same communication group should have the same
......
......@@ -40,9 +40,20 @@ class PassContext:
del self._applied_passes[-1]
class PassType:
UNKNOWN = 0
COMM_OPT = 1
CALC_OPT = 2
PARALLEL_OPT = 3
FUSION_OPT = 4
class PassBase(ABC):
_REGISTERED_PASSES = {}
_COMMON_RULES = []
# TODO(zengjinle): add white/black list
name = None
@staticmethod
def _register(pass_name, pass_class):
......@@ -67,6 +78,9 @@ class PassBase(ABC):
def _check_conflict(self, other_pass):
pass
def _type(self):
return PassType.UNKNOWN
def _check_conflict_including_common_rules(self, other_pass):
return self._check_conflict(other_pass) and all(
[r(other_pass, self) for r in PassBase._COMMON_RULES])
......@@ -142,40 +156,18 @@ class CPPPassWrapper(PassBase):
self._attrs, self.cpp_attr_types)
# Like AutoParallel/HybridParallel, etc.
class ParallelOptPass(PassBase):
def __init__(self):
super(ParallelOptPass, self).__init__()
# Like AMP, Recompute, etc.
class CalcOptPass(PassBase):
def __init__(self):
super(CalcOptPass, self).__init__()
# Like FuseAllReduce, FuseGradientMerge, etc.
class CommOptPass(PassBase):
def __init__(self):
super(CommOptPass, self).__init__()
def _make_pass_order_rule(pass_class_before, pass_class_after):
def impl(pass_obj_before, pass_obj_after):
if isinstance(pass_obj_before, pass_class_after) \
and isinstance(pass_obj_after, pass_class_before):
def _fusion_opt_last_rule(pass_before, pass_after):
if pass_before._type() == PassType.FUSION_OPT and pass_after._type(
) != PassType.FUSION_OPT:
return False
else:
return True
return impl
PassBase._COMMON_RULES = [
_make_pass_order_rule(CalcOptPass, CommOptPass),
_make_pass_order_rule(ParallelOptPass, CPPPassWrapper),
_make_pass_order_rule(CalcOptPass, CPPPassWrapper),
_make_pass_order_rule(CommOptPass, CPPPassWrapper),
_fusion_opt_last_rule,
lambda pass_before, pass_after: type(pass_before) != type(pass_after),
# Add more common rules here
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册