diff --git a/python/paddle/distributed/launch/context/args_envs.py b/python/paddle/distributed/launch/context/args_envs.py index 9e0565b1c45e52a1e9cfa7a63191ca0c3c598ad7..5eb00958ee62b344951f4b0bd0100c010d28e845 100644 --- a/python/paddle/distributed/launch/context/args_envs.py +++ b/python/paddle/distributed/launch/context/args_envs.py @@ -200,4 +200,9 @@ def parse_args(): help="seconds to wait before elastic job begin to train", ) - return parser.parse_known_args() + args = parser.parse_known_args() + env_rank = int(os.getenv('PADDLE_TRAINER_ID', -1)) + if env_rank >= 0: + assert hasattr(args[0], "rank") + args[0].rank = env_rank + return args