From 4c160be23b331ae7dd4be368b9291d5fb5b801ce Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 16 Nov 2021 15:45:44 +0800 Subject: [PATCH] refine pass by removing CommOpt, CalcOpt, ParallelOpt (#37206) --- python/paddle/distributed/passes/cpp_pass.py | 5 +- .../distributed/passes/fuse_all_reduce.py | 7 ++- python/paddle/distributed/passes/pass_base.py | 50 ++++++++----------- 3 files changed, 30 insertions(+), 32 deletions(-) diff --git a/python/paddle/distributed/passes/cpp_pass.py b/python/paddle/distributed/passes/cpp_pass.py index 5dd50b95345..fe6ef74bd85 100644 --- a/python/paddle/distributed/passes/cpp_pass.py +++ b/python/paddle/distributed/passes/cpp_pass.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .pass_base import CPPPassWrapper, register_pass +from .pass_base import PassType, CPPPassWrapper, register_pass @register_pass("fuse_elewise_add_act") @@ -23,3 +23,6 @@ class FuseElementwiseAddActPass(CPPPassWrapper): @property def cpp_name(self): return "fuse_elewise_add_act_pass" + + def _type(self): + return PassType.FUSION_OPT diff --git a/python/paddle/distributed/passes/fuse_all_reduce.py b/python/paddle/distributed/passes/fuse_all_reduce.py index 101f0c3dc38..317a66c008a 100644 --- a/python/paddle/distributed/passes/fuse_all_reduce.py +++ b/python/paddle/distributed/passes/fuse_all_reduce.py @@ -14,7 +14,7 @@ from paddle.framework import core from paddle.fluid import unique_name -from .pass_base import CommOptPass, register_pass +from .pass_base import PassBase, PassType, register_pass from collections import OrderedDict import numpy as np @@ -329,7 +329,7 @@ def insert_fuse_all_reduce_by_memory_size(block, groups, max_memory_size): @register_pass("fuse_all_reduce") -class FuseAllReducePass(CommOptPass): +class FuseAllReducePass(PassBase): def __init__(self): super(FuseAllReducePass, self).__init__() self.set_attr("max_memory_size", -1) @@ -341,6 +341,9 @@ class FuseAllReducePass(CommOptPass): def _check_conflict(self, other_pass): return True + def _type(self): + return PassType.COMM_OPT + # NOTE: why FuseAllReducePass can override apply_single_impl instead of # apply_impl? AllReduce is a collective operation, so the program of each # rank inside the same communication group should have the same diff --git a/python/paddle/distributed/passes/pass_base.py b/python/paddle/distributed/passes/pass_base.py index 4d4585ca6e8..0df6df6ab3f 100644 --- a/python/paddle/distributed/passes/pass_base.py +++ b/python/paddle/distributed/passes/pass_base.py @@ -40,9 +40,20 @@ class PassContext: del self._applied_passes[-1] +class PassType: + UNKNOWN = 0 + COMM_OPT = 1 + CALC_OPT = 2 + PARALLEL_OPT = 3 + FUSION_OPT = 4 + + class PassBase(ABC): _REGISTERED_PASSES = {} _COMMON_RULES = [] + # TODO(zengjinle): add white/black list + + name = None @staticmethod def _register(pass_name, pass_class): @@ -67,6 +78,9 @@ class PassBase(ABC): def _check_conflict(self, other_pass): pass + def _type(self): + return PassType.UNKNOWN + def _check_conflict_including_common_rules(self, other_pass): return self._check_conflict(other_pass) and all( [r(other_pass, self) for r in PassBase._COMMON_RULES]) @@ -142,40 +156,18 @@ class CPPPassWrapper(PassBase): self._attrs, self.cpp_attr_types) -# Like AutoParallel/HybridParallel, etc. -class ParallelOptPass(PassBase): - def __init__(self): - super(ParallelOptPass, self).__init__() - - -# Like AMP, Recompute, etc. -class CalcOptPass(PassBase): - def __init__(self): - super(CalcOptPass, self).__init__() - - -# Like FuseAllReduce, FuseGradientMerge, etc. -class CommOptPass(PassBase): - def __init__(self): - super(CommOptPass, self).__init__() - - -def _make_pass_order_rule(pass_class_before, pass_class_after): - def impl(pass_obj_before, pass_obj_after): - if isinstance(pass_obj_before, pass_class_after) \ - and isinstance(pass_obj_after, pass_class_before): - return False +def _fusion_opt_last_rule(pass_before, pass_after): + if pass_before._type() == PassType.FUSION_OPT and pass_after._type( + ) != PassType.FUSION_OPT: + return False + else: return True - return impl - PassBase._COMMON_RULES = [ - _make_pass_order_rule(CalcOptPass, CommOptPass), - _make_pass_order_rule(ParallelOptPass, CPPPassWrapper), - _make_pass_order_rule(CalcOptPass, CPPPassWrapper), - _make_pass_order_rule(CommOptPass, CPPPassWrapper), + _fusion_opt_last_rule, lambda pass_before, pass_after: type(pass_before) != type(pass_after), + # Add more common rules here ] -- GitLab