diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py index 0d985a523251754ff4335d76cd4ced7ef3f42f49..c5a9df50589ccd36bbd228822da7c29094ad9b1e 100644 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -242,7 +242,8 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra elastic_group.add_argument( "--force", type=bool, default=False, help="update np force") - return parser.parse_args() + known_args, _ = parser.parse_known_args() + return known_args def get_cluster_from_args(args, device_mode, devices_per_proc): diff --git a/python/paddle/distributed/launch/context/__init__.py b/python/paddle/distributed/launch/context/__init__.py index 510f49d8246f128c896712e9e0ad0776fa6f7626..e03d832767e6fac85c242d6563da363f7cbdd4a3 100644 --- a/python/paddle/distributed/launch/context/__init__.py +++ b/python/paddle/distributed/launch/context/__init__.py @@ -25,12 +25,13 @@ class Context(object): def __init__(self, enable_plugin=True): self.args, self.unknown_args = parse_args() self.envs = fetch_envs() - self.logger = self.get_logger() + + self.set_env_in_args() self.node = Node() self.status = Status() - self.set_env_in_args() + self.logger = self.get_logger() # design for event queue, later self.events = [] diff --git a/python/paddle/distributed/launch/context/device.py b/python/paddle/distributed/launch/context/device.py index 9163e7abd918371ddf4eca388bc912b630684f1f..c2f6896ab6c045da23a142b3ba5a6511c1d9b6ed 100644 --- a/python/paddle/distributed/launch/context/device.py +++ b/python/paddle/distributed/launch/context/device.py @@ -57,7 +57,7 @@ class Device(object): else: self._labels = [] - def get_selected_flag_key(self): + def get_selected_device_key(self): if self._dtype == DeviceType.CPU: return 'FLAGS_selected_cpus' if self._dtype == DeviceType.GPU: @@ -70,19 +70,15 @@ class Device(object): return 'FLAGS_selected_mlus' return 'FLAGS_selected_devices' - def get_selected_flag_label(self, idx): - if idx < len(self._labels): - return self._labels[idx] + def get_selected_devices(self, devices=''): + ''' + return the device label/id relative to the visible devices + ''' + if not devices: + return [str(x) for x in range(0, len(self._labels))] else: - return '0' - - def selected_flags(self, idx=None): - if idx is None: - return {self.get_selected_flag_key(): ','.join(self._labels)} - else: - return { - self.get_selected_flag_key(): self.get_selected_flag_label(idx) - } + devs = [x.strip() for x in devices.split(',')] + return [str(self._labels.index(d)) for d in devs] @classmethod def parse_device(self): diff --git a/python/paddle/distributed/launch/controllers/collective.py b/python/paddle/distributed/launch/controllers/collective.py index 0a6c1c4002abb3d291c47748eddad201fc0d2839..bbcb7c81d6e65c2e570ad3234619d95d9d7fdb20 100644 --- a/python/paddle/distributed/launch/controllers/collective.py +++ b/python/paddle/distributed/launch/controllers/collective.py @@ -75,6 +75,9 @@ class CollectiveController(Controller): job_endpoints = [i['endpoints'] for i in peer_list] self.pod.reset() + selected_dev_key = self.ctx.node.device.get_selected_device_key() + selected_dev_list = self.ctx.node.device.get_selected_devices( + self.ctx.args.devices) for i in range(self.pod.replicas): e = { "PADDLE_MASTER": collective_master, @@ -90,9 +93,9 @@ class CollectiveController(Controller): "PADDLE_RANK_IN_NODE": str(i), } if self.pod.replicas == 1: - e.update(self.ctx.node.device.selected_flags()) + e.update({selected_dev_key: selected_dev_list}) else: - e.update(self.ctx.node.device.selected_flags(i)) + e.update({selected_dev_key: selected_dev_list[i]}) self.add_container(envs=e, log_tag=i) return True diff --git a/python/paddle/distributed/launch/controllers/controller.py b/python/paddle/distributed/launch/controllers/controller.py index 08345a2a1f76b84cfde96667e6329bc1b28c18d4..fbe9df4c9a22398df2343cff6b8091506c159f2f 100644 --- a/python/paddle/distributed/launch/controllers/controller.py +++ b/python/paddle/distributed/launch/controllers/controller.py @@ -210,6 +210,8 @@ class Controller(ControllerBase): if self.ctx.args.nproc_per_node: return int(self.ctx.args.nproc_per_node) + elif self.ctx.args.devices: + return len(self.ctx.args.devices.split(',')) else: return self.ctx.node.device.count diff --git a/python/paddle/distributed/launch/plugins/__init__.py b/python/paddle/distributed/launch/plugins/__init__.py index 1862f75a77f65d39715e031b0ba72ebea6ab5523..35a44ed942c204a3793a7e49fde915e98743ce27 100644 --- a/python/paddle/distributed/launch/plugins/__init__.py +++ b/python/paddle/distributed/launch/plugins/__init__.py @@ -29,8 +29,9 @@ def process_args(ctx): #argdev = ctx.args.gpus or ctx.args.xpus or ctx.args.npus argdev = ctx.args.devices if argdev: - ctx.node.device.labels = argdev.split(',') - ctx.logger.debug('Device reset by args {}'.format(argdev)) + for d in argdev.split(','): + assert d in ctx.node.device.labels, 'Device not found {}'.format( + argdev) def collective_compatible(ctx): diff --git a/python/paddle/fluid/tests/unittests/test_run.py b/python/paddle/fluid/tests/unittests/test_run.py index a2f12fbf5809ba9f026b4160754e850f96182df6..365d3f931c27c180eebd9d3b72c80dac5f9227e5 100644 --- a/python/paddle/fluid/tests/unittests/test_run.py +++ b/python/paddle/fluid/tests/unittests/test_run.py @@ -64,7 +64,10 @@ class Collective_Test(unittest.TestCase): if args: cmd.extend(args.split(" ")) cmd.extend([pyname]) - proc = subprocess.Popen(cmd, env) + env = os.environ.copy() + # virtual devies for testing + env.update({'CUDA_VISIBLE_DEVICES': '0,1,2,3,4,5,6,7'}) + proc = subprocess.Popen(cmd, env=env) return proc def test_collective_1(self):