diff --git a/python/paddle/distributed/launch/context/args_envs.py b/python/paddle/distributed/launch/context/args_envs.py index df81569be952fa37faf48f3592f85834d4816572..7dc410de3450d1f8f3ee65e8949e9d7467b3f49e 100644 --- a/python/paddle/distributed/launch/context/args_envs.py +++ b/python/paddle/distributed/launch/context/args_envs.py @@ -222,4 +222,9 @@ def parse_args(): help="seconds to wait before elastic job begin to train", ) - return parser.parse_known_args() + args = parser.parse_known_args() + env_rank = int(os.getenv('PADDLE_TRAINER_ID', -1)) + if env_rank >= 0: + assert hasattr(args[0], "rank") + args[0].rank = env_rank + return args