未验证 提交 4c160be2 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine pass by removing CommOpt, CalcOpt, ParallelOpt (#37206)

上级 70b7c7ed
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .pass_base import CPPPassWrapper, register_pass from .pass_base import PassType, CPPPassWrapper, register_pass
@register_pass("fuse_elewise_add_act") @register_pass("fuse_elewise_add_act")
...@@ -23,3 +23,6 @@ class FuseElementwiseAddActPass(CPPPassWrapper): ...@@ -23,3 +23,6 @@ class FuseElementwiseAddActPass(CPPPassWrapper):
@property @property
def cpp_name(self): def cpp_name(self):
return "fuse_elewise_add_act_pass" return "fuse_elewise_add_act_pass"
def _type(self):
return PassType.FUSION_OPT
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from paddle.framework import core from paddle.framework import core
from paddle.fluid import unique_name from paddle.fluid import unique_name
from .pass_base import CommOptPass, register_pass from .pass_base import PassBase, PassType, register_pass
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
...@@ -329,7 +329,7 @@ def insert_fuse_all_reduce_by_memory_size(block, groups, max_memory_size): ...@@ -329,7 +329,7 @@ def insert_fuse_all_reduce_by_memory_size(block, groups, max_memory_size):
@register_pass("fuse_all_reduce") @register_pass("fuse_all_reduce")
class FuseAllReducePass(CommOptPass): class FuseAllReducePass(PassBase):
def __init__(self): def __init__(self):
super(FuseAllReducePass, self).__init__() super(FuseAllReducePass, self).__init__()
self.set_attr("max_memory_size", -1) self.set_attr("max_memory_size", -1)
...@@ -341,6 +341,9 @@ class FuseAllReducePass(CommOptPass): ...@@ -341,6 +341,9 @@ class FuseAllReducePass(CommOptPass):
def _check_conflict(self, other_pass): def _check_conflict(self, other_pass):
return True return True
def _type(self):
return PassType.COMM_OPT
# NOTE: why FuseAllReducePass can override apply_single_impl instead of # NOTE: why FuseAllReducePass can override apply_single_impl instead of
# apply_impl? AllReduce is a collective operation, so the program of each # apply_impl? AllReduce is a collective operation, so the program of each
# rank inside the same communication group should have the same # rank inside the same communication group should have the same
......
...@@ -40,9 +40,20 @@ class PassContext: ...@@ -40,9 +40,20 @@ class PassContext:
del self._applied_passes[-1] del self._applied_passes[-1]
class PassType:
UNKNOWN = 0
COMM_OPT = 1
CALC_OPT = 2
PARALLEL_OPT = 3
FUSION_OPT = 4
class PassBase(ABC): class PassBase(ABC):
_REGISTERED_PASSES = {} _REGISTERED_PASSES = {}
_COMMON_RULES = [] _COMMON_RULES = []
# TODO(zengjinle): add white/black list
name = None
@staticmethod @staticmethod
def _register(pass_name, pass_class): def _register(pass_name, pass_class):
...@@ -67,6 +78,9 @@ class PassBase(ABC): ...@@ -67,6 +78,9 @@ class PassBase(ABC):
def _check_conflict(self, other_pass): def _check_conflict(self, other_pass):
pass pass
def _type(self):
return PassType.UNKNOWN
def _check_conflict_including_common_rules(self, other_pass): def _check_conflict_including_common_rules(self, other_pass):
return self._check_conflict(other_pass) and all( return self._check_conflict(other_pass) and all(
[r(other_pass, self) for r in PassBase._COMMON_RULES]) [r(other_pass, self) for r in PassBase._COMMON_RULES])
...@@ -142,40 +156,18 @@ class CPPPassWrapper(PassBase): ...@@ -142,40 +156,18 @@ class CPPPassWrapper(PassBase):
self._attrs, self.cpp_attr_types) self._attrs, self.cpp_attr_types)
# Like AutoParallel/HybridParallel, etc. def _fusion_opt_last_rule(pass_before, pass_after):
class ParallelOptPass(PassBase): if pass_before._type() == PassType.FUSION_OPT and pass_after._type(
def __init__(self): ) != PassType.FUSION_OPT:
super(ParallelOptPass, self).__init__() return False
else:
# Like AMP, Recompute, etc.
class CalcOptPass(PassBase):
def __init__(self):
super(CalcOptPass, self).__init__()
# Like FuseAllReduce, FuseGradientMerge, etc.
class CommOptPass(PassBase):
def __init__(self):
super(CommOptPass, self).__init__()
def _make_pass_order_rule(pass_class_before, pass_class_after):
def impl(pass_obj_before, pass_obj_after):
if isinstance(pass_obj_before, pass_class_after) \
and isinstance(pass_obj_after, pass_class_before):
return False
return True return True
return impl
PassBase._COMMON_RULES = [ PassBase._COMMON_RULES = [
_make_pass_order_rule(CalcOptPass, CommOptPass), _fusion_opt_last_rule,
_make_pass_order_rule(ParallelOptPass, CPPPassWrapper),
_make_pass_order_rule(CalcOptPass, CPPPassWrapper),
_make_pass_order_rule(CommOptPass, CPPPassWrapper),
lambda pass_before, pass_after: type(pass_before) != type(pass_after), lambda pass_before, pass_after: type(pass_before) != type(pass_after),
# Add more common rules here
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册