未验证 提交 c75658a9 编写于 作者: S sneaxiy 提交者: GitHub

Fix launch error when PADDLE_TRAINER_ENDPOINTS is too long (#55450)

* fix new launch

* fix ps uit
上级 8511e030
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
g_backup_envs = None
def getenv_or_backup(name, default=None):
global g_backup_envs
if g_backup_envs is None:
backup_path = os.getenv('PADDLE_BACKUP_ENV_PATH')
if backup_path is None:
g_backup_envs = {}
else:
with open(backup_path, 'r') as f:
g_backup_envs = json.load(f)
value = os.getenv(name)
if value is not None:
return value
else:
return g_backup_envs.get(name, default)
......@@ -25,6 +25,8 @@ from paddle.distributed.fleet.base.private_helper_function import (
)
from paddle.fluid import core
from ...backup_env import getenv_or_backup
__all__ = []
......@@ -845,7 +847,9 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._server_endpoints = self._server_endpoints.split(",")
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", None)
self._worker_endpoints = getenv_or_backup(
"PADDLE_TRAINER_ENDPOINTS", None
)
if self._worker_endpoints is not None:
self._worker_endpoints = self._worker_endpoints.split(",")
else:
......@@ -1067,7 +1071,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
assert self._training_role == "TRAINER"
self._role = Role.WORKER
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
self._worker_endpoints = getenv_or_backup("PADDLE_TRAINER_ENDPOINTS")
self._cur_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
if self._worker_endpoints is None:
# back to non_distributed execution.
......
......@@ -25,6 +25,8 @@ import traceback
from paddle.distributed.fleet import cloud_utils, launch_utils
from paddle.distributed.utils.log_utils import get_logger
from ...backup_env import getenv_or_backup
logger = get_logger("INFO", "ELASTIC")
ELASTIC_EXIT_CODE = 101
......@@ -150,7 +152,7 @@ class ElasticManager:
self.np = len(self.trainers.split(","))
self.start_port = int(os.getenv("PADDLE_PORT", "6170"))
self.dist_endpoints = os.getenv('DISTRIBUTED_TRAINER_ENDPOINTS', '')
trainer_endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS', '')
trainer_endpoints = getenv_or_backup('PADDLE_TRAINER_ENDPOINTS', '')
self.trainer_endpoints_list = trainer_endpoints.split(",")
else:
self.trainers = args.ips or os.getenv('PADDLE_TRAINERS', '')
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import signal
import sys
......@@ -221,6 +222,8 @@ class Controller(ControllerBase):
):
if not container:
envs = copy.deepcopy(envs)
envs['PADDLE_LOG_DIR'] = str(os.path.abspath(self.ctx.args.log_dir))
container = self.new_container(
entrypoint=entrypoint, envs=envs, out=log_file, err=log_file
)
......
......@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import signal
import subprocess
import sys
import time
LIMIT_LEN_ENVS = ["TRAINER_IP_PORT_LIST", "PADDLE_TRAINER_ENDPOINTS"]
class ProcessContext:
def __init__(
......@@ -42,9 +45,30 @@ class ProcessContext:
def _start(self):
pre_fn = os.setsid if self._group else None
log_dir = self._env["PADDLE_LOG_DIR"]
os.makedirs(log_dir, exist_ok=True)
rank = self._env.get("PADDLE_TRAINER_ID")
if rank is not None:
rank = int(rank)
backup_env_path = str(
os.path.join(log_dir, f'backup_env.{rank}.json')
)
envs = {"PADDLE_BACKUP_ENV_PATH": backup_env_path}
max_len = int(os.getenv('PADDLE_ENV_LIMIT_LEN', 48000))
for k, v in self._env.items():
if k not in LIMIT_LEN_ENVS or len(v) < max_len:
envs[k] = v
with open(backup_env_path, 'w') as f:
json.dump(dict(self._env), f, indent=4, sort_keys=True)
else:
envs = self._env
self._proc = subprocess.Popen(
self._cmd,
env=self._env,
env=envs,
stdout=self._stdout,
stderr=self._stderr,
preexec_fn=self._preexec_fn or pre_fn,
......
......@@ -52,6 +52,7 @@ from paddle.nn.layer import layers
from paddle.utils import deprecated
from . import parallel_helper
from .backup_env import getenv_or_backup
__all__ = []
......@@ -723,7 +724,7 @@ class ParallelEnv:
selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",")
self._device_id = int(selected_xpus[0])
self._trainer_endpoints = os.getenv(
self._trainer_endpoints = getenv_or_backup(
"PADDLE_TRAINER_ENDPOINTS", ""
).split(",")
self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT", "")
......@@ -898,7 +899,7 @@ def _is_cpuonly(backend):
def _check_var_exists(var_name):
var = os.environ.get(var_name, None)
var = getenv_or_backup(var_name, None)
if var is None:
raise ValueError(
"paddle.distributed initialize error, "
......@@ -1081,7 +1082,9 @@ def init_parallel_env():
if endpoints is None:
endpoints = os.getenv("PADDLE_MASTER", None)
if endpoints is None:
endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0]
endpoints = getenv_or_backup("PADDLE_TRAINER_ENDPOINTS").split(',')[
0
]
assert endpoints, (
"The environment variable 'MASTER_ADDR' and 'MASTER_PORT' "
"must be specified, for example 'export MASTER_ADDR=127.0.0.1' "
......
......@@ -495,8 +495,9 @@ def _to_name_str(var):
def _prepare_fleet_executor():
from ..distributed.fleet.proto import fleet_executor_desc_pb2
from ..distributed.backup_env import getenv_or_backup
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS", "")
trainer_endpoints_str = getenv_or_backup("PADDLE_TRAINER_ENDPOINTS", "")
trainer_endpoints = trainer_endpoints_str.split(',')
fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc()
cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
......
......@@ -55,6 +55,7 @@ def get_files(pth, prefix):
if isfile(join(pth, f))
and not f.endswith('gpu.log')
and not f.startswith('envlog')
and not f.startswith('backup_env')
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册