未验证 提交 003e3dab 编写于 作者: S sneaxiy 提交者: GitHub

Add white list for dist passes (#37831)

* add white list for dist passes

* update comment

* follow zhiqiu's comment

* fix PassContext attrs type
上级 c117dfba
......@@ -21,7 +21,7 @@ from paddle.fluid.framework import program_guard, _apply_pass as _apply_cpp_pass
class PassContext:
def __init__(self):
self._applied_passes = []
self._attrs = []
self._attrs = {}
def set_attr(self, key, value):
self._attrs[key] = value
......@@ -51,7 +51,9 @@ class PassType:
class PassBase(ABC):
_REGISTERED_PASSES = {}
_COMMON_RULES = []
# TODO(zengjinle): add white/black list
_BEFORE_WHITE_LISTS_DICT = {}
_AFTER_WHITE_LISTS_DICT = {}
name = None
......@@ -164,9 +166,60 @@ def _fusion_opt_last_rule(pass_before, pass_after):
return True
def _make_rule_from_white_lists_dict(before_white_lists_dict,
after_white_lists_dict):
def collect_pass_names(white_lists_dict, result):
for k, v in white_lists_dict.items():
result.add(k)
assert isinstance(v, (list, tuple))
for pass_name in v:
assert isinstance(pass_name, (bytes, str))
result.add(pass_name)
all_pass_names = set()
collect_pass_names(before_white_lists_dict, all_pass_names)
collect_pass_names(after_white_lists_dict, all_pass_names)
compatible_pass_dict = {}
for pass_name in all_pass_names:
compatible_pass_dict[pass_name] = set()
for k, v in before_white_lists_dict.items():
for pass_name in v:
compatible_pass_dict[k].add(pass_name)
for k, v in after_white_lists_dict.items():
for pass_name in v:
compatible_pass_dict[pass_name].add(k)
def rule(pass_before, pass_after):
all_passes_after = compatible_pass_dict.get(pass_before.name)
if all_passes_after is None or pass_after.name not in compatible_pass_dict:
return True
else:
return pass_after.name in all_passes_after
return rule
# The key-value pair (k, [v1, v2, ..., vn]) means the pass k can be
# applied before any of pass [v1, v2, ..., vn] is applied
PassBase._BEFORE_WHITE_LISTS_DICT = {
"fuse_gradient_merge": ["fuse_all_reduce"],
# Add more white lists here
}
# The key-value pair (k, [v1, v2, ..., vn]) means the pass k can be
# applied after any of pass [v1, v2, ..., vn] is applied
PassBase._AFTER_WHITE_LISTS_DICT = {
# Add more white lists here
}
PassBase._COMMON_RULES = [
_fusion_opt_last_rule,
lambda pass_before, pass_after: type(pass_before) != type(pass_after),
_make_rule_from_white_lists_dict(PassBase._BEFORE_WHITE_LISTS_DICT,
PassBase._AFTER_WHITE_LISTS_DICT),
# Add more common rules here
]
......
......@@ -147,13 +147,25 @@ class DistPassTestBase(unittest.TestCase):
with open(dump_file, "wb") as f:
pickle.dump(all_fetch_values, f)
def _distributed_launch(self, apply_pass, gpus=None, **kwargs):
if gpus is None:
@classmethod
def _get_default_gpu_lists(cls):
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
if visible_devices is None:
visible_devices = os.getenv("FLAGS_selected_gpus")
if visible_devices is None:
num_gpus = paddle.device.cuda.device_count()
gpus = list(range(num_gpus))
return list(range(num_gpus))
else:
num_gpus = len(gpus)
return [
int(s.strip()) for s in visible_devices.split(",") if s.strip()
]
def _distributed_launch(self, apply_pass, gpus=None, **kwargs):
if gpus is None:
gpus = self._get_default_gpu_lists()
num_gpus = len(gpus)
gpus = ','.join([str(gpu_id) for gpu_id in gpus])
pid = os.getpid()
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle.distributed.passes.pass_base import register_pass, PassBase, new_pass
from paddle.distributed.passes.pass_base import _make_rule_from_white_lists_dict as make_white_lists_rule
class TestConcretePass(PassBase):
def __init__(self):
super(TestConcretePass, self).__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, context):
pass
@register_pass("A")
class A(TestConcretePass):
def __init__(self):
super(A, self).__init__()
@register_pass("B")
class B(TestConcretePass):
def __init__(self):
super(B, self).__init__()
@register_pass("C")
class C(TestConcretePass):
def __init__(self):
super(C, self).__init__()
@register_pass("D")
class D(TestConcretePass):
def __init__(self):
super(D, self).__init__()
@register_pass("E")
class E(TestConcretePass):
def __init__(self):
super(E, self).__init__()
class TestMakeWhiteListsRule(unittest.TestCase):
def test_main(self):
before_white_lists = {"A": ["B", "C"]}
after_white_lists = {"D": ["C"]}
rule = make_white_lists_rule(before_white_lists, after_white_lists)
pass_a = new_pass("A")
pass_b = new_pass("B")
pass_c = new_pass("C")
pass_d = new_pass("D")
pass_e = new_pass("E")
self.assertTrue(rule(pass_a, pass_e))
self.assertTrue(rule(pass_e, pass_a))
self.assertTrue(rule(pass_a, pass_b))
self.assertFalse(rule(pass_b, pass_a))
self.assertTrue(rule(pass_a, pass_c))
self.assertFalse(rule(pass_c, pass_a))
self.assertFalse(rule(pass_a, pass_d))
self.assertFalse(rule(pass_d, pass_a))
self.assertTrue(rule(pass_c, pass_d))
self.assertFalse(rule(pass_d, pass_c))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册