launch.py 22.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
r"""
15
fleetrun is a module that spawns multiple distributed
16 17
process on each training node for gpu training and cpu training.
Usage:
18
    In both of single node training or multiple node training, this module
19 20 21 22 23 24 25 26
launch a process on each of the given gpu card or cpu machine.
    GPU training:
    1. for single node training with all visible gpu cards:
       fleetrun your_training_py (arg1 arg2 and all others)
    2. for single node training with [0,4) cards
       fleetrun --gpus="0,1,2,3" your_training_py (arg1 arg2 and all others)
    3. for multiple node training such as two node:192.168.0.16, 192.168.0.17
        on 192.168.0.16:
27
            fleetrun --ips="192.168.0.16,192.168.0.17" \
28 29 30 31 32 33
                your_training_py (arg1 arg2 and all others)
        on 192.168.0.17:
            fleetrun --ips="192.168.0.16,192.168.0.17" \
                your_training_py (arg1 arg2 and all others)
    CPU training:
    1. for single node training with multi servers and workers:
34
        fleetrun --server_num=2 --worker_num=2 your_training_py (arg1 arg2 and all others)
35
    2. for multiple node training such as two node:192.168.0.16, 192.168.0.17 \
36
        with 2 servers and 4 workers.
37
        on 192.168.0.16:
38 39
            fleetrun --servers="192.168.0.16:6170,192.168.0.17:6170" \
                --workers="192.168.0.16,192.168.0.17,192.168.0.16,192.168.0.17" \
40 41 42
                your_training_py (arg1 arg2 and all others)
        on 192.168.0.17:
            fleetrun --servers="192.168.0.16:6170,192.168.0.17:6171" \
43 44 45 46 47 48 49 50 51 52 53
                --workers="192.168.0.16,192.168.0.17,192.168.0.16,192.168.0.17" \
                your_training_py (arg1 arg2 and all others)
    3. use gloo backend for multiple node training such as two node:192.168.0.16, 192.168.0.17 \
        with 2 servers and 4 workers. (workers should set port)
        on 192.168.0.16:
            fleetrun --servers="192.168.0.16:6170,192.168.0.17:6170" \
                --workers="192.168.0.16:6171,192.168.0.17:6171,192.168.0.16:6172,192.168.0.17:6172" \
                your_training_py (arg1 arg2 and all others)
        on 192.168.0.17:
            fleetrun --servers="192.168.0.16:6170,192.168.0.17:6170" \
                --workers="192.168.0.16:6171,192.168.0.17:6171,192.168.0.16:6172,192.168.0.17:6172" \
54 55 56 57
                your_training_py (arg1 arg2 and all others)
"""

from __future__ import print_function
58 59

import shutil
60
import sys
61
import tempfile
62 63 64 65 66 67 68 69 70
from sys import version
import subprocess
import os
import time
import six
import copy
from argparse import ArgumentParser, REMAINDER
import paddle
import paddle.fluid as fluid
71
from paddle.distributed.fleet import launch_utils
72

73
# TODO(danleifeng): Don't import * from a module
74 75
from paddle.distributed.fleet.launch_utils import *
import paddle.distributed.fleet.cloud_utils as cloud_utils
76
import paddle.distributed.fleet.ascend_utils as ascend_utils
77

K
kuizhiqing 已提交
78
from paddle.distributed.fleet.elastic import enable_elastic, launch_elastic
79

80 81
__all__ = []

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

def _print_arguments(args):
    print("-----------  Configuration Arguments -----------")
    for arg, value in sorted(six.iteritems(vars(args))):
        print("%s: %s" % (arg, value))
    print("------------------------------------------------")


def _parse_args():
    """
    Helper function parsing the command line options
    @retval ArgumentParser
    """
    parser = ArgumentParser(
        description='''start paddle training using multi-process mode.
see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/training/cluster_howto.html#permalink-8--nccl2-
''')
99
    base_group = parser.add_argument_group("Base Parameters")
100

101 102
    base_group.add_argument(
        "--log_dir",
103
        type=str,
104
        default="log",
G
Guoxia Wang 已提交
105
        help="The path for each process's log. Default --log_dir=log/")
106

107 108 109 110 111 112 113 114
    base_group.add_argument(
        "--nproc_per_node",
        type=int,
        default=None,
        help="The number of processes to launch on a node."
        "In gpu training, it should be less or equal to the gpus number of you system(or you set by --gpus). And so each process can"
        " bound to one or average number of gpus.")

115 116 117
    base_group.add_argument(
        "--run_mode",
        type=str,
G
gongweibao 已提交
118
        default=None,
119 120
        help="run mode of job, can be:collective/ps/ps-heter")

