未验证 提交 c54c60de 编写于 作者: K kuizhiqing 提交者: GitHub

fleetrun launch in legacy mode (#40568)

上级 49f1ab2a
...@@ -12,74 +12,69 @@ ...@@ -12,74 +12,69 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .job.container import Container __all__ = []
from .job.pod import Pod
from .job.job import Job
from . import plugins
#__all__ = [Container, Pod, Job]
''' '''
Paddle distribution training entry ``python -m paddle.distributed.run``. Paddle distributed training entry ``python -m paddle.distributed.launch``.
Help Help
# for arg usage and explanation, try the following command # for arg usage and explanation, try the following command
# python -m paddle.distributed.run -h # python -m paddle.distributed.launch -h
Collective Mode Collective Mode
Case 1: 1 node Case 1: 1 node
use all visible devices use all visible devices
# python -m paddle.distributed.run train.py # python -m paddle.distributed.launch train.py
use specified devices use specified devices
# python -m paddle.distributed.run --devices=0,1,2,3 train.py # python -m paddle.distributed.launch --devices=0,1,2,3 train.py
Case 2: multi-node, auto detect ip/port Case 2: multi-node, auto detect ip/port
# python -m paddle.distributed.run --np 2 train.py # python -m paddle.distributed.launch --nnodes 2 train.py
# auto print following command # auto print following command
# python -m paddle.distributed.run --master 10.0.0.1:13538 --np 2 demo.py # python -m paddle.distributed.launch --master 10.0.0.1:13538 --nnodes 2 demo.py
# then copy and paste above command to other nodes # then copy and paste above command to other nodes
Case 3: multi-node, specified master/rendezvous server Case 3: multi-node, specified master/rendezvous server
# python -m paddle.distributed.run --np 2 --master 10.0.0.1:2379 train.py # python -m paddle.distributed.launch --nnodes 2 --master 10.0.0.1:2379 train.py
# the master ip must be one of the node and the port must available # the master ip must be one of the node and the port must available
Parameter Server Mode Parameter Server Mode
Case 1.1: 1 node, 1 ps, 1 trainer Case 1.1: 1 node, 1 ps, 1 trainer
# python -m paddle.distributed.run --mode ps train.py # python -m paddle.distributed.launch --mode ps train.py
# python -m paddle.distributed.run --server_num=1 --trainer_num=1 train.py # python -m paddle.distributed.launch --server_num=1 --trainer_num=1 train.py
Case 1.2: 1 node, 2 ps, 2 trainer Case 1.2: 1 node, 2 ps, 2 trainer
# python -m paddle.distributed.run --server_num=2 --trainer_num=2 train.py # python -m paddle.distributed.launch --server_num=2 --trainer_num=2 train.py
Case 2: 2 node, 2 ps, 2 trainer per node Case 2: 2 node, 2 ps, 2 trainer per node
# python -m paddle.distributed.run --server_num=2 --trainer_num=2 --np 2 train.py # python -m paddle.distributed.launch --server_num=2 --trainer_num=2 --nnodes 2 train.py
# auto print following command # auto print following command
# python -m paddle.distributed.run --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --np 2 train.py # python -m paddle.distributed.launch --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --nnodes 2 train.py
# then copy and paste above command to other nodes # then copy and paste above command to other nodes
Case 3: multi-node, specified master/rendezvous server Case 3: multi-node, specified master/rendezvous server
# python -m paddle.distributed.run --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --np 2 train.py # python -m paddle.distributed.launch --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --nnodes 2 train.py
# the master ip must be one of the node and the port must available # the master ip must be one of the node and the port must available
Case 4: specified servers and trainers in each node Case 4: specified servers and trainers in each node
python -m paddle.distributed.run --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903 train.py python -m paddle.distributed.launch --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903 train.py
Elastic Mode Elastic Mode
# run following command in 3 node to run immediately, or in 2 node to run after elastic_timeout # run following command in 3 node to run immediately, or in 2 node to run after elastic_timeout
# python -m paddle.distributed.run --master etcd://10.0.0.1:2379 --np 2:3 train.py # python -m paddle.distributed.launch --master etcd://10.0.0.1:2379 --nnodes 2:3 train.py
# once the peer number changes between 2:3, the strategy holds # once the peer number changes between 2:3, the strategy holds
......
...@@ -15,14 +15,28 @@ ...@@ -15,14 +15,28 @@
from .context import Context from .context import Context
from . import controllers from . import controllers
# initialize the context to run
ctx = Context()
# initialize the selected controller def launch():
c = controllers.init(ctx) # initialize the context to run
ctx = Context()
# run the pods if ctx.is_legacy_mode():
c.run()
# manager or just wait pod # legacy mode
c.finalize() from paddle.distributed.fleet import launch
launch.launch()
else:
# initialize the selected controller
c = controllers.init(ctx)
# run the pods
c.run()
# manager or just wait pod
c.finalize()
if __name__ == "__main__":
launch()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.launch import plugins
from .node import Node
from .status import Status
from .args_envs import parse_args, fetch_envs, env_args_mapping
import logging
class Context(object):
def __init__(self, enable_plugin=True):
self.args, self.unknown_args = parse_args()
self.envs = fetch_envs()
self.logger = self.get_logger()
self.node = Node()
self.status = Status()
self.set_env_in_args()
# design for event queue, later
self.events = []
if enable_plugin:
self._enable_plugin()
def is_legacy_mode(self):
if self.args.legacy:
return True
if len(self.unknown_args) > 0:
self.logger.warning("Compatible mode enable with args {}".format(
self.unknown_args))
return True
legacy_env_list = [
'DISTRIBUTED_TRAINER_ENDPOINTS',
'PADDLE_ELASTIC_JOB_ID',
'PADDLE_DISTRI_BACKEND',
'FLAGS_START_PORT',
]
for env in legacy_env_list:
if env in self.envs:
self.logger.warning(
"ENV {} is deprecated, legacy launch enable".format(env))
return True
if self.args.master:
return False
return False
def get_envs(self):
return self.envs.copy()
def _enable_plugin(self):
for pl in plugins.enabled_plugins:
pl(self)
def get_logger(self, level=logging.INFO):
logger = logging.getLogger("LAUNCH")
logger.setLevel(self.args.log_level.upper() or level)
formatter = logging.Formatter(
fmt='%(name)s %(levelname)s %(asctime)s %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
return logger
def set_env_in_args(self):
for k, v in env_args_mapping.items():
if k in self.envs:
setattr(self.args, v, self.envs[k])
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from argparse import ArgumentParser, REMAINDER
env_args_mapping = {
'POD_IP': 'host',
'PADDLE_MASTER': 'master',
'PADDLE_DEVICES': 'devices',
'PADDLE_NNODES': 'nnodes',
'PADDLE_MODE': 'mode',
'PADDLE_LOG_LEVEL': 'log_level',
'PADDLE_NPROC_PER_NODE': 'nproc_per_node',
'PADDLE_JOB_ID': 'job_id',
'PADDLE_RANK': 'rank',
'PADDLE_LOG_DIR': 'log_dir',
'PADDLE_MAX_RESTART': 'max_restart',
'PADDLE_ELASTIC_LEVEL': 'elastic_level',
'PADDLE_ELASTIC_TIMEOUT': 'elastic_timeout',
'PADDLE_SERVER_NUM': 'server_num',
'PADDLE_TRAINER_NUM': 'trainer_num',
'PADDLE_SERVERS_ENDPOINTS': 'servers',
'PADDLE_TRAINERS_ENDPOINTS': 'trainers',
'PADDLE_GLOO_PORT': 'gloo_port',
'PADDLE_WITH_GLOO': 'with_gloo',
}
def fetch_envs():
os.environ.pop('http_proxy', None)
os.environ.pop('https_proxy', None)
return os.environ.copy()
def parse_args():
parser = ArgumentParser()
base_group = parser.add_argument_group("Base Parameters")
base_group.add_argument(
"--master",
type=str,
default=None,
help="the master/rendezvous server, ip:port")
base_group.add_argument(
"--legacy", type=bool, default=False, help="use legacy launch")
base_group.add_argument(
"--rank", type=int, default=-1, help="the peer rank")
base_group.add_argument(
"--log_level", type=str, default="INFO", help="log level. Default INFO")
base_group.add_argument(
"--nnodes",
type=str,
default="1",
help="the number of peers, i.e. pod/node number")
base_group.add_argument(
"--nproc_per_node",
type=int,
default=None,
help="the number of processes in a pod")
base_group.add_argument(
"--log_dir",
type=str,
default="log",
help="the path for each process's log. Default ./log")
base_group.add_argument(
"--mode",
type=str,
default="collective",
help="run mode of the job, collective/ps/ps-heter")
base_group.add_argument(
"--job_id",
type=str,
default="default",
help="unique id of the job. Default default")
base_group.add_argument(
"--devices",
type=str,
default=None,
help="accelerate devices. as --gpus,npus,xps")
base_group.add_argument("--host", type=str, default=None, help="host ip")
base_group.add_argument(
"training_script",
type=str,
help="the full path of py script,"
"followed by arguments for the "
"training script")
base_group.add_argument('training_script_args', nargs=REMAINDER)
ps_group = parser.add_argument_group("Parameter-Server Parameters")
# for parameter server
ps_group.add_argument(
"--servers", type=str, default='', help="servers endpoints full list")
ps_group.add_argument(
"--trainers", type=str, default='', help="trainers endpoints full list")
ps_group.add_argument(
"--trainer_num", type=int, default=None, help="number of trainers")
ps_group.add_argument(
"--server_num", type=int, default=None, help="number of servers")
ps_group.add_argument(
"--gloo_port", type=int, default=6767, help="gloo http port")
ps_group.add_argument(
"--with_gloo", type=str, default="0", help="use gloo or not")
# parameter elastic mode
elastic_group = parser.add_argument_group("Elastic Parameters")
elastic_group.add_argument(
"--max_restart",
type=int,
default=3,
help="the times can restart. Default 3")
elastic_group.add_argument(
"--elastic_level",
type=int,
default=-1,
help="elastic level: -1 disable, 0 failed exit, peers hold, 1 internal restart"
)
elastic_group.add_argument(
"--elastic_timeout",
type=int,
default=30,
help="seconds to wait before elastic perform training")
return parser.parse_known_args()
...@@ -20,36 +20,90 @@ class DeviceType: ...@@ -20,36 +20,90 @@ class DeviceType:
GPU = 'gpu' GPU = 'gpu'
XPU = 'xpu' XPU = 'xpu'
NPU = 'npu' NPU = 'npu'
MLU = 'mlu'
class Device(object): class Device(object):
def __init__(self, dtype=None, count=1, memory="", labels=""): def __init__(self, dtype=None, memory="", labels=""):
self.dtype = dtype self._dtype = dtype
self.count = count self._memory = memory
self.memory = memory self._labels = labels
self.labels = labels
def __str__(self): def __str__(self):
return ",".join(self.labels) return ",".join(self._labels)
@property
def dtype(self):
return self._dtype
@property
def count(self):
return len(self._labels) or 1
@property
def memory(self):
return self._memory
@property
def labels(self):
return self._labels
@labels.setter
def labels(self, lbs):
if isinstance(lbs, str):
self._labels = lbs.split(',')
elif isinstance(lbs, list):
self._labels = lbs
else:
self._labels = []
def get_selected_flag_key(self):
if self._dtype == DeviceType.CPU:
return 'FLAGS_selected_cpus'
if self._dtype == DeviceType.GPU:
return 'FLAGS_selected_gpus'
if self._dtype == DeviceType.NPU:
return 'FLAGS_selected_npus'
if self._dtype == DeviceType.XPU:
return 'FLAGS_selected_xpus'
if self._dtype == DeviceType.MLU:
return 'FLAGS_selected_mlus'
return 'FLAGS_selected_devices'
def get_selected_flag_label(self, idx):
if idx < len(self._labels):
return self._labels[idx]
else:
return '0'
def selected_flags(self, idx=None):
if idx is None:
return {self.get_selected_flag_key(): ','.join(self._labels)}
else:
return {
self.get_selected_flag_key(): self.get_selected_flag_label(idx)
}
@classmethod @classmethod
def parse_device(self): def parse_device(self):
dev = Device() dev = Device()
visible_devices = None visible_devices = None
if 'CUDA_VISIBLE_DEVICES' in os.environ or 'NVIDIA_VISIBLE_DEVICES' in os.environ: if 'CUDA_VISIBLE_DEVICES' in os.environ or 'NVIDIA_VISIBLE_DEVICES' in os.environ:
dev.dtype = DeviceType.GPU dev._dtype = DeviceType.GPU
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") or os.getenv( visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") or os.getenv(
"NVIDIA_VISIBLE_DEVICES") "NVIDIA_VISIBLE_DEVICES")
elif 'XPU_VISIBLE_DEVICES' in os.environ: elif 'XPU_VISIBLE_DEVICES' in os.environ:
dev.dtype = DeviceType.XPU dev._dtype = DeviceType.XPU
visible_devices = os.getenv("XPU_VISIBLE_DEVICES") visible_devices = os.getenv("XPU_VISIBLE_DEVICES")
elif 'ASCEND_VISIBLE_DEVICES' in os.environ: elif 'ASCEND_VISIBLE_DEVICES' in os.environ:
dev.dtype = DeviceType.NPU dev._dtype = DeviceType.NPU
visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES")
elif 'MLU_VISIBLE_DEVICES' in os.environ:
dev._dtype = DeviceType.MLU
visible_devices = os.getenv("MLU_VISIBLE_DEVICES")
if visible_devices and visible_devices != 'all': if visible_devices is not None and visible_devices != 'all':
dev.labels = visible_devices.split(',') dev._labels = visible_devices.split(',')
dev.count = len(dev.labels)
else: else:
return self.detect_device() return self.detect_device()
...@@ -63,26 +117,33 @@ class Device(object): ...@@ -63,26 +117,33 @@ class Device(object):
num = 0 num = 0
visible_devices = None visible_devices = None
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
dev.dtype = DeviceType.GPU dev._dtype = DeviceType.GPU
num = fluid.core.get_cuda_device_count() num = fluid.core.get_cuda_device_count()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") or os.getenv( visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") or os.getenv(
"NVIDIA_VISIBLE_DEVICES") "NVIDIA_VISIBLE_DEVICES")
elif fluid.core.is_compiled_with_xpu(): elif fluid.core.is_compiled_with_xpu():
dev.dtype = DeviceType.XPU dev._dtype = DeviceType.XPU
num = fluid.core.get_xpu_device_count() num = fluid.core.get_xpu_device_count()
visible_devices = os.getenv("XPU_VISIBLE_DEVICES") visible_devices = os.getenv("XPU_VISIBLE_DEVICES")
elif fluid.core.is_compiled_with_npu(): elif fluid.core.is_compiled_with_npu():
dev.dtype = DeviceType.NPU dev._dtype = DeviceType.NPU
num = fluid.core.get_npu_device_count() num = fluid.core.get_npu_device_count()
visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES")
elif fluid.core.is_compiled_with_mlu():
dev._dtype = DeviceType.MLU
num = fluid.core.get_mlu_device_count()
visible_devices = os.getenv("MLU_VISIBLE_DEVICES")
if num == 0: if num == 0:
dev.dtype = DeviceType.CPU dev._dtype = DeviceType.CPU
elif visible_devices is None or visible_devices == "all" or visible_devices == "": elif visible_devices is None or visible_devices == "all":
dev.labels = [str(x) for x in range(0, num)] dev._labels = [str(x) for x in range(0, num)]
dev.count = num
else: else:
dev.labels = visible_devices.split(',') dev._labels = visible_devices.split(',')
dev.count = len(dev.labels)
return dev return dev
if __name__ == '__main__':
d = Device.parse_device()
print(d.get_selected_flag())
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__all__ = ["init"] __all__ = []
from .collective import CollectiveController from .collective import CollectiveController
from .collective import CollectiveElasticController from .collective import CollectiveElasticController
......
...@@ -89,6 +89,10 @@ class CollectiveController(Controller): ...@@ -89,6 +89,10 @@ class CollectiveController(Controller):
"PADDLE_TRAINERS_NUM": "{}".format(global_size), "PADDLE_TRAINERS_NUM": "{}".format(global_size),
"PADDLE_RANK_IN_NODE": str(i), "PADDLE_RANK_IN_NODE": str(i),
} }
if self.pod.replicas == 1:
e.update(self.ctx.node.device.selected_flags())
else:
e.update(self.ctx.node.device.selected_flags(i))
self.add_container(envs=e, log_tag=i) self.add_container(envs=e, log_tag=i)
return True return True
...@@ -106,7 +110,8 @@ class CollectiveElasticController(CollectiveController): ...@@ -106,7 +110,8 @@ class CollectiveElasticController(CollectiveController):
def register(self): def register(self):
if self.job.id == 'default': if self.job.id == 'default':
self.ctx.logger.warning( self.ctx.logger.warning(
'Using default job name may cause conflict, add --id in args') 'Using default job name may cause conflict, add --job_id in args'
)
self.master.register_heartbeat(self.job.id, self.pod.name) self.master.register_heartbeat(self.job.id, self.pod.name)
...@@ -114,6 +119,8 @@ class CollectiveElasticController(CollectiveController): ...@@ -114,6 +119,8 @@ class CollectiveElasticController(CollectiveController):
''' '''
watch self and peer status, return true to exit watch self and peer status, return true to exit
''' '''
self.ctx.logger.info("Watching {}".format(self.pod))
while not self.ctx.status.is_done(): while not self.ctx.status.is_done():
# self status # self status
status = self.pod.watch(timeout=2) status = self.pod.watch(timeout=2)
...@@ -171,13 +178,8 @@ class CollectiveElasticController(CollectiveController): ...@@ -171,13 +178,8 @@ class CollectiveElasticController(CollectiveController):
continue continue
self.master.set_status(self.ctx.status.RUNNING) self.master.set_status(self.ctx.status.RUNNING)
self.ctx.status.run()
assert len(self.pod.containers) > 0, "No container in the pod"
self.ctx.logger.debug("Run {}".format(self.pod))
self.ctx.logger.debug("Run {}".format(self.pod.containers[0]))
self.pod.deploy() self.deploy_pod()
if self.watch(): if self.watch():
break break
......
...@@ -16,9 +16,9 @@ import sys ...@@ -16,9 +16,9 @@ import sys
import os import os
import signal import signal
from paddle.distributed.run.job import Job from paddle.distributed.launch.job.job import Job
from paddle.distributed.run.job import Pod from paddle.distributed.launch.job.pod import Pod
from paddle.distributed.run.job import Container from paddle.distributed.launch.job.container import Container
from .master import Master from .master import Master
...@@ -39,38 +39,43 @@ class ControllerBase(object): ...@@ -39,38 +39,43 @@ class ControllerBase(object):
self.ctx = ctx self.ctx = ctx
self.master = Master.factory(self.ctx) self.master = Master.factory(self.ctx)
self.job = Job(np=self.ctx.args.np, self.job = Job(nnodes=self.ctx.args.nnodes,
mode=self.ctx.args.mode, mode=self.ctx.args.mode,
id=self.ctx.args.id) jid=self.ctx.args.job_id)
self.pod = Pod() self.pod = Pod()
self.join_server = None self.join_server = None
def run(self): def deploy_pod(self):
self.build_job()
self.build_pod()
if len(self.pod.containers) < 1: assert len(self.pod.containers) > 0, "No container in the pod"
self.ctx.logger.error("No container in the pod {}".format(self.pod))
return
self.ctx.logger.info("Run {}".format(self.pod)) self.ctx.logger.info("Run {}".format(self.pod))
self.ctx.logger.debug(self.pod.containers[0]) self.ctx.logger.debug(self.pod.containers[0])
self.ctx.status.run()
self.pod.deploy() self.pod.deploy()
def run(self):
self.build_job()
self.build_pod()
self.deploy_pod()
self.watch() self.watch()
def watch(self) -> bool: def watch(self) -> bool:
self.ctx.logger.info("Watching {}".format(self.pod))
status = self.pod.watch() status = self.pod.watch()
if status == self.ctx.status.COMPLETED: if status == self.ctx.status.COMPLETED:
self.ctx.logger.info("Pod {}".format(status)) self.ctx.logger.info("Pod {}".format(status))
elif status == self.ctx.status.FAILED: elif status == self.ctx.status.FAILED:
fc = self.pod.failed_container()
self.ctx.logger.info("Pod {}".format(status)) self.ctx.logger.info("Pod {}".format(status))
self.ctx.logger.error("Container failed !!!\n{}".format( self.ctx.logger.error("Container failed !!!\n{}".format(fc[0]))
self.pod.failed_container())) fc[0].tail()
self.pod.tail()
self.pod.stop() self.pod.stop()
def stop(self, sigint=None): def stop(self, sigint=None):
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.distributed.run.utils.kv_client import KVClient from paddle.distributed.launch.utils.kv_client import KVClient
from paddle.distributed.run.utils.kv_server import KVServer from paddle.distributed.launch.utils.kv_server import KVServer
import time import time
import sys import sys
...@@ -84,7 +84,7 @@ class HTTPMaster(Master): ...@@ -84,7 +84,7 @@ class HTTPMaster(Master):
print("Copy the following command to other nodes to run.") print("Copy the following command to other nodes to run.")
cmd = [ cmd = [
sys.executable.split('/')[-1], "-m", "paddle.distributed.run" sys.executable.split('/')[-1], "-m", "paddle.distributed.launch"
] ]
cmd.extend(["--master", self.endpoint]) cmd.extend(["--master", self.endpoint])
cmd.extend(sys.argv[1:]) cmd.extend(sys.argv[1:])
...@@ -118,9 +118,12 @@ class HTTPMaster(Master): ...@@ -118,9 +118,12 @@ class HTTPMaster(Master):
self._stop_server() self._stop_server()
def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int): def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int):
if size < 2: if size < 2:
return [value], 0 return [value], 0
self.ctx.logger.info("Waiting peer ready...")
self.lazy_init() self.lazy_init()
while not self.ctx.status.is_done(): while not self.ctx.status.is_done():
...@@ -130,7 +133,7 @@ class HTTPMaster(Master): ...@@ -130,7 +133,7 @@ class HTTPMaster(Master):
self.ctx.logger.warning("master not ready") self.ctx.logger.warning("master not ready")
time.sleep(0.1) time.sleep(0.1)
# 'aaaaaa' make suer main pod (master server) as rank 0 # 'aaaaaa' make sure main pod (master server) as rank 0
ky = 'aaaaaa' if rank < 0 and self.role == Master.MAIN else key ky = 'aaaaaa' if rank < 0 and self.role == Master.MAIN else key
k = "{}/{}/{}".format(prefix, ky, rank) k = "{}/{}/{}".format(prefix, ky, rank)
...@@ -177,6 +180,12 @@ class ETCDMaster(Master): ...@@ -177,6 +180,12 @@ class ETCDMaster(Master):
sync_peers gather all value for key under scope prefix sync_peers gather all value for key under scope prefix
result always be sorted either by rank or alphabet of pod.name result always be sorted either by rank or alphabet of pod.name
''' '''
if size < 2:
return [value], 0
self.ctx.logger.info("Waiting peer ready...")
path = "{}/{}/{}".format(prefix, key, rank) path = "{}/{}/{}".format(prefix, key, rank)
self.client.delete_prefix(prefix) self.client.delete_prefix(prefix)
......
...@@ -22,7 +22,8 @@ class PSController(Controller): ...@@ -22,7 +22,8 @@ class PSController(Controller):
@classmethod @classmethod
def enable(cls, ctx): def enable(cls, ctx):
if ctx.args.mode == ControleMode.PS or ctx.args.server_num or len( if ctx.args.mode == ControleMode.PS or ctx.args.server_num or len(
ctx.args.servers) > 0: ctx.args.servers) > 0 or ctx.args.trainer_num or len(
ctx.args.trainers) > 0:
ctx.logger.debug("{} enabled".format(cls.__name__)) ctx.logger.debug("{} enabled".format(cls.__name__))
ctx.args.mode = ControleMode.PS ctx.args.mode = ControleMode.PS
return True return True
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.distributed.fleet import launch
launch.launch()
__all__ = []
...@@ -13,12 +13,11 @@ ...@@ -13,12 +13,11 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from paddle.distributed.run.utils.process_context import ProcessContext from paddle.distributed.launch.utils.process_context import ProcessContext
from .status import Status from .status import Status
import os, copy, sys import os, copy, sys
import time
class Container(object): class Container(object):
...@@ -78,6 +77,11 @@ class Container(object): ...@@ -78,6 +77,11 @@ class Container(object):
kwargs = {k: v for k, v in kwargs.items() if isinstance(v, str)} kwargs = {k: v for k, v in kwargs.items() if isinstance(v, str)}
self._env.update(kwargs) self._env.update(kwargs)
def _valide_env(self):
for k, v in self._env.items():
assert isinstance(k, str) and isinstance(
v, str), 'env {}:{} must be str'.format(k, v)
def _get_fd(self, pth): def _get_fd(self, pth):
if not pth: if not pth:
return None return None
...@@ -90,12 +94,12 @@ class Container(object): ...@@ -90,12 +94,12 @@ class Container(object):
except: except:
return None return None
def start(self, timeout=-1): def start(self):
end = time.time() + timeout
if self._proc and self._proc.alive(): if self._proc and self._proc.alive():
return True return True
self._valide_env()
self._stdout = self._get_fd(self._out) or sys.stdout self._stdout = self._get_fd(self._out) or sys.stdout
if self._out == self._err: if self._out == self._err:
self._stderr = self._stdout self._stderr = self._stdout
...@@ -106,14 +110,6 @@ class Container(object): ...@@ -106,14 +110,6 @@ class Container(object):
self._entrypoint, env=self._env, out=self._stdout, err=self._stderr) self._entrypoint, env=self._env, out=self._stdout, err=self._stderr)
self._proc.start() self._proc.start()
while timeout > 0 and time.time() < end:
if self._proc.alive():
time.sleep(0.1)
continue
if self._proc.exit_code() == 0:
return True
return False
def terminate(self, force=False): def terminate(self, force=False):
if self._log_handler: if self._log_handler:
self._log_handler.close() self._log_handler.close()
...@@ -125,9 +121,11 @@ class Container(object): ...@@ -125,9 +121,11 @@ class Container(object):
def wait(self, timeout=None): def wait(self, timeout=None):
self._proc.wait(timeout) self._proc.wait(timeout)
@property
def exit_code(self): def exit_code(self):
return self._proc.exit_code() if self._proc else -1 return self._proc.exit_code() if self._proc else -1
@property
def status(self): def status(self):
if not self._proc: if not self._proc:
return Status.UNINIT return Status.UNINIT
...@@ -141,9 +139,9 @@ class Container(object): ...@@ -141,9 +139,9 @@ class Container(object):
def __str__(self): def __str__(self):
return 'Container rank {} status {} cmd {} code {} log {} \nenv {}'.format( return 'Container rank {} status {} cmd {} code {} log {} \nenv {}'.format(
self._rank, self._rank,
self.status(), self.status,
self._entrypoint, self._entrypoint,
self.exit_code(), self.exit_code,
self.errfile, self.errfile,
self._env, ) self._env, )
......
...@@ -20,16 +20,16 @@ class JobMode: ...@@ -20,16 +20,16 @@ class JobMode:
class Job(object): class Job(object):
def __init__(self, id='default', mode=JobMode.COLLECTIVE, np="1"): def __init__(self, jid='default', mode=JobMode.COLLECTIVE, nnodes="1"):
self._mode = mode self._mode = mode
self._id = id self._id = jid
self._replicas = 0 self._replicas = 0
self._replicas_min = self._replicas self._replicas_min = self._replicas
self._replicas_max = self._replicas self._replicas_max = self._replicas
self._elastic = False self._elastic = False
self.set_replicas(str(np)) self.set_replicas(str(nnodes))
def __str__(self): def __str__(self):
return "Job: {}, mode {}, replicas {}[{}:{}], elastic {}".format( return "Job: {}, mode {}, replicas {}[{}:{}], elastic {}".format(
...@@ -64,8 +64,8 @@ class Job(object): ...@@ -64,8 +64,8 @@ class Job(object):
def replicas(self, replicas): def replicas(self, replicas):
self._replicas = replicas self._replicas = replicas
def set_replicas(self, np: str): def set_replicas(self, nnodes: str):
np = str(np) if np else '1' np = str(nnodes) if nnodes else '1'
if ':' in np: if ':' in np:
nps = np.split(':') nps = np.split(':')
......
...@@ -34,7 +34,7 @@ class PodSepc(object): ...@@ -34,7 +34,7 @@ class PodSepc(object):
#self.status: Status = None #self.status: Status = None
self._rank = -1 self._rank = -1
self._init_timeout = 120 # 2 min timeout for each init container self._init_timeout = None
self._restart = -1 self._restart = -1
self._replicas = 0 # number of containers self._replicas = 0 # number of containers
self._exit_code = 0 self._exit_code = 0
...@@ -45,15 +45,15 @@ class Pod(PodSepc): ...@@ -45,15 +45,15 @@ class Pod(PodSepc):
super().__init__() super().__init__()
def __str__(self): def __str__(self):
return "Pod: {}, replicas {}, status {}".format(self.name, return "Pod: {}, replicas {}, status {}".format(
self.replicas, self.name, self.replicas, self.status)
self.status())
def failed_container(self): def failed_container(self):
cs = []
for c in self._containers: for c in self._containers:
if c.status() == Status.FAILED: if c.status == Status.FAILED:
return c cs.append(c)
return None return cs
@property @property
def name(self): def name(self):
...@@ -65,7 +65,7 @@ class Pod(PodSepc): ...@@ -65,7 +65,7 @@ class Pod(PodSepc):
@replicas.setter @replicas.setter
def replicas(self, r): def replicas(self, r):
self._replicas = r self._replicas = max(r, 1)
@property @property
def rank(self): def rank(self):
...@@ -98,13 +98,15 @@ class Pod(PodSepc): ...@@ -98,13 +98,15 @@ class Pod(PodSepc):
@property @property
def exit_code(self): def exit_code(self):
for c in self._containers: for c in self._containers:
if c.exit_code() != 0: if c.exit_code != 0:
return c.exit_code() return c.exit_code
return 0 return 0
def deploy(self): def deploy(self):
# init container should stop before run containers
for i in self._init_containers: for i in self._init_containers:
i.start(self._init_timeout) i.start()
i.wait(self._init_timeout)
for c in self._containers: for c in self._containers:
c.start() c.start()
...@@ -120,6 +122,7 @@ class Pod(PodSepc): ...@@ -120,6 +122,7 @@ class Pod(PodSepc):
for c in self._containers: for c in self._containers:
c.wait(None) c.wait(None)
@property
def status(self): def status(self):
if self.is_failed(): if self.is_failed():
return Status.FAILED return Status.FAILED
...@@ -127,6 +130,9 @@ class Pod(PodSepc): ...@@ -127,6 +130,9 @@ class Pod(PodSepc):
if self.is_completed(): if self.is_completed():
return Status.COMPLETED return Status.COMPLETED
if self.is_running():
return Status.RUNNING
return Status.READY return Status.READY
def reset(self): def reset(self):
...@@ -135,31 +141,31 @@ class Pod(PodSepc): ...@@ -135,31 +141,31 @@ class Pod(PodSepc):
def is_failed(self): def is_failed(self):
for c in self._containers: for c in self._containers:
if c.status() == Status.FAILED: if c.status == Status.FAILED:
return True return True
return False return False
def is_completed(self): def is_completed(self):
for c in self._containers: for c in self._containers:
if c.status() != Status.COMPLETED: if c.status != Status.COMPLETED:
return False
return True
def is_running(self):
for c in self._containers:
if c.status != Status.RUNNING:
return False return False
return True return True
def logs(self, idx=None): def logs(self, idx=None):
if idx is None: if idx is None:
if self.failed_container(): self._containers[0].logs()
self.failed_container().logs()
else:
self._containers[0].logs()
else: else:
self._containers[idx].logs() self._containers[idx].logs()
def tail(self, idx=None): def tail(self, idx=None):
if idx is None: if idx is None:
if self.failed_container(): self._containers[0].tail()
self.failed_container().tail()
else:
self._containers[0].tail()
else: else:
self._containers[idx].tail() self._containers[idx].tail()
...@@ -175,10 +181,10 @@ class Pod(PodSepc): ...@@ -175,10 +181,10 @@ class Pod(PodSepc):
end = time.time() + timeout end = time.time() + timeout
while timeout < 0 or time.time() < end: while timeout < 0 or time.time() < end:
for c in self._containers: for c in self._containers:
if c.status() in any_list: if c.status in any_list:
return c.status() return c.status
s = [c.status() for c in self._containers] s = [c.status for c in self._containers]
if len(set(s)) == 1 and s[0] in all_list: if len(set(s)) == 1 and s[0] in all_list:
return s[0] return s[0]
......
...@@ -30,15 +30,26 @@ def process_args(ctx): ...@@ -30,15 +30,26 @@ def process_args(ctx):
argdev = ctx.args.devices argdev = ctx.args.devices
if argdev: if argdev:
ctx.node.device.labels = argdev.split(',') ctx.node.device.labels = argdev.split(',')
ctx.node.device.count = len(ctx.node.device.labels)
ctx.logger.debug('Device reset by args {}'.format(argdev)) ctx.logger.debug('Device reset by args {}'.format(argdev))
def collective_compatible(ctx): def collective_compatible(ctx):
if 'PADDLE_TRAINER_ENDPOINTS' in ctx.envs: if 'PADDLE_TRAINER_ENDPOINTS' in ctx.envs:
ctx.master = ctx.envs['PADDLE_TRAINER_ENDPOINTS'].split(',')[0] eps = ctx.envs['PADDLE_TRAINER_ENDPOINTS'].split(',')
hosts = set([h.split(':')[0] for h in eps])
ctx.args.master = eps[0] if ':' in eps[0] else '{}:6768'.format(eps[0])
ctx.args.nnodes = len(hosts)
ctx.logger.info('args reset by env PADDLE_TRAINER_ENDPOINTS\n{}'.format(
eps))
'''
if 'DISTRIBUTED_TRAINER_ENDPOINTS' in ctx.envs: if 'DISTRIBUTED_TRAINER_ENDPOINTS' in ctx.envs:
ctx.master = ctx.envs['DISTRIBUTED_TRAINER_ENDPOINTS'].split(',')[0] eps = ctx.envs['DISTRIBUTED_TRAINER_ENDPOINTS'].split(',')
hosts = set([h.split(':')[0] for h in eps])
ctx.args.master = eps[0]
ctx.args.nnodes = len(hosts)
ctx.logger.info(
'args reset by env DISTRIBUTED_TRAINER_ENDPOINTS\n{}'.format(eps))
'''
def rewrite_host_ip(ctx): def rewrite_host_ip(ctx):
......
...@@ -11,15 +11,3 @@ ...@@ -11,15 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .pod import Pod
from .job import Job
from .container import Container
from .status import Status
__all__ = [
'Pod',
'Job',
'Container',
'Status',
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser, REMAINDER
import os, copy
from paddle.distributed.run import plugins
from .node import Node
from .status import Status
import logging
class Context(object):
def __init__(self, enable_plugin=True):
os.environ.pop('http_proxy', None)
os.environ.pop('https_proxy', None)
self.args = self.parse_args()
self.envs = self.fetch_envs()
self.logger = self.get_logger()
self.node = Node()
self.status = Status()
self.set_env_in_args()
# design for event queue, later
self.events = []
if enable_plugin:
self._enable_plugin()
def get_envs(self):
return self.envs.copy()
def _enable_plugin(self):
for pl in plugins.enabled_plugins:
pl(self)
def parse_args(self):
parser = ArgumentParser()
base_group = parser.add_argument_group("Base Parameters")
base_group.add_argument(
"--master",
type=str,
default=None,
help="the master/rendezvous server, ip:port")
base_group.add_argument(
"--rank", type=int, default=-1, help="the peer rank")
base_group.add_argument(
"--log", type=str, default="INFO", help="log level. Default INFO")
base_group.add_argument(
"--np",
type=str,
default="1",
help="the number of peers, i.e. pod/node number")
base_group.add_argument(
"--nproc_per_node",
type=int,
default=None,
help="the number of processes in a pod")
base_group.add_argument(
"--log_dir",
type=str,
default="log",
help="the path for each process's log. Default ./log")
base_group.add_argument(
"--mode",
type=str,
default="collective",
help="run mode of the job, collective/ps/ps-heter")
base_group.add_argument(
"--id",
type=str,
default="default",
help="unique id of the job. Default default")
base_group.add_argument(
"--devices",
type=str,
default=None,
help="accelerate devices. as --gpus,npus,xps")
base_group.add_argument(
"--host", type=str, default=None, help="host ip")
base_group.add_argument(
"training_script",
type=str,
help="the full path of py script,"
"followed by arguments for the "
"training script")
base_group.add_argument('training_script_args', nargs=REMAINDER)
ps_group = parser.add_argument_group("Parameter-Server Parameters")
# for parameter server
ps_group.add_argument(
"--servers",
type=str,
default='',
help="servers endpoints full list")
ps_group.add_argument(
"--trainers",
type=str,
default='',
help="trainers endpoints full list")
ps_group.add_argument(
"--trainer_num", type=int, default=None, help="number of trainers")
ps_group.add_argument(
"--server_num", type=int, default=None, help="number of servers")
ps_group.add_argument(
"--gloo_port", type=int, default=6767, help="gloo http port")
ps_group.add_argument(
"--with_gloo", type=str, default="0", help="use gloo or not")
# parameter elastic mode
elastic_group = parser.add_argument_group("Elastic Parameters")
elastic_group.add_argument(
"--max_restart",
type=int,
default=3,
help="the times can restart. Default 3")
elastic_group.add_argument(
"--elastic_level",
type=int,
default=-1,
help="elastic level: -1 disable, 0 failed exit, peers hold, 1 internal restart"
)
elastic_group.add_argument(
"--elastic_timeout",
type=int,
default=30,
help="seconds to wait before elastic perform training")
return parser.parse_args()
def _valide_env(self, key):
if key in ['POD_IP']:
return True
if key.endswith('_VISIBLE_DEVICES'):
return True
if key.startswith('PADDLE_'):
return True
return False
def fetch_envs(self):
ge = os.environ.copy()
black_env_list = ['http_proxy', 'https_proxy']
for key in black_env_list:
ge.pop(key, None)
return ge
'''
# use black list instead white list
return {k: ge[k] for k in ge if self._valide_env(k)}
'''
def get_logger(self, level=logging.INFO):
logger = logging.getLogger("PADDLERUN")
logger.setLevel(self.args.log.upper() or level)
formatter = logging.Formatter(
fmt='%(name)s %(levelname)s %(asctime)s %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
return logger
def set_env_in_args(self):
env_args = {
'POD_IP': 'host',
'PADDLE_MASTER': 'master',
'PADDLE_DEVICES': 'devices',
'PADDLE_NP': 'np',
'PADDLE_MODE': 'mode',
'PADDLE_LOG': 'log',
'PADDLE_NPROC_PER_NODE': 'nproc_per_node',
'PADDLE_JOB_ID': 'id',
'PADDLE_RANK': 'rank',
'PADDLE_LOG_DIR': 'log_dir',
'PADDLE_MAX_RESTlRT': 'max_restart',
'PADDLE_ELASTIC_LEVEL': 'elastic_level',
'PADDLE_ELASTIC_TIMEOUT': 'elastic_timeout',
'PADDLE_SERVER_NUM': 'server_num',
'PADDLE_TRAINER_NUM': 'trainer_num',
'PADDLE_SERVERS_ENDPOINTS': 'servers',
'PADDLE_TRAINERS_ENDPOINTS': 'trainers',
'PADDLE_GLOO_PORT': 'gloo_port',
'PADDLE_WITH_GLOO': 'with_gloo',
}
for k, v in env_args.items():
if k in self.envs:
setattr(self.args, v, self.envs[k])
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
def get_local_ip(ctx):
_, ip = _get_host_name_ip()
ctx.args.host = ip
ctx.envs["POD_IP"] = ip
def _get_host_name_ip():
try:
host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name)
return host_name, host_ip
except:
return None
...@@ -949,7 +949,7 @@ if (WITH_DISTRIBUTE AND NOT APPLE) ...@@ -949,7 +949,7 @@ if (WITH_DISTRIBUTE AND NOT APPLE)
endif() endif()
# setting timeout value as 15S # setting timeout value as 15S
set_tests_properties(test_run PROPERTIES TIMEOUT 200) set_tests_properties(test_run PROPERTIES TIMEOUT 120)
set_tests_properties(test_sync_batch_norm_op PROPERTIES TIMEOUT 120) set_tests_properties(test_sync_batch_norm_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_cross_op PROPERTIES TIMEOUT 120) set_tests_properties(test_cross_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_lod_tensor_to_selected_rows PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_lod_tensor_to_selected_rows PROPERTIES TIMEOUT 200)
......
...@@ -17,5 +17,4 @@ ...@@ -17,5 +17,4 @@
set -e set -e
# use default values # use default values
# FIXME: random fails on Unknown command lines -c (or -m). # FIXME: random fails on Unknown command lines -c (or -m).
launch_py=${PADDLE_BINARY_DIR}/python/paddle/distributed/launch.py CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch c_comm_init_op.py
CUDA_VISIBLE_DEVICES=0,1 python ${launch_py} c_comm_init_op.py
...@@ -39,9 +39,7 @@ import os ...@@ -39,9 +39,7 @@ import os
env = os.environ.copy() env = os.environ.copy()
assert "PADDLE_PSERVERS_IP_PORT_LIST" in env assert "PADDLE_PSERVERS_IP_PORT_LIST" in env
assert "PADDLE_TRAINER_ENDPOINTS" in env assert "PADDLE_TRAINER_ENDPOINTS" in env
#assert "PADDLE_PSERVER_ENDPOINTS" in env assert "PADDLE_ROLE" in env
#assert "PADDLE_TRAINER_ENDPOINTS" in env
#assert "PADDLE_ROLE" in env
#assert "PADDLE_RANK" in env #assert "PADDLE_RANK" in env
''' '''
...@@ -62,27 +60,24 @@ class Collective_Test(unittest.TestCase): ...@@ -62,27 +60,24 @@ class Collective_Test(unittest.TestCase):
write_file(pyname, colpyfile) write_file(pyname, colpyfile)
def pdrun(self, args, env=None): def pdrun(self, args, env=None):
cmd = [sys.executable.split('/')[-1], "-m", "paddle.distributed.run"] cmd = [sys.executable.split('/')[-1], "-m", "paddle.distributed.launch"]
if args: if args:
cmd.extend(args.split(" ")) cmd.extend(args.split(" "))
cmd.extend([pyname]) cmd.extend([pyname])
proc = subprocess.Popen(cmd, env) proc = subprocess.Popen(cmd, env)
return proc return proc
'''
def test_collective_1(self): def test_collective_1(self):
args = "--id test1" args = "--job_id test1"
p = self.pdrun(args) p = self.pdrun(args)
p.wait() p.wait()
self.assertTrue(p.poll() == 0) self.assertTrue(p.poll() == 0)
'''
def test_collective_2(self): def test_collective_2(self):
if os.path.exists('./log'): if os.path.exists('./log'):
shutil.rmtree('./log') shutil.rmtree('./log')
args = "--id test2 --devices 0,1,2" args = "--job_id test2 --devices 0,1,2"
p = self.pdrun(args) p = self.pdrun(args)
p.wait() p.wait()
self.assertTrue(p.poll() == 0) self.assertTrue(p.poll() == 0)
...@@ -95,7 +90,7 @@ class Collective_Test(unittest.TestCase): ...@@ -95,7 +90,7 @@ class Collective_Test(unittest.TestCase):
shutil.rmtree('./log') shutil.rmtree('./log')
port = random.randrange(6000, 8000) port = random.randrange(6000, 8000)
args = "--id test3 --devices 0,1 --master 127.0.0.1:{} --np 2".format( args = "--job_id test3 --devices 0,1 --master 127.0.0.1:{} --np 2".format(
port) port)
p1 = self.pdrun(args) p1 = self.pdrun(args)
p2 = self.pdrun(args) p2 = self.pdrun(args)
...@@ -113,14 +108,13 @@ class PS_Test(unittest.TestCase): ...@@ -113,14 +108,13 @@ class PS_Test(unittest.TestCase):
write_file(pyname, pspyfile) write_file(pyname, pspyfile)
def pdrun(self, args, env=None): def pdrun(self, args, env=None):
cmd = [sys.executable.split('/')[-1], "-m", "paddle.distributed.run"] cmd = [sys.executable.split('/')[-1], "-m", "paddle.distributed.launch"]
if args: if args:
cmd.extend(args.split(" ")) cmd.extend(args.split(" "))
cmd.extend([pyname]) cmd.extend([pyname])
proc = subprocess.Popen(cmd, env) proc = subprocess.Popen(cmd, env)
return proc return proc
'''
def test_ps_1(self): def test_ps_1(self):
args = "--mode ps" args = "--mode ps"
p = self.pdrun(args) p = self.pdrun(args)
...@@ -131,21 +125,20 @@ class PS_Test(unittest.TestCase): ...@@ -131,21 +125,20 @@ class PS_Test(unittest.TestCase):
if os.path.exists('./log'): if os.path.exists('./log'):
shutil.rmtree('./log') shutil.rmtree('./log')
args = "--id ps2 --server_num=2 --trainer_num=2" args = "--job_id ps2 --server_num=2 --trainer_num=2"
p = self.pdrun(args) p = self.pdrun(args)
p.wait() p.wait()
self.assertTrue(p.poll() == 0) self.assertTrue(p.poll() == 0)
c = get_files('log', 'ps2') c = get_files('log', 'ps2')
self.assertTrue(len(c) == 5) self.assertTrue(len(c) == 5)
'''
def test_ps_3(self): def test_ps_3(self):
if os.path.exists('./log'): if os.path.exists('./log'):
shutil.rmtree('./log') shutil.rmtree('./log')
port = random.randrange(6000, 8000) port = random.randrange(6000, 8000)
args = "--id ps3 --master 127.0.0.1:{} --np 2 --server_num=1 --trainer_num=1".format( args = "--job_id ps3 --master 127.0.0.1:{} --np 2 --server_num=1 --trainer_num=1".format(
port) port)
p1 = self.pdrun(args) p1 = self.pdrun(args)
p2 = self.pdrun(args) p2 = self.pdrun(args)
...@@ -161,7 +154,7 @@ class PS_Test(unittest.TestCase): ...@@ -161,7 +154,7 @@ class PS_Test(unittest.TestCase):
if os.path.exists('./log'): if os.path.exists('./log'):
shutil.rmtree('./log') shutil.rmtree('./log')
args = "--id ps4 --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903" args = "--job_id ps4 --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903"
p1 = self.pdrun(args) p1 = self.pdrun(args)
p1.wait() p1.wait()
self.assertTrue(p1.poll() == 0) self.assertTrue(p1.poll() == 0)
......
...@@ -282,6 +282,12 @@ packages=['paddle', ...@@ -282,6 +282,12 @@ packages=['paddle',
'paddle.distribution', 'paddle.distribution',
'paddle.distributed.sharding', 'paddle.distributed.sharding',
'paddle.distributed.fleet', 'paddle.distributed.fleet',
'paddle.distributed.launch',
'paddle.distributed.launch.context',
'paddle.distributed.launch.controllers',
'paddle.distributed.launch.job',
'paddle.distributed.launch.plugins',
'paddle.distributed.launch.utils',
'paddle.distributed.fleet.base', 'paddle.distributed.fleet.base',
'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.elastic',
'paddle.distributed.fleet.meta_optimizers', 'paddle.distributed.fleet.meta_optimizers',
...@@ -727,7 +733,7 @@ with redirect_stdout(): ...@@ -727,7 +733,7 @@ with redirect_stdout():
}, },
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'fleetrun = paddle.distributed.fleet.launch:launch' 'fleetrun = paddle.distributed.launch.__main__:launch'
] ]
}, },
classifiers=[ classifiers=[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册