未验证 提交 67c6ddff 编写于 作者: K kuizhiqing 提交者: GitHub

New design for launch/run (#40086)

上级 464f65b1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .job.container import Container
from .job.pod import Pod
from .job.job import Job
from . import plugins
#__all__ = [Container, Pod, Job]
'''
Paddle distribution training entry ``python -m paddle.distributed.run``.
Help
# for arg usage and explanation, try the following command
# python -m paddle.distributed.run -h
Collective Mode
Case 1: 1 node
use all visible devices
# python -m paddle.distributed.run train.py
use specified devices
# python -m paddle.distributed.run --devices=0,1,2,3 train.py
Case 2: multi-node, auto detect ip/port
# python -m paddle.distributed.run --np 2 train.py
# auto print following command
# python -m paddle.distributed.run --master 10.0.0.1:13538 --np 2 demo.py
# then copy and paste above command to other nodes
Case 3: multi-node, specified master/rendezvous server
# python -m paddle.distributed.run --np 2 --master 10.0.0.1:2379 train.py
# the master ip must be one of the node and the port must available
Parameter Server Mode
Case 1.1: 1 node, 1 ps, 1 trainer
# python -m paddle.distributed.run --mode ps train.py
# python -m paddle.distributed.run --server_num=1 --trainer_num=1 train.py
Case 1.2: 1 node, 2 ps, 2 trainer
# python -m paddle.distributed.run --server_num=2 --trainer_num=2 train.py
Case 2: 2 node, 2 ps, 2 trainer per node
# python -m paddle.distributed.run --server_num=2 --trainer_num=2 --np 2 train.py
# auto print following command
# python -m paddle.distributed.run --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --np 2 train.py
# then copy and paste above command to other nodes
Case 3: multi-node, specified master/rendezvous server
# python -m paddle.distributed.run --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --np 2 train.py
# the master ip must be one of the node and the port must available
Case 4: specified servers and trainers in each node
python -m paddle.distributed.run --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903 train.py
Elastic Mode
# run following command in 3 node to run immediately, or in 2 node to run after elastic_timeout
# python -m paddle.distributed.run --master etcd://10.0.0.1:2379 --np 2:3 train.py
# once the peer number changes between 2:3, the strategy holds
'''
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .context import Context
from . import controllers
# initialize the context to run
ctx = Context()
# initialize the selected controller
c = controllers.init(ctx)
# run the pods
c.run()
# manager or just wait pod
c.finalize()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser, REMAINDER
import os, copy
from paddle.distributed.run import plugins
from .node import Node
from .status import Status
import logging
class Context(object):
def __init__(self, enable_plugin=True):
os.environ.pop('http_proxy', None)
os.environ.pop('https_proxy', None)
self.args = self.parse_args()
self.envs = self.fetch_envs()
self.logger = self.get_logger()
self.node = Node()
self.status = Status()
self.set_env_in_args()
# design for event queue, later
self.events = []
if enable_plugin:
self._enable_plugin()
def get_envs(self):
return self.envs.copy()
def _enable_plugin(self):
for pl in plugins.enabled_plugins:
pl(self)
def parse_args(self):
parser = ArgumentParser()
base_group = parser.add_argument_group("Base Parameters")
base_group.add_argument(
"--master",
type=str,
default=None,
help="the master/rendezvous server, ip:port")
base_group.add_argument(
"--rank", type=int, default=-1, help="the peer rank")
base_group.add_argument(
"--log", type=str, default="INFO", help="log level. Default INFO")
base_group.add_argument(
"--np",
type=str,
default="1",
help="the number of peers, i.e. pod/node number")
base_group.add_argument(
"--nproc_per_node",
type=int,
default=None,
help="the number of processes in a pod")
base_group.add_argument(
"--log_dir",
type=str,
default="log",
help="the path for each process's log. Default ./log")
base_group.add_argument(
"--mode",
type=str,
default="collective",
help="run mode of the job, collective/ps/ps-heter")
base_group.add_argument(
"--id",
type=str,
default="default",
help="unique id of the job. Default default")
base_group.add_argument(
"--devices",
type=str,
default=None,
help="accelerate devices. as --gpus,npus,xps")
base_group.add_argument(
"--host", type=str, default=None, help="host ip")
base_group.add_argument(
"training_script",
type=str,
help="the full path of py script,"
"followed by arguments for the "
"training script")
base_group.add_argument('training_script_args', nargs=REMAINDER)
ps_group = parser.add_argument_group("Parameter-Server Parameters")
# for parameter server
ps_group.add_argument(
"--servers",
type=str,
default='',
help="servers endpoints full list")
ps_group.add_argument(
"--trainers",
type=str,
default='',
help="trainers endpoints full list")
ps_group.add_argument(
"--trainer_num", type=int, default=None, help="number of trainers")
ps_group.add_argument(
"--server_num", type=int, default=None, help="number of servers")
ps_group.add_argument(
"--gloo_port", type=int, default=6767, help="gloo http port")
ps_group.add_argument(
"--with_gloo", type=str, default="0", help="use gloo or not")
# parameter elastic mode
elastic_group = parser.add_argument_group("Elastic Parameters")
elastic_group.add_argument(
"--max_restart",
type=int,
default=3,
help="the times can restart. Default 3")
elastic_group.add_argument(
"--elastic_level",
type=int,
default=-1,
help="elastic level: -1 disable, 0 failed exit, peers hold, 1 internal restart"
)
elastic_group.add_argument(
"--elastic_timeout",
type=int,
default=30,
help="seconds to wait before elastic perform training")
return parser.parse_args()
def _valide_env(self, key):
if key in ['POD_IP']:
return True
if key.endswith('_VISIBLE_DEVICES'):
return True
if key.startswith('PADDLE_'):
return True
return False
def fetch_envs(self):
ge = os.environ.copy()
black_env_list = ['http_proxy', 'https_proxy']
for key in black_env_list:
ge.pop(key, None)
return ge
'''
# use black list instead white list
return {k: ge[k] for k in ge if self._valide_env(k)}
'''
def get_logger(self, level=logging.INFO):
logger = logging.getLogger("PADDLERUN")
logger.setLevel(self.args.log.upper() or level)
formatter = logging.Formatter(
fmt='%(name)s %(levelname)s %(asctime)s %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
return logger
def set_env_in_args(self):
env_args = {
'POD_IP': 'host',
'PADDLE_MASTER': 'master',
'PADDLE_DEVICES': 'devices',
'PADDLE_NP': 'np',
'PADDLE_MODE': 'mode',
'PADDLE_LOG': 'log',
'PADDLE_NPROC_PER_NODE': 'nproc_per_node',
'PADDLE_JOB_ID': 'id',
'PADDLE_RANK': 'rank',
'PADDLE_LOG_DIR': 'log_dir',
'PADDLE_MAX_RESTlRT': 'max_restart',
'PADDLE_ELASTIC_LEVEL': 'elastic_level',
'PADDLE_ELASTIC_TIMEOUT': 'elastic_timeout',
'PADDLE_SERVER_NUM': 'server_num',
'PADDLE_TRAINER_NUM': 'trainer_num',
'PADDLE_SERVERS_ENDPOINTS': 'servers',
'PADDLE_TRAINERS_ENDPOINTS': 'trainers',
'PADDLE_GLOO_PORT': 'gloo_port',
'PADDLE_WITH_GLOO': 'with_gloo',
}
for k, v in env_args.items():
if k in self.envs:
setattr(self.args, v, self.envs[k])
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
class DeviceType:
CPU = 'cpu'
GPU = 'gpu'
XPU = 'xpu'
NPU = 'npu'
class Device(object):
def __init__(self, dtype=None, count=1, memory="", labels=""):
self.dtype = dtype
self.count = count
self.memory = memory
self.labels = labels
def __str__(self):
return ",".join(self.labels)
@classmethod
def parse_device(self):
dev = Device()
visible_devices = None
if 'CUDA_VISIBLE_DEVICES' in os.environ or 'NVIDIA_VISIBLE_DEVICES' in os.environ:
dev.dtype = DeviceType.GPU
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") or os.getenv(
"NVIDIA_VISIBLE_DEVICES")
elif 'XPU_VISIBLE_DEVICES' in os.environ:
dev.dtype = DeviceType.XPU
visible_devices = os.getenv("XPU_VISIBLE_DEVICES")
elif 'ASCEND_VISIBLE_DEVICES' in os.environ:
dev.dtype = DeviceType.NPU
visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES")
if visible_devices and visible_devices != 'all':
dev.labels = visible_devices.split(',')
dev.count = len(dev.labels)
else:
return self.detect_device()
return dev
@classmethod
def detect_device(self):
import paddle.fluid as fluid
dev = Device()
num = 0
visible_devices = None
if fluid.core.is_compiled_with_cuda():
dev.dtype = DeviceType.GPU
num = fluid.core.get_cuda_device_count()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") or os.getenv(
"NVIDIA_VISIBLE_DEVICES")
elif fluid.core.is_compiled_with_xpu():
dev.dtype = DeviceType.XPU
num = fluid.core.get_xpu_device_count()
visible_devices = os.getenv("XPU_VISIBLE_DEVICES")
elif fluid.core.is_compiled_with_npu():
dev.dtype = DeviceType.NPU
num = fluid.core.get_npu_device_count()
visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES")
if num == 0:
dev.dtype = DeviceType.CPU
elif visible_devices is None or visible_devices == "all" or visible_devices == "":
dev.labels = [str(x) for x in range(0, num)]
dev.count = num
else:
dev.labels = visible_devices.split(',')
dev.count = len(dev.labels)
return dev
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Event(object):
def __init__(self, kind="status", message="", fatal=False):
self.kind = kind
self.message = message
self.fatal = fatal
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .device import Device
import socket
import struct
from contextlib import closing
class Node(object):
def __init__(self):
# self.device = Device.detect_device()
self.device = Device.parse_device()
self.ip = self.get_host_ip()
self.free_ports = []
def get_host_ip(self):
try:
self.hostname = socket.gethostname()
self.ip = socket.gethostbyname(socket.getfqdn(self.hostname))
return self.ip
except:
return '127.0.0.1'
def get_free_ports(self, n=1):
free_ports = [self.get_free_port() for i in range(n)]
self.free_ports += free_ports
return free_ports
def get_ports_occupied(self):
return self.free_ports
@classmethod
def get_free_port(self):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack('ii', 1, 0))
s.bind(('', 0))
return s.getsockname()[1]
@classmethod
def is_server_ready(self, ip, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
#sock.settimeout(0.01)
#sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
result = sock.connect_ex((ip, int(port)))
if result == 0:
return True
else:
return False
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Resource(object):
def __init__(self):
self.devices = []
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Status(object):
UNINIT = "uninit"
READY = "ready"
RUNNING = "running"
FAILED = "failed"
TERMINATING = "terminating"
RESTARTING = "restarting"
UNKNOWN = "unknown"
COMPLETED = "completed"
DONE = "done" # should exit whatever status
def __init__(self):
self._current_status = None
def current(self):
return self._current_status
def is_running(self):
return self._current_status == self.RUNNING
def is_restarting(self):
return self._current_status == self.RESTARTING
def is_done(self):
if self._current_status in [self.DONE, self.COMPLETED, self.FAILED]:
return True
else:
return False
def run(self):
self._current_status = self.RUNNING
def fail(self):
self._current_status = self.FAILED
def complete(self):
self._current_status = self.COMPLETED
def restart(self):
self._current_status = self.RESTARTING
def done(self):
self._current_status = self.DONE
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["init"]
from .collective import CollectiveController
from .collective import CollectiveElasticController
from .ps import PSController
# the order is extremely important
_controllers = [
CollectiveElasticController,
PSController,
CollectiveController,
]
def init(ctx):
for c in _controllers:
if c.enable(ctx):
return c(ctx)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .controller import Controller
import json
import os
import six
import time
class CollectiveController(Controller):
@classmethod
def enable(cls, ctx):
if ctx:
ctx.logger.debug("{} enabled".format(cls.__name__))
return True
else:
return False
def build_pod(self):
self.pod.replicas = self.pod_replicas()
# rank will be reset when restart
self.pod.rank = self.ctx.args.rank
port = self.ctx.node.get_free_port()
# compatible
endpoints = [
"{}:{}".format(self.ctx.node.ip, p)
for p in self.ctx.node.get_free_ports(self.pod.replicas)
]
data = json.dumps({
'name': self.pod.name,
'rank': self.pod.rank,
'replicas': self.pod.replicas,
'dtype': self.ctx.node.device.dtype,
'candidate': '{}:{}'.format(self.ctx.node.ip, port),
'endpoints': ",".join(endpoints),
})
peer_list, rank = self.master.sync_peers(
'/{}/info'.format(self.job.id), self.pod.name, data,
self.job.replicas, self.pod.rank)
self.pod.rank = rank
if len(peer_list) < 1:
return False
peer_list = [json.loads(i) for i in peer_list]
self.ctx.logger.debug("sync peers done {}".format(peer_list))
self.save_pod_log(peer_list)
global_size = sum([i['replicas'] for i in peer_list])
rank_offset = sum([i['replicas'] for i in peer_list[:rank]])
'''
The new designed collective need nothing but a master endpoint
'''
collective_master = peer_list[0]['candidate']
job_endpoints = [i['endpoints'] for i in peer_list]
self.pod.reset()
for i in range(self.pod.replicas):
e = {
"PADDLE_MASTER": collective_master,
"PADDLE_GLOBAL_SIZE": "{}".format(global_size),
"PADDLE_LOCAL_SIZE": "{}".format(self.pod.replicas),
"PADDLE_GLOBAL_RANK": "{}".format(i + rank_offset),
"PADDLE_LOCAL_RANK": "{}".format(i),
## compatible env
"PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints),
"PADDLE_CURRENT_ENDPOINT": endpoints[i],
"PADDLE_TRAINER_ID": "{}".format(i + rank_offset),
"PADDLE_TRAINERS_NUM": "{}".format(global_size),
"PADDLE_RANK_IN_NODE": str(i),
}
self.add_container(envs=e, log_tag=i)
return True
class CollectiveElasticController(CollectiveController):
@classmethod
def enable(cls, ctx):
if ctx.args.master and ctx.args.master.startswith("etcd://"):
ctx.logger.debug("{} enabled".format(cls.__name__))
return True
else:
return False
def register(self):
if self.job.id == 'default':
self.ctx.logger.warning(
'Using default job name may cause conflict, add --id in args')
self.master.register_heartbeat(self.job.id, self.pod.name)
def watch(self) -> bool:
'''
watch self and peer status, return true to exit
'''
while not self.ctx.status.is_done():
# self status
status = self.pod.watch(timeout=2)
self.ctx.logger.debug("Pod status {}, Ctx status {}".format(
status, self.ctx.status.current()))
# completed
if status == self.ctx.status.COMPLETED:
self.master.set_status(status)
self.ctx.status.complete()
self.ctx.logger.info("Pod complete {}".format(status))
return True
# self failure
elif status == self.ctx.status.FAILED:
self.master.set_status(status)
self.master.restart_peer()
self.ctx.logger.info("Pod failed {}".format(status))
self.pod.stop()
if self.ctx.args.elastic_level <= 0:
return True
else:
return False
# peer failure
if self.ctx.status.is_restarting() and self.master.get_status(
) != self.ctx.status.COMPLETED:
self.pod.stop()
return False
#peers = self.master.fetch_peer_alive()
#print("peers {}".format(peers))
def run(self):
timeout = self.ctx.args.elastic_timeout if self.job.elastic else self.ctx.args.elastic_timeout * 10
self.register()
while self.pod.restart <= self.ctx.args.max_restart:
self.build_job()
ok, replicas = self.master.wait_peer_ready(
self.job.replicas_min, self.job.replicas_max, timeout)
if ok:
self.job.replicas = replicas
else:
self.ctx.logger.warnning("peer not ready {}".format(self.job))
break
self.ctx.logger.debug("Run {}".format(self.job))
if not self.build_pod():
continue
self.master.set_status(self.ctx.status.RUNNING)
self.ctx.status.run()
assert len(self.pod.containers) > 0, "No container in the pod"
self.ctx.logger.debug("Run {}".format(self.pod))
self.ctx.logger.debug("Run {}".format(self.pod.containers[0]))
self.pod.deploy()
if self.watch():
break
self.ctx.logger.debug("Job done {}".format(self.job))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import signal
from paddle.distributed.run.job import Job
from paddle.distributed.run.job import Pod
from paddle.distributed.run.job import Container
from .master import Master
import time
class ControleMode:
COLLECTIVE = "collective"
PS = "ps"
class ControllerBase(object):
def __init__(self, ctx):
signal.signal(signal.SIGTERM, self.signal_handler)
signal.signal(signal.SIGABRT, self.signal_handler)
signal.signal(signal.SIGINT, self.signal_handler)
self.ctx = ctx
self.master = Master.factory(self.ctx)
self.job = Job(np=self.ctx.args.np,
mode=self.ctx.args.mode,
id=self.ctx.args.id)
self.pod = Pod()
self.join_server = None
def run(self):
self.build_job()
self.build_pod()
if len(self.pod.containers) < 1:
self.ctx.logger.error("No container in the pod {}".format(self.pod))
return
self.ctx.logger.info("Run {}".format(self.pod))
self.ctx.logger.debug(self.pod.containers[0])
self.pod.deploy()
self.watch()
def watch(self) -> bool:
status = self.pod.watch()
if status == self.ctx.status.COMPLETED:
self.ctx.logger.info("Pod {}".format(status))
elif status == self.ctx.status.FAILED:
self.ctx.logger.info("Pod {}".format(status))
self.ctx.logger.error("Container failed !!!\n{}".format(
self.pod.failed_container()))
self.pod.tail()
self.pod.stop()
def stop(self, sigint=None):
self.ctx.logger.debug("Controller stop")
self.master.stop()
self.pod.stop(sigint)
def finalize(self):
self.pod.join()
self.master.stop()
self.ctx.logger.info("Exit code {}".format(self.pod.exit_code))
sys.exit(self.pod.exit_code)
def signal_handler(self, sigint, frame):
self.ctx.logger.info("Terminating with signal {}".format(sigint))
if hasattr(self, 'sigint'):
time.sleep(5)
sys.exit(sigint)
self.sigint = sigint
self.ctx.status.done()
self.stop(sigint)
time.sleep(1)
self.ctx.logger.debug("Exit with signal {}".format(sigint))
sys.exit(sigint)
class Controller(ControllerBase):
'''
Controller API for customization
'''
def build_job(self):
'''
build job fill the job info.
'''
self.ctx.logger.info(self.job)
def build_pod(self) -> bool:
'''
build pod includes creating containers etc.
Return True if succeed
'''
raise NotImplementedError
def _get_entrypoint(self):
entrypoint = [sys.executable, "-u", self.ctx.args.training_script]
entrypoint.extend(self.ctx.args.training_script_args)
return entrypoint
def _get_out_err_file(self, out=None, err=None):
if out and self.ctx.args.log_dir != "":
out = os.path.join(self.ctx.args.log_dir, out)
if err and self.ctx.args.log_dir != "":
err = os.path.join(self.ctx.args.log_dir, err)
return out, (err or out)
def new_container(self,
entrypoint=None,
envs={},
use_ctx_env=True,
out=None,
err=None):
c = Container(
entrypoint=(entrypoint or self._get_entrypoint()),
env=(self.ctx.get_envs() if use_ctx_env else {}), )
c.outfile, c.errfile = self._get_out_err_file(out, err)
c.update_env(envs)
return c
def add_container(self,
container=None,
entrypoint=None,
envs={},
log_tag=None,
is_init=False):
if not is_init and log_tag is not None:
log_file = "{}.{}.{}.log".format(self.job.id, self.pod.name,
log_tag)
else:
log_file = None
if not container:
container = self.new_container(
entrypoint=entrypoint, envs=envs, out=log_file, err=log_file)
if is_init:
self.pod.add_init_container(container)
else:
self.pod.add_container(container)
def pod_replicas(self):
'''
how many process/container should be run in pod
'''
if self.ctx.args.nproc_per_node:
return int(self.ctx.args.nproc_per_node)
else:
return self.ctx.node.device.count
def save_pod_log(self, info):
'''
save_pod_log append *info* to the log file of pod.name
'''
if not self.ctx.args.log_dir:
return
f = os.path.join(self.ctx.args.log_dir,
'{}.{}.log'.format(self.job.id, self.pod.name))
try:
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'a+') as fd:
fd.write(str(info))
except Exception as e:
self.ctx.logger.error("save log failed because {}".format(e))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.run.utils.kv_client import KVClient
from paddle.distributed.run.utils.kv_server import KVServer
import time
import sys
import six
import threading
import copy
import random
ETCD_PROTOCAL = 'etcd://'
class Master(object):
'''
Master is a distributed store design to exchange info among nodes
'''
MAIN = "main"
STANDBY = "standby"
PATICIPANT = "participant"
def __init__(self, ctx):
self.ctx = ctx
self.server = None
self.initialized = False
self.endpoint = None
def stop(self):
raise NotImplementedError
def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int):
raise NotImplementedError
@classmethod
def factory(cls, ctx):
if ctx.args.master and ctx.args.master.startswith(ETCD_PROTOCAL):
return ETCDMaster(ctx)
else:
return HTTPMaster(ctx)
class HTTPMaster(Master):
def lazy_init(self):
if self.initialized:
return
self.role = Master.PATICIPANT
if self.ctx.args.master:
self.endpoint = self.ctx.args.master
ip, port = self.endpoint.split(':')
if ip in ['127.0.0.1', self.ctx.node.ip]:
time.sleep(2 * random.random())
while not self.ctx.node.is_server_ready(ip, int(port)):
try:
self.server = KVServer(int(port))
self.role = Master.MAIN
break
except Exception as e:
self.ctx.logger.warning("start master failed {}".format(
e))
time.sleep(0.1)
continue
else:
port = self.ctx.node.get_free_port()
self.endpoint = "{}:{}".format(self.ctx.node.ip, port)
self.server = KVServer(port)
self.role = Master.MAIN
print("Copy the following command to other nodes to run.")
cmd = [
sys.executable.split('/')[-1], "-m", "paddle.distributed.run"
]
cmd.extend(["--master", self.endpoint])
cmd.extend(sys.argv[1:])
print("-" * 80)
print(" ".join(cmd))
print("-" * 80)
if self.ctx.args.rank >= 0:
self.ctx.logger.warning(
"--rank set in the command may not compatible in auto mode")
if '127.0.0.1' in self.endpoint:
self.endpoint = self.endpoint.replace('127.0.0.1', self.ctx.node.ip)
self.client = KVClient(self.endpoint)
self.initialized = True
self._start_server()
def _start_server(self):
if self.server and not self.server.started:
self.server.start()
self.ctx.logger.debug("KV server start at {}".format(self.endpoint))
def _stop_server(self):
if self.server and not self.server.stopped:
self.server.stop()
self.ctx.logger.debug("KV server stopped")
def stop(self):
self._stop_server()
def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int):
if size < 2:
return [value], 0
self.lazy_init()
while not self.ctx.status.is_done():
if self.client.wait_server_ready(timeout=5):
break
else:
self.ctx.logger.warning("master not ready")
time.sleep(0.1)
# 'aaaaaa' make suer main pod (master server) as rank 0
ky = 'aaaaaa' if rank < 0 and self.role == Master.MAIN else key
k = "{}/{}/{}".format(prefix, ky, rank)
while not self.ctx.status.is_done():
if not self.client.put(k, value):
self.ctx.logger.warning("put value failed")
time.sleep(0.1)
continue
rjson = self.client.get_prefix(prefix)
self.ctx.logger.debug("sync peers {}".format(rjson))
if rjson and len(rjson) == size:
if rank < 0:
keys = list(rjson.keys())
keys.sort()
ret = [rjson[k] for k in keys]
idx = ret.index(value)
return ret, idx
else:
ret = [None] * size
for k, v in rjson.items():
ret[int(k.split('/')[-1])] = v
return ret, rank
else:
time.sleep(0.5)
return [], 0
class ETCDMaster(Master):
def __init__(self, ctx):
super().__init__(ctx)
if self.ctx.args.master:
# etcd://localhost:2379
self.endpoint = self.ctx.args.master.strip("etcd://")
import etcd3
host, port = self.endpoint.split(':')
self.client = etcd3.client(host=host, port=port)
def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int):
'''
sync_peers gather all value for key under scope prefix
result always be sorted either by rank or alphabet of pod.name
'''
path = "{}/{}/{}".format(prefix, key, rank)
self.client.delete_prefix(prefix)
self.ctx.logger.debug("sync path {} value {}".format(path, value))
while not self.ctx.status.is_done():
self.client.put(path, six.b(value))
result = [i for i in self.client.get_prefix(prefix)]
result = copy.deepcopy(result)
self.ctx.logger.debug("sync peers {}".format(result))
if len(result) == size:
if rank < 0:
keys = [six.ensure_str(i[1].key) for i in result]
sorted_keys = [six.ensure_str(i[1].key) for i in result]
sorted_keys.sort()
values = [six.ensure_str(i[0]) for i in result]
ret = [values[keys.index(k)] for k in sorted_keys]
idx = ret.index(value)
return ret, idx
else:
ret = [None] * size
for v, k in result:
ii = int(six.ensure_str(k.key).split('/')[-1])
if ii < 0:
self.ctx.logger.error(
"rank {} error in sync".format(ii))
ret[ii] = six.ensure_str(v)
return ret, rank
else:
time.sleep(0.5)
def register_heartbeat(self, job_id, pod_id, ttl=10):
if hasattr(self, 'heartbeat_prefix'):
self.ctx.logger.warning("Heartbeat already done")
return
self.job_prefix = '/paddle/{}'.format(job_id)
self.heartbeat_prefix = '{}/heartbeat'.format(self.job_prefix)
lease = self.client.lease(ttl)
#self.client.delete_prefix(self.job_prefix)
beat_path = "{}/{}".format(self.heartbeat_prefix, pod_id)
self.client.put(beat_path, six.b(pod_id), lease=lease)
def _beat_watch(event):
self.ctx.status.restart()
beat_watch = self.client.add_watch_prefix_callback(
self.heartbeat_prefix, _beat_watch)
def _heartbeat():
while not self.ctx.status.is_done():
try:
lease.refresh()
if pod_id not in self.fetch_peer_alive():
self.client.put(beat_path, six.b(pod_id), lease=lease)
self.ctx.logger.debug("Heartbeat register again")
except Exception as e:
self.ctx.logger.error("Heartbeat error {}".format(e))
time.sleep(ttl / 2)
self.ctx.logger.debug("Heartbeat done")
self.client.cancel_watch(beat_watch)
self.beat_thread = threading.Thread(
name='heartbeat', target=_heartbeat, daemon=True)
self.beat_thread.start()
def fetch_peer_alive(self):
peer_alive = [
six.ensure_str(i[0])
for i in self.client.get_prefix(self.heartbeat_prefix)
]
self.ctx.logger.debug("peer alive {}".format(peer_alive))
return peer_alive
def wait_peer_ready(self, replicas_min, replicas_max, timeout):
end = time.time() + timeout
while not self.ctx.status.is_done() and time.time() < end:
if len(self.fetch_peer_alive()) == replicas_max:
return (True, replicas_max)
else:
time.sleep(0.5)
np = len(self.fetch_peer_alive())
if np >= replicas_min and np <= replicas_max:
return (True, np)
else:
return (False, np)
def restart_peer(self):
self.client.delete_prefix(self.heartbeat_prefix)
def set_status(self, status):
assert self.client.put(
self.job_prefix, six.b(status),
lease=self.client.lease(600)), "set status failed {}".format(status)
def get_status(self):
return six.ensure_str(self.client.get(self.job_prefix)[0] or '')
def stop(self):
if hasattr(self, 'beat_thread'):
self.ctx.status.done()
# TODO(kuizhiqing) thread should exit
#self.beat_thread.join()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .controller import Controller, ControleMode
import json
import os, shutil
class PSController(Controller):
@classmethod
def enable(cls, ctx):
if ctx.args.mode == ControleMode.PS or ctx.args.server_num or len(
ctx.args.servers) > 0:
ctx.logger.debug("{} enabled".format(cls.__name__))
ctx.args.mode = ControleMode.PS
return True
else:
return False
def build_pod(self):
if self.ctx.args.servers and self.ctx.args.trainers:
self._build_pod_with_args()
else:
self._build_pod_with_master()
def _build_pod_with_args(self):
if '127.0.0.1' in self.ctx.args.servers:
host = '127.0.0.1'
else:
host = self.ctx.node.ip
server_endpoints = [s for s in self.ctx.args.servers.split(",")]
trainer_endpoints = [s for s in self.ctx.args.trainers.split(",")]
servers = [
s for s in self.ctx.args.servers.split(",") if s.startswith(host)
]
trainers = [
s for s in self.ctx.args.trainers.split(",") if s.startswith(host)
]
server_num = len(servers)
trainer_num = len(trainers)
self.pod.replicas = server_num + trainer_num
self.save_pod_log([server_endpoints, trainer_endpoints])
import tempfile
gloo_rendezvous_dir = tempfile.mkdtemp()
if os.path.exists(gloo_rendezvous_dir):
shutil.rmtree(gloo_rendezvous_dir)
gloo_port = self.ctx.args.gloo_port
gloo_http = "{}:{}".format(server_endpoints[0].split(":")[0], gloo_port)
_gloo_envs = {
"PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": gloo_rendezvous_dir,
"PADDLE_GLOO_HTTP_ENDPOINT": gloo_http,
"PADDLE_WITH_GLOO": self.ctx.args.with_gloo
}
for i in range(server_num):
e = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.ctx.args.servers,
"PADDLE_TRAINER_ENDPOINTS": self.ctx.args.trainers,
"PADDLE_PORT": servers[i].split(":")[1],
"PADDLE_ROLE": "PSERVER",
"TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": "{}".format(len(trainer_endpoints)),
"POD_IP": self.ctx.node.ip,
}
e.update(_gloo_envs)
log_tag = "ps.{}".format(i)
self.add_container(envs=e, log_tag=log_tag)
trainer_rank_offset = 0
for s in trainer_endpoints:
if s.startswith(host):
break
else:
trainer_rank_offset += 1
for i in range(trainer_num):
e = {
"PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_PORT": trainers[i].split(":")[1],
"PADDLE_ROLE": "TRAINER",
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": "{}".format(i + trainer_rank_offset),
"PADDLE_TRAINERS_NUM": "{}".format(len(trainer_endpoints)),
"POD_IP": self.ctx.node.ip,
}
e.update(_gloo_envs)
log_tag = "trainer.{}".format(i)
self.add_container(envs=e, log_tag=log_tag)
def _build_pod_with_master(self):
self.pod.rank = self.ctx.args.rank
server_num = self.ctx.args.server_num or 1
servers = [
"{}:{}".format(self.ctx.node.ip, p)
for p in self.ctx.node.get_free_ports(server_num)
]
trainer_num = self.ctx.args.trainer_num or 1
trainers = [
"{}:{}".format(self.ctx.node.ip, p)
for p in self.ctx.node.get_free_ports(trainer_num)
]
data = json.dumps({
'name': self.pod.name,
'rank': self.pod.rank,
'servers': servers,
'trainers': trainers,
'dtype': self.ctx.node.device.dtype,
'gloo_port': self.ctx.node.get_free_port(),
})
peer_list, rank = self.master.sync_peers(
'/{}/info'.format(self.job.id), self.pod.name, data,
self.job.replicas, self.pod.rank)
self.ctx.logger.debug("sync peers done {}".format(peer_list))
peer_list = [json.loads(i) for i in peer_list]
self.save_pod_log(peer_list)
server_endpoints = [j for i in peer_list for j in i['servers']]
trainer_endpoints = [j for i in peer_list for j in i['trainers']]
#rank_offset = sum([i['replicas'] for i in peer_list[:rank]])
server_rank_offset = sum([len(i['servers']) for i in peer_list[:rank]])
trainer_rank_offset = sum(
[len(i['trainers']) for i in peer_list[:rank]])
self.pod.rank = rank
self.pod.replicas = server_num + trainer_num
import tempfile
gloo_rendezvous_dir = tempfile.mkdtemp()
if os.path.exists(gloo_rendezvous_dir):
shutil.rmtree(gloo_rendezvous_dir)
gloo_port = peer_list[0]['gloo_port']
gloo_http = "{}:{}".format(server_endpoints[0].split(":")[0], gloo_port)
_gloo_envs = {
"PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": gloo_rendezvous_dir,
"PADDLE_GLOO_HTTP_ENDPOINT": gloo_http,
"PADDLE_WITH_GLOO": self.ctx.args.with_gloo
}
for i in range(server_num):
e = {
"PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_PORT":
server_endpoints[i + server_rank_offset].split(":")[1],
"PADDLE_ROLE": "PSERVER",
"TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": "{}".format(len(trainer_endpoints)),
"POD_IP": self.ctx.node.ip,
}
e.update(_gloo_envs)
log_tag = "ps.{}".format(i)
self.add_container(envs=e, log_tag=log_tag)
for i in range(trainer_num):
e = {
"PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_PORT":
trainer_endpoints[i + trainer_rank_offset].split(":")[1],
"PADDLE_ROLE": "TRAINER",
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": "{}".format(i + trainer_rank_offset),
"PADDLE_TRAINERS_NUM": "{}".format(len(trainer_endpoints)),
"POD_IP": self.ctx.node.ip,
}
e.update(_gloo_envs)
log_tag = "trainer.{}".format(i)
self.add_container(envs=e, log_tag=log_tag)
''' NEW VERSION
for i in range(server_num):
e = {
"PADDLE_PSERVER_ENDPOINTS": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_ROLE": "PSERVER",
"PADDLE_RANK": "{}".format(i + server_rank_offset),
}
log_tag = "ps.{}".format(i)
self.add_container(envs=e, log_tag=log_tag)
for i in range(trainer_num):
e = {
"PADDLE_PSERVER_ENDPOINTS": ",".join(server_endpoints),
"PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints),
"PADDLE_ROLE": "TRAINER_CPU",
"PADDLE_RANK": "{}".format(i + trainer_rank_offset),
}
log_tag = "trainer.{}".format(i)
self.add_container(envs=e, log_tag=log_tag)
'''
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .pod import Pod
from .job import Job
from .container import Container
from .status import Status
__all__ = [
'Pod',
'Job',
'Container',
'Status',
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from paddle.distributed.run.utils.process_context import ProcessContext
from .status import Status
import os, copy, sys
import time
class Container(object):
'''
TODO(kuizhiqing) A container can be run by process/thread or just a callable function
'''
def __init__(self, entrypoint=[], rank=-1, env={}):
self._entrypoint = entrypoint
self._rank = rank
self._out = None
self._err = None
self._env = env
self._proc = None
self._retry: int = 3
self._grace_period = 10
self._log_handler = None
@property
def entrypoint(self):
return self._entrypoint
@entrypoint.setter
def entrypoint(self, entry):
self._entrypoint = entry
@property
def rank(self):
return self._rank
@rank.setter
def rank(self, r):
self._rank = r
@property
def outfile(self):
return self._out
@outfile.setter
def outfile(self, out):
self._out = out
@property
def errfile(self):
return self._err
@errfile.setter
def errfile(self, err):
self._err = err
def update_env(self, env={}, **kwargs):
env = {k: v for k, v in env.items() if isinstance(v, str)}
self._env.update(env)
kwargs = {k: v for k, v in kwargs.items() if isinstance(v, str)}
self._env.update(kwargs)
def _get_fd(self, pth):
if not pth:
return None
try:
d = os.path.dirname(pth)
if not os.path.isdir(d):
os.makedirs(d, exist_ok=True)
return open(pth, 'w')
except:
return None
def start(self, timeout=-1):
end = time.time() + timeout
if self._proc and self._proc.alive():
return True
self._stdout = self._get_fd(self._out) or sys.stdout
if self._out == self._err:
self._stderr = self._stdout
elif self._err:
self._stderr = self._get_fd(self._err) or sys.stderr
self._proc = ProcessContext(
self._entrypoint, env=self._env, out=self._stdout, err=self._stderr)
self._proc.start()
while timeout > 0 and time.time() < end:
if self._proc.alive():
time.sleep(0.1)
continue
if self._proc.exit_code() == 0:
return True
return False
def terminate(self, force=False):
if self._log_handler:
self._log_handler.close()
self._log_handler = None
if self._proc and self._proc.alive():
return self._proc.terminate(force)
def wait(self, timeout=None):
self._proc.wait(timeout)
def exit_code(self):
return self._proc.exit_code() if self._proc else -1
def status(self):
if not self._proc:
return Status.UNINIT
if self._proc.alive():
return Status.RUNNING
elif self._proc.exit_code() == 0:
return Status.COMPLETED
else:
return Status.FAILED
def __str__(self):
return 'Container rank {} status {} cmd {} code {} log {} \nenv {}'.format(
self._rank,
self.status(),
self._entrypoint,
self.exit_code(),
self.errfile,
self._env, )
def logs(self, fn=None, offset=0, whence=1, lines=1000):
if not self._log_handler:
self._log_handler = open(self._out)
if fn is None:
fn = sys.stdout
self._log_handler.seek(offset, whence)
try:
idx = 0
for line in self._log_handler:
fn.write(line)
idx += 1
if idx > lines:
break
finally:
return self._log_handler.tell()
def tail(self, length=3000):
if not self._log_handler:
self._log_handler = open(self._out)
self._log_handler.seek(0, 2)
ed = self._log_handler.tell()
if ed > length:
self.logs(offset=ed - length, whence=0)
else:
self.logs(offset=0, whence=0)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class JobMode:
COLLECTIVE = 'collective'
PS = 'ps'
HETER = 'heter'
class Job(object):
def __init__(self, id='default', mode=JobMode.COLLECTIVE, np="1"):
self._mode = mode
self._id = id
self._replicas = 0
self._replicas_min = self._replicas
self._replicas_max = self._replicas
self._elastic = False
self.set_replicas(str(np))
def __str__(self):
return "Job: {}, mode {}, replicas {}[{}:{}], elastic {}".format(
self.id, self.mode, self._replicas, self._replicas_min,
self._replicas_max, self.elastic)
@property
def mode(self):
return self._mode
@property
def id(self):
return self._id
@property
def elastic(self):
return self._elastic
@property
def replicas(self):
return self._replicas
@property
def replicas_min(self):
return self._replicas_min
@property
def replicas_max(self):
return self._replicas_max
@replicas.setter
def replicas(self, replicas):
self._replicas = replicas
def set_replicas(self, np: str):
np = str(np) if np else '1'
if ':' in np:
nps = np.split(':')
self._replicas_min, self._replicas_max = int(nps[0]), int(nps[1])
self._replicas = self._replicas_max # default to max
self._elastic = True
else:
self._replicas = int(np)
self._replicas_min, self._replicas_max = self._replicas, self._replicas
self._elastic = False
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from .container import Container
from .status import Status
import random
import time
class PodSepc(object):
def __init__(self):
self._name = ''.join(
random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(6))
# by controller
self._init_containers: List[Container] = []
self._containers: List[Container] = []
#self.resource: Resource = None
#self.status: Status = None
self._rank = -1
self._init_timeout = 120 # 2 min timeout for each init container
self._restart = -1
self._replicas = 0 # number of containers
self._exit_code = 0
class Pod(PodSepc):
def __init__(self):
super().__init__()
def __str__(self):
return "Pod: {}, replicas {}, status {}".format(self.name,
self.replicas,
self.status())
def failed_container(self):
for c in self._containers:
if c.status() == Status.FAILED:
return c
return None
@property
def name(self):
return self._name
@property
def replicas(self):
return self._replicas
@replicas.setter
def replicas(self, r):
self._replicas = r
@property
def rank(self):
return self._rank
@rank.setter
def rank(self, r):
self._rank = r
@property
def restart(self):
return self._restart
@property
def containers(self):
return self._containers
def add_container(self, c):
c.rank = len(self._containers)
self._containers.append(c)
@property
def init_containers(self):
return self._init_containers
def add_init_container(self, c):
c.rank = len(self._init_containers)
self._init_containers.append(c)
@property
def exit_code(self):
for c in self._containers:
if c.exit_code() != 0:
return c.exit_code()
return 0
def deploy(self):
for i in self._init_containers:
i.start(self._init_timeout)
for c in self._containers:
c.start()
self._restart += 1
def stop(self, sigint=0):
for c in self._containers:
force = True if sigint == 9 else False
c.terminate(force)
def join(self):
for c in self._containers:
c.wait(None)
def status(self):
if self.is_failed():
return Status.FAILED
if self.is_completed():
return Status.COMPLETED
return Status.READY
def reset(self):
self._init_containers = []
self._containers = []
def is_failed(self):
for c in self._containers:
if c.status() == Status.FAILED:
return True
return False
def is_completed(self):
for c in self._containers:
if c.status() != Status.COMPLETED:
return False
return True
def logs(self, idx=None):
if idx is None:
if self.failed_container():
self.failed_container().logs()
else:
self._containers[0].logs()
else:
self._containers[idx].logs()
def tail(self, idx=None):
if idx is None:
if self.failed_container():
self.failed_container().tail()
else:
self._containers[0].tail()
else:
self._containers[idx].tail()
def watch(self,
all_list=[Status.COMPLETED],
any_list=[Status.FAILED],
interval=1,
timeout=-1):
'''
watch return if any container status in any_list
or all container status in all_list
'''
end = time.time() + timeout
while timeout < 0 or time.time() < end:
for c in self._containers:
if c.status() in any_list:
return c.status()
s = [c.status() for c in self._containers]
if len(set(s)) == 1 and s[0] in all_list:
return s[0]
time.sleep(interval)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Status(object):
UNINIT = "uninit"
READY = "ready"
RUNNING = "running"
FAILED = "failed"
TERMINATING = "terminating"
RESTARTING = "restarting"
UNKNOWN = "unknown"
COMPLETED = "completed"
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
__all__ = []
def log(ctx):
ctx.logger.info("----------- Configuration ----------------------")
for arg, value in sorted(six.iteritems(vars(ctx.args))):
ctx.logger.info("%s: %s" % (arg, value))
ctx.logger.info("--------------------------------------------------")
def process_args(ctx):
# reset device by args
#argdev = ctx.args.gpus or ctx.args.xpus or ctx.args.npus
argdev = ctx.args.devices
if argdev:
ctx.node.device.labels = argdev.split(',')
ctx.node.device.count = len(ctx.node.device.labels)
ctx.logger.debug('Device reset by args {}'.format(argdev))
def collective_compatible(ctx):
if 'PADDLE_TRAINER_ENDPOINTS' in ctx.envs:
ctx.master = ctx.envs['PADDLE_TRAINER_ENDPOINTS'].split(',')[0]
if 'DISTRIBUTED_TRAINER_ENDPOINTS' in ctx.envs:
ctx.master = ctx.envs['DISTRIBUTED_TRAINER_ENDPOINTS'].split(',')[0]
def rewrite_host_ip(ctx):
if ctx.args.host is not None and "." in ctx.args.host:
ctx.logger.warning('Host ip reset to {}'.format(ctx.args.host))
ctx.node.ip = ctx.args.host
enabled_plugins = [collective_compatible, rewrite_host_ip, process_args, log]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
def get_local_ip(ctx):
_, ip = _get_host_name_ip()
ctx.args.host = ip
ctx.envs["POD_IP"] = ip
def _get_host_name_ip():
try:
host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name)
return host_name, host_ip
except:
return None
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests
import time
class KVClient(object):
def __init__(self, endpoint='localhost:2379'):
self.endpoint = endpoint if endpoint.startswith(
"http://") else "http://{}".format(endpoint)
def put(self, key, value):
key = key if key.startswith('/') else "/{}".format(key)
u = "{}{}".format(self.endpoint, key)
try:
r = requests.post(u, data=value, timeout=3)
if r.status_code == 200:
return True
else:
return False
except:
return False
def get(self, key):
key = key if key.startswith('/') else "/{}".format(key)
u = "{}{}".format(self.endpoint, key)
try:
r = requests.get(u, timeout=3)
if r.status_code == 200:
ret = r.json()
return ret.get(key, '')
else:
return "error"
except:
return ""
def get_prefix(self, key):
key = key if key.startswith('/') else "/{}".format(key)
u = "{}{}".format(self.endpoint, key)
try:
r = requests.get(u, timeout=3)
if r.status_code == 200:
return r.json()
except:
return ""
def delete(self, key):
key = key if key.startswith('/') else "/{}".format(key)
u = "{}{}".format(self.endpoint, key)
try:
r = requests.delete(u, timeout=3)
if r.status_code == 200:
return True
else:
return False
except:
return False
def wait_server_ready(self, timeout=3):
end = time.time() + timeout
while time.time() < end:
if self.get("/healthy") == "ok":
return True
if __name__ == '__main__':
cli = PKVClient("http://localhost:8090")
data = {"/workers/1": "rank1", "/workers/2": "rank2"}
for k, v in data.items():
cli.put(k, v)
x = cli.get_prefix("/workers")
print(x)
for k, v in data.items():
assert x[k] == v
cli.put("key", "value")
print(cli.get("key"))
assert cli.get("key") == "value"
cli.delete("key")
print(cli.get("/key"))
print(cli.get("/healthy"))
assert cli.get("/healthy") == "ok"
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http.server import HTTPServer
import http.server as SimpleHTTPServer
from multiprocessing import Process
import threading
import json
class KVHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
def do_GET(self):
with self.server.kv_lock:
ret = {}
for k, v in self.server.kv.items():
if k.startswith(self.path):
ret[k] = v.decode(encoding="utf-8")
if ret:
self.output(200, json.dumps(ret).encode("utf-8"))
else:
self.output(404)
def do_PUT(self):
self.do_POST()
def do_POST(self):
content_length = int(self.headers['Content-Length'] or 0)
try:
value = self.rfile.read(content_length)
with self.server.kv_lock:
self.server.kv[self.path] = value
self.output(200)
return
except:
self.output(500)
def do_DELETE(self):
with self.server.kv_lock:
if self.path in self.server.kv:
del self.server.kv[self.path]
self.output(200)
else:
self.output(404)
def output(self, code, value=''):
self.send_response(code)
self.send_header("Content-Length", len(value))
self.send_header("Content-Type", "application/json; charset=utf8")
self.end_headers()
if value:
self.wfile.write(value)
def log_message(self, format, *args):
return
class KVServer(HTTPServer, object):
def __init__(self, port):
super(KVServer, self).__init__(('', port), KVHandler)
self.kv_lock = threading.Lock()
self.kv = {'/healthy': b'ok'}
self.port = port
self.stopped = False
self.started = False
def start(self):
self.listen_thread = threading.Thread(target=self.serve_forever)
self.listen_thread.start()
self.started = True
def stop(self):
self.shutdown()
self.listen_thread.join()
self.server_close()
self.stopped = True
class PKVServer():
def __init__(self, port):
self._server = KVServer(port)
def start(self):
self.proc = Process(target=self._server.start)
self.proc.daemon = True
self.proc.start()
def stop(self):
self._server.stop()
self.proc.join()
@property
def started(self):
return self._server.started
@property
def stopped(self):
return self._server.stopped
if __name__ == '__main__':
#kv = PKVServer(8090)
kv = KVServer(8090)
kv.start()
import time
#print("serve at 8090 for 600 s")
time.sleep(600)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import os, sys, signal, time
class ProcessContext(object):
def __init__(self,
cmd,
env=os.environ,
out=sys.stdout,
err=sys.stderr,
group=True,
preexec_fn=None):
self._cmd = cmd
self._env = env
self._preexec_fn = preexec_fn
self._stdout = out
self._stderr = err
self._group = group if os.name != 'nt' else False
self._proc = None
self._code = None
def _start(self):
pre_fn = os.setsid if self._group else None
self._proc = subprocess.Popen(
self._cmd,
env=self._env,
stdout=self._stdout,
stderr=self._stderr,
preexec_fn=self._preexec_fn or pre_fn)
def _close_std(self):
try:
if not self._stdout.isatty():
self._stdout.close()
if not self._stderr.isatty():
self._stderr.close()
except:
pass
def alive(self):
return self._proc and self._proc.poll() is None
def exit_code(self):
return self._proc.poll() if self._proc else None
def start(self):
self._start()
def terminate(self, force=False, max_retry=3):
for i in range(max_retry):
if self.alive():
if self._group:
os.killpg(os.getpgid(self._proc.pid), signal.SIGTERM)
else:
self._proc.terminate()
time.sleep(0.2)
else:
break
if force and self.alive():
self._proc.kill()
self._close_std()
return self.alive()
def wait(self, timeout=None):
self._proc.wait(timeout)
......@@ -949,6 +949,7 @@ if (WITH_DISTRIBUTE AND NOT APPLE)
endif()
# setting timeout value as 15S
set_tests_properties(test_run PROPERTIES TIMEOUT 200)
set_tests_properties(test_sync_batch_norm_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_cross_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_lod_tensor_to_selected_rows PROPERTIES TIMEOUT 200)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import subprocess
import sys, os
import json
import shutil
import random
from os import listdir
from os.path import isfile, join
pyname = 'train.py'
colpyfile = '''# train.py for unitest
import os
env = os.environ.copy()
assert "PADDLE_MASTER" in env
assert "PADDLE_GLOBAL_SIZE" in env
assert "PADDLE_LOCAL_SIZE" in env
assert "PADDLE_GLOBAL_RANK" in env
assert "PADDLE_LOCAL_RANK" in env
'''
pspyfile = '''# train.py for unitest
import os
env = os.environ.copy()
assert "PADDLE_PSERVERS_IP_PORT_LIST" in env
assert "PADDLE_TRAINER_ENDPOINTS" in env
#assert "PADDLE_PSERVER_ENDPOINTS" in env
#assert "PADDLE_TRAINER_ENDPOINTS" in env
#assert "PADDLE_ROLE" in env
#assert "PADDLE_RANK" in env
'''
def write_file(name, ct):
with open(name, "w") as f:
f.write(ct)
def get_files(pth, prefix):
return [
f for f in listdir(pth) if isfile(join(pth, f)) and f.startswith(prefix)
]
class Collective_Test(unittest.TestCase):
def setUp(self):
write_file(pyname, colpyfile)
def pdrun(self, args, env=None):
cmd = [sys.executable.split('/')[-1], "-m", "paddle.distributed.run"]
if args:
cmd.extend(args.split(" "))
cmd.extend([pyname])
proc = subprocess.Popen(cmd, env)
return proc
'''
def test_collective_1(self):
args = "--id test1"
p = self.pdrun(args)
p.wait()
self.assertTrue(p.poll() == 0)
'''
def test_collective_2(self):
if os.path.exists('./log'):
shutil.rmtree('./log')
args = "--id test2 --devices 0,1,2"
p = self.pdrun(args)
p.wait()
self.assertTrue(p.poll() == 0)
c = get_files('log', 'test2')
self.assertTrue(len(c) == 4)
def test_collective_3(self):
if os.path.exists('./log'):
shutil.rmtree('./log')
port = random.randrange(6000, 8000)
args = "--id test3 --devices 0,1 --master 127.0.0.1:{} --np 2".format(
port)
p1 = self.pdrun(args)
p2 = self.pdrun(args)
p1.wait()
p2.wait()
self.assertTrue(p1.poll() == 0)
self.assertTrue(p2.poll() == 0)
c = get_files('log', 'test3')
self.assertTrue(len(c) == 6)
class PS_Test(unittest.TestCase):
def setUp(self):
write_file(pyname, pspyfile)
def pdrun(self, args, env=None):
cmd = [sys.executable.split('/')[-1], "-m", "paddle.distributed.run"]
if args:
cmd.extend(args.split(" "))
cmd.extend([pyname])
proc = subprocess.Popen(cmd, env)
return proc
'''
def test_ps_1(self):
args = "--mode ps"
p = self.pdrun(args)
p.wait()
self.assertTrue(p.poll() == 0)
def test_ps_2(self):
if os.path.exists('./log'):
shutil.rmtree('./log')
args = "--id ps2 --server_num=2 --trainer_num=2"
p = self.pdrun(args)
p.wait()
self.assertTrue(p.poll() == 0)
c = get_files('log', 'ps2')
self.assertTrue(len(c) == 5)
'''
def test_ps_3(self):
if os.path.exists('./log'):
shutil.rmtree('./log')
port = random.randrange(6000, 8000)
args = "--id ps3 --master 127.0.0.1:{} --np 2 --server_num=1 --trainer_num=1".format(
port)
p1 = self.pdrun(args)
p2 = self.pdrun(args)
p1.wait()
p2.wait()
self.assertTrue(p1.poll() == 0)
self.assertTrue(p2.poll() == 0)
c = get_files('log', 'ps3')
self.assertTrue(len(c) == 6)
def test_ps_4(self):
if os.path.exists('./log'):
shutil.rmtree('./log')
args = "--id ps4 --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903"
p1 = self.pdrun(args)
p1.wait()
self.assertTrue(p1.poll() == 0)
c = get_files('log', 'ps4')
self.assertTrue(len(c) == 5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册