From 003e3dab5a5fa8fa1c5a6bd9cf6d191ee06cf8db Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 14 Dec 2021 10:41:55 +0800 Subject: [PATCH] Add white list for dist passes (#37831) * add white list for dist passes * update comment * follow zhiqiu's comment * fix PassContext attrs type --- python/paddle/distributed/passes/pass_base.py | 57 +++++++++++- .../distributed_passes/dist_pass_test_base.py | 20 +++- .../distributed_passes/test_white_lists.py | 92 +++++++++++++++++++ 3 files changed, 163 insertions(+), 6 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/distributed_passes/test_white_lists.py diff --git a/python/paddle/distributed/passes/pass_base.py b/python/paddle/distributed/passes/pass_base.py index 0df6df6ab3..b6b35c7c7e 100644 --- a/python/paddle/distributed/passes/pass_base.py +++ b/python/paddle/distributed/passes/pass_base.py @@ -21,7 +21,7 @@ from paddle.fluid.framework import program_guard, _apply_pass as _apply_cpp_pass class PassContext: def __init__(self): self._applied_passes = [] - self._attrs = [] + self._attrs = {} def set_attr(self, key, value): self._attrs[key] = value @@ -51,7 +51,9 @@ class PassType: class PassBase(ABC): _REGISTERED_PASSES = {} _COMMON_RULES = [] - # TODO(zengjinle): add white/black list + + _BEFORE_WHITE_LISTS_DICT = {} + _AFTER_WHITE_LISTS_DICT = {} name = None @@ -164,9 +166,60 @@ def _fusion_opt_last_rule(pass_before, pass_after): return True +def _make_rule_from_white_lists_dict(before_white_lists_dict, + after_white_lists_dict): + def collect_pass_names(white_lists_dict, result): + for k, v in white_lists_dict.items(): + result.add(k) + assert isinstance(v, (list, tuple)) + for pass_name in v: + assert isinstance(pass_name, (bytes, str)) + result.add(pass_name) + + all_pass_names = set() + collect_pass_names(before_white_lists_dict, all_pass_names) + collect_pass_names(after_white_lists_dict, all_pass_names) + + compatible_pass_dict = {} + for pass_name in all_pass_names: + compatible_pass_dict[pass_name] = set() + + for k, v in before_white_lists_dict.items(): + for pass_name in v: + compatible_pass_dict[k].add(pass_name) + + for k, v in after_white_lists_dict.items(): + for pass_name in v: + compatible_pass_dict[pass_name].add(k) + + def rule(pass_before, pass_after): + all_passes_after = compatible_pass_dict.get(pass_before.name) + if all_passes_after is None or pass_after.name not in compatible_pass_dict: + return True + else: + return pass_after.name in all_passes_after + + return rule + + +# The key-value pair (k, [v1, v2, ..., vn]) means the pass k can be +# applied before any of pass [v1, v2, ..., vn] is applied +PassBase._BEFORE_WHITE_LISTS_DICT = { + "fuse_gradient_merge": ["fuse_all_reduce"], + # Add more white lists here +} + +# The key-value pair (k, [v1, v2, ..., vn]) means the pass k can be +# applied after any of pass [v1, v2, ..., vn] is applied +PassBase._AFTER_WHITE_LISTS_DICT = { + # Add more white lists here +} + PassBase._COMMON_RULES = [ _fusion_opt_last_rule, lambda pass_before, pass_after: type(pass_before) != type(pass_after), + _make_rule_from_white_lists_dict(PassBase._BEFORE_WHITE_LISTS_DICT, + PassBase._AFTER_WHITE_LISTS_DICT), # Add more common rules here ] diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py index a5b1cdff0f..44b7766bcd 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py @@ -147,13 +147,25 @@ class DistPassTestBase(unittest.TestCase): with open(dump_file, "wb") as f: pickle.dump(all_fetch_values, f) - def _distributed_launch(self, apply_pass, gpus=None, **kwargs): - if gpus is None: + @classmethod + def _get_default_gpu_lists(cls): + visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + if visible_devices is None: + visible_devices = os.getenv("FLAGS_selected_gpus") + + if visible_devices is None: num_gpus = paddle.device.cuda.device_count() - gpus = list(range(num_gpus)) + return list(range(num_gpus)) else: - num_gpus = len(gpus) + return [ + int(s.strip()) for s in visible_devices.split(",") if s.strip() + ] + + def _distributed_launch(self, apply_pass, gpus=None, **kwargs): + if gpus is None: + gpus = self._get_default_gpu_lists() + num_gpus = len(gpus) gpus = ','.join([str(gpu_id) for gpu_id in gpus]) pid = os.getpid() diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_white_lists.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_white_lists.py new file mode 100644 index 0000000000..37abe1e121 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_white_lists.py @@ -0,0 +1,92 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from paddle.distributed.passes.pass_base import register_pass, PassBase, new_pass +from paddle.distributed.passes.pass_base import _make_rule_from_white_lists_dict as make_white_lists_rule + + +class TestConcretePass(PassBase): + def __init__(self): + super(TestConcretePass, self).__init__() + + def _check_self(self): + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + pass + + +@register_pass("A") +class A(TestConcretePass): + def __init__(self): + super(A, self).__init__() + + +@register_pass("B") +class B(TestConcretePass): + def __init__(self): + super(B, self).__init__() + + +@register_pass("C") +class C(TestConcretePass): + def __init__(self): + super(C, self).__init__() + + +@register_pass("D") +class D(TestConcretePass): + def __init__(self): + super(D, self).__init__() + + +@register_pass("E") +class E(TestConcretePass): + def __init__(self): + super(E, self).__init__() + + +class TestMakeWhiteListsRule(unittest.TestCase): + def test_main(self): + before_white_lists = {"A": ["B", "C"]} + after_white_lists = {"D": ["C"]} + rule = make_white_lists_rule(before_white_lists, after_white_lists) + + pass_a = new_pass("A") + pass_b = new_pass("B") + pass_c = new_pass("C") + pass_d = new_pass("D") + pass_e = new_pass("E") + + self.assertTrue(rule(pass_a, pass_e)) + self.assertTrue(rule(pass_e, pass_a)) + + self.assertTrue(rule(pass_a, pass_b)) + self.assertFalse(rule(pass_b, pass_a)) + self.assertTrue(rule(pass_a, pass_c)) + self.assertFalse(rule(pass_c, pass_a)) + + self.assertFalse(rule(pass_a, pass_d)) + self.assertFalse(rule(pass_d, pass_a)) + + self.assertTrue(rule(pass_c, pass_d)) + self.assertFalse(rule(pass_d, pass_c)) + + +if __name__ == "__main__": + unittest.main() -- GitLab