From e3766da649244ab9290eeb4fc475b693e0ef3227 Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Fri, 29 Jul 2022 15:06:12 +0800 Subject: [PATCH] [LAUNCH] fix set args bug (#44717) --- python/paddle/distributed/launch/context/__init__.py | 3 +-- python/paddle/distributed/launch/controllers/collective.py | 3 ++- python/paddle/distributed/launch/plugins/__init__.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/launch/context/__init__.py b/python/paddle/distributed/launch/context/__init__.py index f93b30b4dd1..3e8f0de3e69 100644 --- a/python/paddle/distributed/launch/context/__init__.py +++ b/python/paddle/distributed/launch/context/__init__.py @@ -101,7 +101,6 @@ class Context(object): return False def set_env_in_args(self): - # this logic may not propre to replace args with env, but ... for k, v in env_args_mapping.items(): if k in self.envs: - setattr(self.args, v, type(getattr(self.args, v))(self.envs[k])) + setattr(self.args, v, self.envs[k]) diff --git a/python/paddle/distributed/launch/controllers/collective.py b/python/paddle/distributed/launch/controllers/collective.py index 6b4972c003c..1595bcd1efb 100644 --- a/python/paddle/distributed/launch/controllers/collective.py +++ b/python/paddle/distributed/launch/controllers/collective.py @@ -131,7 +131,8 @@ class CollectiveElasticController(CollectiveController): def run(self): - timeout = self.ctx.args.elastic_timeout if self.job.elastic else self.ctx.args.elastic_timeout * 10 + timeout = int(self.ctx.args.elastic_timeout) + timeout = timeout if self.job.elastic else timeout * 10 self.register() while self.pod.restart <= self.ctx.args.max_restart: diff --git a/python/paddle/distributed/launch/plugins/__init__.py b/python/paddle/distributed/launch/plugins/__init__.py index 4c414a177d1..a3a9e8c809a 100644 --- a/python/paddle/distributed/launch/plugins/__init__.py +++ b/python/paddle/distributed/launch/plugins/__init__.py @@ -62,7 +62,7 @@ def rewrite_host_ip(ctx): def test_mode(ctx): - if ctx.args.training_script == 'test': + if ctx.args.training_script == 'run_check': ctx.logger.info('Paddle Distributed Test begin...') if int(ctx.args.nnodes) < 2: ctx.args.nnodes = 2 -- GitLab