121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    if fluid.core.is_compiled_with_cuda():
        base_group.add_argument(
            "--gpus",
            type=str,
            default=None,
            help="It's for gpu training."
            "For example:"
            "--gpus=\"0,1,2,3\" will launch four training processes each bound to one gpu."
        )
        base_group.add_argument("--selected_gpus", dest="gpus")

    if fluid.core.is_compiled_with_xpu():
        base_group.add_argument(
            "--xpus",
            type=str,
            default=None,
            help="It's for xpu training. For example: "
            "--xpus=\"0,1,2,3\" will launch four training processes each bound to one xpu."
        )
        base_group.add_argument("--selected_xpus", dest="xpus")
141

142
    base_group.add_argument(
143 144 145 146 147 148 149
        "training_script",
        type=str,
        help="The full path to the single GPU training "
        "program/script to be launched in parallel, "
        "followed by all the arguments for the "
        "training script")

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
    base_group.add_argument('training_script_args', nargs=REMAINDER)

    # Optional arguments for the launch helper
    # for collective
    collective_group = parser.add_argument_group("Collective Parameters")
    collective_group.add_argument(
        "--ips",
        type=str,
        default="127.0.0.1",
        help="Paddle cluster nodes ips, such as 192.168.0.16,192.168.0.17..")

    ps_group = parser.add_argument_group("Parameter-Server Parameters")
    # for parameter server
    ps_group.add_argument(
        "--servers", type=str, default="", help="User defined servers ip:port")
    ps_group.add_argument(
        "--workers", type=str, default="", help="User defined workers ip:port")
    ps_group.add_argument(
        "--heter_workers",
        type=str,
        default="",
        help="User defined heter workers ip:port")

    ps_group.add_argument("--worker_num", type=int, help="number of workers")
    ps_group.add_argument("--server_num", type=int, help="number of servers")
    ps_group.add_argument(
        "--heter_worker_num", type=int, help="number of heter_workers")
177
    ps_group.add_argument("--http_port", type=int, help="Gloo http Port")
178

179 180 181 182 183 184 185 186 187 188 189 190
    # parameter elastic mode
    elastic_group = parser.add_argument_group("Elastic Parameters")
    elastic_group.add_argument(
        "--elastic_server", type=str, help="etcd server host:port")
    elastic_group.add_argument("--job_id", type=str, help="job unique id")
    elastic_group.add_argument("--np", type=int, help="job pod/node number")
    elastic_group.add_argument("--scale", type=int, default=0, help="scale np")
    elastic_group.add_argument(
        "--host", type=str, help="bind host, default to POD_IP env")
    elastic_group.add_argument(
        "--force", type=bool, default=False, help="update np force")

191 192 193
    return parser.parse_args()


194
def get_cluster_from_args(args, device_mode, devices_per_proc):
195 196 197 198
    node_ips = [x.strip() for x in args.ips.split(',')]
    if len(node_ips) == 1:
        node_ip = node_ips[0]
    else:
199 200 201 202
        if args.host:
            node_ip = args.host
        else:
            _, node_ip = get_host_name_ip()
203

204
    assert node_ip in node_ips, "Can't find your local ip {%s} in node_ips: {%s}" \
205
        % (node_ip, node_ips)
206 207
    node_rank = node_ips.index(node_ip)

208
    logger.debug("parsed from args: node_ips:{} node_ip:{} node_rank:{}".format(
209 210 211 212 213
        node_ips, node_ip, node_rank))

    free_ports = None
    if not cloud_utils.use_paddlecloud() and len(
            node_ips) <= 1 and os.environ.get('FLAGS_START_PORT') is None:
214
        free_ports = find_free_ports(len(devices_per_proc))
215 216 217 218 219
        if free_ports is not None:
            free_ports = list(free_ports)
    else:
        start_port = 6070
        if os.environ.get('FLAGS_START_PORT') is not None:
220
            start_port = int(os.environ.get('FLAGS_START_PORT'))
221

222 223 224
        free_ports = [
            x for x in range(start_port, start_port + len(devices_per_proc))
        ]
225

226 227 228
    trainer_endpoints = []
    for ip in node_ips:
        trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports])
229 230
    return get_cluster(node_ips, node_ip, trainer_endpoints, device_mode,
                       devices_per_proc)
231 232


