From a3f3172c2b080a2101398c851c6b25e2423e603e Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Thu, 4 Aug 2022 10:52:25 +0800 Subject: [PATCH] launch no python script (#44849) --- python/paddle/distributed/launch/context/__init__.py | 2 -- .../paddle/distributed/launch/controllers/controller.py | 6 +++++- python/paddle/distributed/launch/plugins/__init__.py | 9 +++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/launch/context/__init__.py b/python/paddle/distributed/launch/context/__init__.py index 3e8f0de3e69..d273d2355b3 100644 --- a/python/paddle/distributed/launch/context/__init__.py +++ b/python/paddle/distributed/launch/context/__init__.py @@ -57,8 +57,6 @@ class Context(object): return True legacy_env_list = [ - 'DISTRIBUTED_TRAINER_ENDPOINTS', - 'PADDLE_ELASTIC_JOB_ID', 'FLAGS_START_PORT', ] diff --git a/python/paddle/distributed/launch/controllers/controller.py b/python/paddle/distributed/launch/controllers/controller.py index bc628be59dc..0f0513f0a3d 100644 --- a/python/paddle/distributed/launch/controllers/controller.py +++ b/python/paddle/distributed/launch/controllers/controller.py @@ -170,7 +170,11 @@ class Controller(ControllerBase): raise NotImplementedError def _get_entrypoint(self): - entrypoint = [sys.executable, "-u", self.ctx.args.training_script] + if self.ctx.args.training_script.endswith('.py'): + entrypoint = [sys.executable, "-u", self.ctx.args.training_script] + else: + entrypoint = [self.ctx.args.training_script] + entrypoint.extend(self.ctx.args.training_script_args) return entrypoint diff --git a/python/paddle/distributed/launch/plugins/__init__.py b/python/paddle/distributed/launch/plugins/__init__.py index a3a9e8c809a..946768db32c 100644 --- a/python/paddle/distributed/launch/plugins/__init__.py +++ b/python/paddle/distributed/launch/plugins/__init__.py @@ -32,8 +32,10 @@ def process_args(ctx): argdev = ctx.args.devices if argdev: for d in argdev.split(','): - assert d in ctx.node.device.labels, 'Device not found {}'.format( - argdev) + if d not in ctx.node.device.labels: + ctx.logger.error( + f'Device not found {d} from {argdev} for setting {ctx.node.device.labels}' + ) def collective_compatible(ctx): @@ -44,7 +46,7 @@ def collective_compatible(ctx): ctx.args.nnodes = len(hosts) ctx.logger.info( 'args reset by env PADDLE_TRAINER_ENDPOINTS\n{}'.format(eps)) - ''' + if 'DISTRIBUTED_TRAINER_ENDPOINTS' in ctx.envs: eps = ctx.envs['DISTRIBUTED_TRAINER_ENDPOINTS'].split(',') hosts = set([h.split(':')[0] for h in eps]) @@ -52,7 +54,6 @@ def collective_compatible(ctx): ctx.args.nnodes = len(hosts) ctx.logger.info( 'args reset by env DISTRIBUTED_TRAINER_ENDPOINTS\n{}'.format(eps)) - ''' def rewrite_host_ip(ctx): -- GitLab