diff --git a/python/paddle/distributed/launch/context/__init__.py b/python/paddle/distributed/launch/context/__init__.py index b252e966021bc4228bab543decd833d44c65962c..9083347a51158badae93b8bce28a58bb7ffa856a 100644 --- a/python/paddle/distributed/launch/context/__init__.py +++ b/python/paddle/distributed/launch/context/__init__.py @@ -97,8 +97,9 @@ class Context: def set_env_in_args(self): for k, v in env_args_mapping.items(): + attr, attr_type = v if k in self.envs: print( - f"LAUNCH WARNNING args {v} is override by env {self.envs[k]}" + f"LAUNCH WARNNING args {attr} will be overridden by env: {k} value: {self.envs[k]}" ) - setattr(self.args, v, self.envs[k]) + setattr(self.args, attr, attr_type(self.envs[k])) diff --git a/python/paddle/distributed/launch/context/args_envs.py b/python/paddle/distributed/launch/context/args_envs.py index 7dc410de3450d1f8f3ee65e8949e9d7467b3f49e..56eac96f1b8b9e845a1cdb9dacab76658300f857 100644 --- a/python/paddle/distributed/launch/context/args_envs.py +++ b/python/paddle/distributed/launch/context/args_envs.py @@ -17,30 +17,30 @@ from argparse import REMAINDER, ArgumentParser from distutils.util import strtobool env_args_mapping = { - 'POD_IP': 'host', - 'PADDLE_MASTER': 'master', - 'PADDLE_DEVICES': 'devices', - 'PADDLE_NNODES': 'nnodes', - 'PADDLE_RUN_MODE': 'run_mode', - 'PADDLE_LOG_LEVEL': 'log_level', - 'PADDLE_LOG_OVERWRITE': 'log_overwrite', - 'PADDLE_SORT_IP': 'sort_ip', - 'PADDLE_NPROC_PER_NODE': 'nproc_per_node', - 'PADDLE_JOB_ID': 'job_id', - 'PADDLE_RANK': 'rank', - 'PADDLE_LOG_DIR': 'log_dir', - 'PADDLE_MAX_RESTART': 'max_restart', - 'PADDLE_ELASTIC_LEVEL': 'elastic_level', - 'PADDLE_ELASTIC_TIMEOUT': 'elastic_timeout', - 'PADDLE_SERVER_NUM': 'server_num', - 'PADDLE_TRAINER_NUM': 'trainer_num', - 'PADDLE_SERVERS_ENDPOINTS': 'servers', - 'PADDLE_TRAINERS_ENDPOINTS': 'trainers', - 'PADDLE_GLOO_PORT': 'gloo_port', - 'PADDLE_WITH_GLOO': 'with_gloo', - 'PADDLE_START_PORT': 'start_port', - 'PADDLE_IPS': 'ips', - "PADDLE_AUTO_PARALLEL_CONFIG": 'auto_parallel_config', + 'POD_IP': ('host', str), + 'PADDLE_MASTER': ('master', str), + 'PADDLE_DEVICES': ('devices', str), + 'PADDLE_NNODES': ('nnodes', str), + 'PADDLE_RUN_MODE': ('run_mode', str), + 'PADDLE_LOG_LEVEL': ('log_level', str), + 'PADDLE_LOG_OVERWRITE': ('log_overwrite', strtobool), + 'PADDLE_SORT_IP': ('sort_ip', strtobool), + 'PADDLE_NPROC_PER_NODE': ('nproc_per_node', int), + 'PADDLE_JOB_ID': ('job_id', str), + 'PADDLE_RANK': ('rank', int), + 'PADDLE_LOG_DIR': ('log_dir', str), + 'PADDLE_MAX_RESTART': ('max_restart', int), + 'PADDLE_ELASTIC_LEVEL': ('elastic_level', int), + 'PADDLE_ELASTIC_TIMEOUT': ('elastic_timeout', int), + 'PADDLE_SERVER_NUM': ('server_num', int), + 'PADDLE_TRAINER_NUM': ('trainer_num', int), + 'PADDLE_SERVERS_ENDPOINTS': ('servers', str), + 'PADDLE_TRAINERS_ENDPOINTS': ('trainers', str), + 'PADDLE_GLOO_PORT': ('gloo_port', int), + 'PADDLE_WITH_GLOO': ('with_gloo', str), + 'PADDLE_START_PORT': ('start_port', int), + 'PADDLE_IPS': ('ips', str), + "PADDLE_AUTO_PARALLEL_CONFIG": ('auto_parallel_config', str), }