From 8562668eff1558081faef30ea35edb4626a3e2fa Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Thu, 24 Mar 2022 11:29:35 +0800 Subject: [PATCH] fix device id env (#40844) --- python/paddle/distributed/fleet/launch.py | 3 ++- .../distributed/launch/context/__init__.py | 5 +++-- .../distributed/launch/context/device.py | 22 ++++++++----------- .../launch/controllers/collective.py | 7 ++++-- .../launch/controllers/controller.py | 2 ++ .../distributed/launch/plugins/__init__.py | 5 +++-- .../paddle/fluid/tests/unittests/test_run.py | 5 ++++- 7 files changed, 28 insertions(+), 21 deletions(-) diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py index 0d985a52325..c5a9df50589 100644 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -242,7 +242,8 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra elastic_group.add_argument( "--force", type=bool, default=False, help="update np force") - return parser.parse_args() + known_args, _ = parser.parse_known_args() + return known_args def get_cluster_from_args(args, device_mode, devices_per_proc): diff --git a/python/paddle/distributed/launch/context/__init__.py b/python/paddle/distributed/launch/context/__init__.py index 510f49d8246..e03d832767e 100644 --- a/python/paddle/distributed/launch/context/__init__.py +++ b/python/paddle/distributed/launch/context/__init__.py @@ -25,12 +25,13 @@ class Context(object): def __init__(self, enable_plugin=True): self.args, self.unknown_args = parse_args() self.envs = fetch_envs() - self.logger = self.get_logger() + + self.set_env_in_args() self.node = Node() self.status = Status() - self.set_env_in_args() + self.logger = self.get_logger() # design for event queue, later self.events = [] diff --git a/python/paddle/distributed/launch/context/device.py b/python/paddle/distributed/launch/context/device.py index 9163e7abd91..c2f6896ab6c 100644 --- a/python/paddle/distributed/launch/context/device.py +++ b/python/paddle/distributed/launch/context/device.py @@ -57,7 +57,7 @@ class Device(object): else: self._labels = [] - def get_selected_flag_key(self): + def get_selected_device_key(self): if self._dtype == DeviceType.CPU: return 'FLAGS_selected_cpus' if self._dtype == DeviceType.GPU: @@ -70,19 +70,15 @@ class Device(object): return 'FLAGS_selected_mlus' return 'FLAGS_selected_devices' - def get_selected_flag_label(self, idx): - if idx < len(self._labels): - return self._labels[idx] + def get_selected_devices(self, devices=''): + ''' + return the device label/id relative to the visible devices + ''' + if not devices: + return [str(x) for x in range(0, len(self._labels))] else: - return '0' - - def selected_flags(self, idx=None): - if idx is None: - return {self.get_selected_flag_key(): ','.join(self._labels)} - else: - return { - self.get_selected_flag_key(): self.get_selected_flag_label(idx) - } + devs = [x.strip() for x in devices.split(',')] + return [str(self._labels.index(d)) for d in devs] @classmethod def parse_device(self): diff --git a/python/paddle/distributed/launch/controllers/collective.py b/python/paddle/distributed/launch/controllers/collective.py index 0a6c1c4002a..bbcb7c81d6e 100644 --- a/python/paddle/distributed/launch/controllers/collective.py +++ b/python/paddle/distributed/launch/controllers/collective.py @@ -75,6 +75,9 @@ class CollectiveController(Controller): job_endpoints = [i['endpoints'] for i in peer_list] self.pod.reset() + selected_dev_key = self.ctx.node.device.get_selected_device_key() + selected_dev_list = self.ctx.node.device.get_selected_devices( + self.ctx.args.devices) for i in range(self.pod.replicas): e = { "PADDLE_MASTER": collective_master, @@ -90,9 +93,9 @@ class CollectiveController(Controller): "PADDLE_RANK_IN_NODE": str(i), } if self.pod.replicas == 1: - e.update(self.ctx.node.device.selected_flags()) + e.update({selected_dev_key: selected_dev_list}) else: - e.update(self.ctx.node.device.selected_flags(i)) + e.update({selected_dev_key: selected_dev_list[i]}) self.add_container(envs=e, log_tag=i) return True diff --git a/python/paddle/distributed/launch/controllers/controller.py b/python/paddle/distributed/launch/controllers/controller.py index 08345a2a1f7..fbe9df4c9a2 100644 --- a/python/paddle/distributed/launch/controllers/controller.py +++ b/python/paddle/distributed/launch/controllers/controller.py @@ -210,6 +210,8 @@ class Controller(ControllerBase): if self.ctx.args.nproc_per_node: return int(self.ctx.args.nproc_per_node) + elif self.ctx.args.devices: + return len(self.ctx.args.devices.split(',')) else: return self.ctx.node.device.count diff --git a/python/paddle/distributed/launch/plugins/__init__.py b/python/paddle/distributed/launch/plugins/__init__.py index 1862f75a77f..35a44ed942c 100644 --- a/python/paddle/distributed/launch/plugins/__init__.py +++ b/python/paddle/distributed/launch/plugins/__init__.py @@ -29,8 +29,9 @@ def process_args(ctx): #argdev = ctx.args.gpus or ctx.args.xpus or ctx.args.npus argdev = ctx.args.devices if argdev: - ctx.node.device.labels = argdev.split(',') - ctx.logger.debug('Device reset by args {}'.format(argdev)) + for d in argdev.split(','): + assert d in ctx.node.device.labels, 'Device not found {}'.format( + argdev) def collective_compatible(ctx): diff --git a/python/paddle/fluid/tests/unittests/test_run.py b/python/paddle/fluid/tests/unittests/test_run.py index a2f12fbf580..365d3f931c2 100644 --- a/python/paddle/fluid/tests/unittests/test_run.py +++ b/python/paddle/fluid/tests/unittests/test_run.py @@ -64,7 +64,10 @@ class Collective_Test(unittest.TestCase): if args: cmd.extend(args.split(" ")) cmd.extend([pyname]) - proc = subprocess.Popen(cmd, env) + env = os.environ.copy() + # virtual devies for testing + env.update({'CUDA_VISIBLE_DEVICES': '0,1,2,3,4,5,6,7'}) + proc = subprocess.Popen(cmd, env=env) return proc def test_collective_1(self): -- GitLab