K
kuizhiqing 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
def launch_collective(args):
    # parse arguments, used for cloud-single-machine and local
    (device_mode, devices_per_proc) = launch_utils.get_device_proc_info(args)
    trainers_num = cloud_utils.get_trainers_num()
    logger.debug("parsed from args trainerss_num:{} mode:{} devices:{}".format(
        trainers_num, device_mode, devices_per_proc))

    cluster = None
    pod = None

    start_port = 6170
    if os.environ.get('FLAGS_START_PORT') is not None:
        start_port = os.environ.get('FLAGS_START_PORT')
    if cloud_utils.use_paddlecloud() and trainers_num != 1:
        cluster, pod = cloud_utils.get_cloud_cluster(
            args.ips, device_mode, devices_per_proc, start_port)
        logger.debug("get cluster from cloud:{}".format(cluster))
    elif device_mode == DeviceMode.ASCEND_NPU:
        # for ascend
        cluster, pod = ascend_utils.get_cloud_cluster(
            rank_table_file=os.getenv("RANK_TABLE_FILE", None),
            device_mode=device_mode,
            start_port=start_port)
    else:
        # trainers_num = 1 or not use paddlecloud ips="a,b"
        cluster, pod = get_cluster_from_args(args, device_mode,
                                             devices_per_proc)
        logger.debug("get cluster from args:{}".format(cluster))

    global_envs = copy.copy(os.environ.copy())
    gloo_rendezvous_dir = tempfile.mkdtemp()
    # add gloo env
    global_envs["PADDLE_WITH_GLOO"] = str(os.getenv("PADDLE_WITH_GLOO", "0"))
    global_envs["PADDLE_GLOO_RENDEZVOUS"] = "3"
    global_envs["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir

    procs = start_local_trainers(
        cluster,
        pod,
        training_script=args.training_script,
        training_script_args=args.training_script_args,
        log_dir=args.log_dir,
        envs=global_envs)

    for idx, proc in enumerate(procs):
        print("launch proc_id:{} idx:{}".format(proc.proc.pid, idx))
279

K
kuizhiqing 已提交
280
    while True:
K
kuizhiqing 已提交
281 282
        try:
            alive = watch_local_trainers(procs, cluster.trainers_nranks())
283

K
kuizhiqing 已提交
284 285 286 287
            if not alive:
                logger.info("Local processes completed.")
                logger.debug("POD info:{}".format(pod))
                break
288

K
kuizhiqing 已提交
289 290 291 292 293 294
            time.sleep(3)

        except:
            logger.warning("Terminating... exit")
            terminate_local_procs(procs)
            exit(1)
K
kuizhiqing 已提交
295 296 297

    if os.path.exists(gloo_rendezvous_dir):
        shutil.rmtree(gloo_rendezvous_dir)
298

299

300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
def launch_ps(args, distribute_mode):
    cloud_flag = cloud_utils.use_paddlecloud()

    # for ps-cpu on paddlecloud
    if cloud_flag and distribute_mode == DistributeMode.PS:
        direct_start(args)
        return
    elif cloud_flag and distribute_mode == DistributeMode.PS_HETER:
        cloud_ps_heter_env_set(args)
        args.workers = os.getenv("PADDLE_TRAINER_ENDPOINTS")
        args.servers = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST")
        args.heter_workers = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST")

    ps_launcher = ParameterServerLauncher(args, distribute_mode)
    ps_launcher.start_ps()
    return


def which_distributed_mode(args):
319 320 321 322 323 324 325 326 327 328
    if args.run_mode is not None:
        assert args.run_mode in ["collective", "ps", "ps-heter"]

    if args.run_mode == "collective":
        return DistributeMode.COLLECTIVE
    elif args.run_mode == "ps":
        return DistributeMode.PS
    elif args.run_mode == "ps-heter":
        return DistributeMode.PS_HETER

329
    ps_args = [
330 331
        '--worker_num', '--server_num', '--heter_worker_num', '--servers',
        '--workers', '--heter_workers', '--http_port'
332
    ]
333
    collective_args = ['--ips']
334

335
    ps_heter_args = ["--heter_worker_num", "--heter_workers"]
336 337 338 339 340 341 342 343

    has_ps_args = [
        ps_arg for ps_arg in ps_args if ps_arg in " ".join(sys.argv[1:-1])
    ]
    has_collective_args = [
        co_arg for co_arg in collective_args
        if co_arg in " ".join(sys.argv[1:-1])
    ]
344 345 346 347 348 349

    if len(has_ps_args) > 1 and len(has_collective_args) > 1:
        raise ValueError(
            "Only one mode(Collective or Parameter-Server) can be selected at the same time, but more than one configuration was received."
        )

350
    if fluid.core.is_compiled_with_cuda():
351
        accelerators = fluid.core.get_cuda_device_count()
B
Baibaifan 已提交
352 353
    elif fluid.core.is_compiled_with_npu():
        accelerators = fluid.core.get_npu_device_count()
354
    elif fluid.core.is_compiled_with_xpu():
355
        accelerators = fluid.core.get_xpu_device_count()
356
    else:
357
        accelerators = 0
358

359 360
    if len(has_ps_args) > 0:
        logger.info(
361 362
            "Run parameter-sever mode. pserver arguments:{}, accelerators count:{}".
            format(has_ps_args, accelerators))
363 364 365 366 367
        has_ps_heter_args = list(set(has_ps_args) & set(ps_heter_args))
        if len(has_ps_heter_args) > 0:
            return DistributeMode.PS_HETER
        else:
            return DistributeMode.PS
368
    elif len(has_collective_args) > 0:
369 370
        logger.info("Run collective mode. gpu arguments:{}, cuda count:{}".
                    format(has_collective_args, accelerators))
371
        return DistributeMode.COLLECTIVE
372
    else:
373 374
        if not fluid.core.is_compiled_with_cuda(
        ) and not fluid.core.is_compiled_with_xpu():
375
            logger.warning(
376
                "Not found distinct arguments and not compiled with cuda or xpu. Default use ps mode"
377 378 379 380
            )
            return DistributeMode.PS
        else:
            logger.warning(
381
                "Not found distinct arguments and compiled with cuda or xpu. Default use collective mode"
382 383
            )
            return DistributeMode.COLLECTIVE
384 385 386


def launch():
G
Guoxia Wang 已提交
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
    """
    Paddle distribution training entry ``python -m paddle.distributed.launch``.
    
    Usage:
        .. code-block:: bash
            :name: code-block-bash1

            python -m paddle.distributed.launch [-h] [--log_dir LOG_DIR] [--nproc_per_node NPROC_PER_NODE] [--run_mode RUN_MODE] [--gpus GPUS]
                             [--selected_gpus GPUS] [--ips IPS] [--servers SERVERS] [--workers WORKERS] [--heter_workers HETER_WORKERS]
                             [--worker_num WORKER_NUM] [--server_num SERVER_NUM] [--heter_worker_num HETER_WORKER_NUM]
                             [--http_port HTTP_PORT] [--elastic_server ELASTIC_SERVER] [--job_id JOB_ID] [--np NP] [--scale SCALE]
                             [--host HOST] [--force FORCE]
                             training_script ...    


    Base Parameters:
        - ``--log_dir``: The path for each process's log. e.g ``--log_dir=output_dir``. Default ``--log_dir=log``.

        - ``--nproc_per_node``: The number of processes to launch on a node. In gpu training, it should be less or equal to the gpus number of you system(or you set by --gpus). And so each process can bound to one or average number of gpus. e.g ``--nproc_per_node=8``

        - ``--run_mode``: run mode of job, can be:collective/ps/ps-heter. e.g ``--run_mode=ps``. Default ``--run_mode=collective``.

        - ``--gpus``: It's for gpu training. e.g ``--gpus=0,1,2,3`` will launch four training processes each bound to one gpu.

        - ``--selected_gpus``: gpus aliases, recommend to use ``--gpus``.
        
        - ``--xpus``: It's for xpu training if xpu is available. e.g ``--xpus=0,1,2,3``.
        
        - ``--selected_xpus``: xpus aliases, recommend to use ``--xpus``.

        - ``training_script``: The full path to the single GPU training program/script to be launched in parallel, followed by all the arguments for the training script. e.g ``traing.py``

        - ``training_script_args``: The args of training_script. e.g ``--lr=0.1``

    Collective Parameters:
        - ``--ips``: Paddle cluster nodes ips, e.g ``--ips=192.168.0.16,192.168.0.17``. Default ``--ips=127.0.0.1``.

    Parameter-Server Parameters:
        - ``--servers``: User defined servers ip:port, e.g ``--servers="192.168.0.16:6170,192.168.0.17:6170"``

        - ``--workers``: User defined workers ip:port, e.g ``--workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172"``

        - ``--heter_workers``: User defined heter workers ip:port, e.g ``--heter_workers="192.168.0.16:6172,192.168.0.17:6172"``

        - ``--worker_num``: Number of workers (It recommend to set when in the emulated distributed environment using single node)

        - ``--server_num``: Number of servers (It recommend to set when in the emulated distributed environment using single node)

        - ``--heter_worker_num``: Number of heter_workers (It recommend to set when in the emulated distributed environment using single node)

        - ``--http_port``: Gloo http Port

    Elastic Parameters:
        - ``--elastic_server``: etcd server host:port, e.g ``--elastic_server=127.0.0.1:2379``

        - ``--job_id``: job unique id, e.g ``--job_id=job1``

        - ``--np``: job pod/node number, e.g ``--np=2``

        - ``--scale``: scale np, not be used now!

        - ``--host``: bind host, default to POD_IP env.

        - ``--force``: update np force, not be used now!

    Returns:
        ``None``

    Examples 1 (collective, single node):
        .. code-block:: bash
            :name: code-block-example-bash1
            
            # For single node training using 4 gpus

            python -m paddle.distributed.launch --gpus=0,1,2,3 train.py --lr=0.01
        
    Examples 2 (collective, multi node):
        .. code-block:: bash
            :name: code-block-example-bash2

            # For multiple node training such as two node:192.168.0.16, 192.168.0.17

            # On 192.168.0.16:

            python -m paddle.distributed.launch --gpus=0,1,2,3 --ips=192.168.0.16,192.168.0.17 train.py --lr=0.01

            # On 192.168.0.17:
            python -m paddle.distributed.launch --gpus=0,1,2,3 --ips=192.168.0.16,192.168.0.17 train.py --lr=0.01
        
    Examples 3 (ps, cpu, single node):
        .. code-block:: bash
            :name: code-block-example-bash3

            # The emulated distributed environment using single node, 2 server and 4 worker
            
            python -m paddle.distributed.launch --server_num=2 --worker_num=4 train.py --lr=0.01
        
    Examples 4 (ps, cpu, multi node):
        .. code-block:: bash
            :name: code-block-example-bash4

            # For multiple node training such as two node:192.168.0.16, 192.168.0.17 with 2 servers and total 4 workers

            # On 192.168.0.16:

            python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01

            # On 192.168.0.17:

            python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01

    Examples 5 (ps, gpu, single node):
        .. code-block:: bash
            :name: code-block-example-bash5

            # The emulated distributed environment using single node, 2 server and 4 worker, each worker use single gpu
            
            export CUDA_VISIBLE_DEVICES=0,1,2,3
            python -m paddle.distributed.launch --server_num=2 --worker_num=4 train.py --lr=0.01
            
    Examples 6 (ps, gpu, multi node):
        .. code-block:: bash
            :name: code-block-example-bash6

            # For multiple node training such as two node:192.168.0.16, 192.168.0.17 with 2 servers and total 4 workers

            # On 192.168.0.16:

            export CUDA_VISIBLE_DEVICES=0,1
            python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01

            # On 192.168.0.17:

            export CUDA_VISIBLE_DEVICES=0,1
            python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01

    Examples 7 (ps-heter, cpu + gpu, single node):
        .. code-block:: bash
            :name: code-block-example-bash7

            # The emulated distributed environment using single node, 2 server and 4 worker, two worker use gpu, two worker use cpu
            
            export CUDA_VISIBLE_DEVICES=0,1
            python -m paddle.distributed.launch --server_num=2 --worker_num=2 --heter_worker_num=2 train.py --lr=0.01
            
    Examples 8 (ps-heter, cpu + gpu, multi node):
        .. code-block:: bash
            :name: code-block-example-bash8

            # For multiple node training such as two node:192.168.0.16, 192.168.0.17 with 2 servers and total 4 workers

            # On 192.168.0.16:

            export CUDA_VISIBLE_DEVICES=0
            python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.17:6171" --heter_workers="192.168.0.16:6172,192.168.0.17:6172" train.py --lr=0.01

            # On 192.168.0.17:

            export CUDA_VISIBLE_DEVICES=0
            python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.17:6171" --heter_workers="192.168.0.16:6172,192.168.0.17:6172" train.py --lr=0.01

    Examples 9 (elastic):
        .. code-block:: bash
            :name: code-block-example-bash9

            python -m paddle.distributed.launch --elastic_server=127.0.0.1:2379 --np=2 --job_id=job1  --gpus=0,1,2,3 train.py
        
    """

556 557 558 559 560
    args = _parse_args()
    logger = get_logger()
    _print_arguments(args)

    distribute_mode = which_distributed_mode(args)
561

K
kuizhiqing 已提交
562 563 564
    if enable_elastic(args, distribute_mode):
        launch_elastic(args, distribute_mode)
        return
565

K
kuizhiqing 已提交
566 567
    if distribute_mode == DistributeMode.COLLECTIVE:
        launch_collective(args)
568
    else:
K
kuizhiqing 已提交
569
        launch_ps(args, distribute_mode)
570 571 572 573


if __name__ == "__main__":
    launch()