diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 95ec31870287c941480dda241ea694b9d1b94f2d..4934f2093f3fb768fa718364f7018cc185035611 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from .spawn import spawn # noqa: F401 +from paddle.distributed.fleet.launch import launch # noqa: F401 from .parallel import init_parallel_env # noqa: F401 from .parallel import get_rank # noqa: F401 @@ -60,6 +61,7 @@ from . import utils # noqa: F401 __all__ = [ # noqa "spawn", + "launch", "scatter", "broadcast", "ParallelEnv", diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py index bc7942826e1eaaeb89dd5854acf8abaa710148fa..2920dd5870ac1aee53e7667fe5cf9238cb2b5963 100644 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -102,8 +102,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra "--log_dir", type=str, default="log", - help="The path for each process's log.If it's not set, the log will printed to default pipe." - ) + help="The path for each process's log. Default --log_dir=log/") base_group.add_argument( "--nproc_per_node", @@ -385,6 +384,175 @@ def which_distributed_mode(args): def launch(): + """ + Paddle distribution training entry ``python -m paddle.distributed.launch``. + + Usage: + .. code-block:: bash + :name: code-block-bash1 + + python -m paddle.distributed.launch [-h] [--log_dir LOG_DIR] [--nproc_per_node NPROC_PER_NODE] [--run_mode RUN_MODE] [--gpus GPUS] + [--selected_gpus GPUS] [--ips IPS] [--servers SERVERS] [--workers WORKERS] [--heter_workers HETER_WORKERS] + [--worker_num WORKER_NUM] [--server_num SERVER_NUM] [--heter_worker_num HETER_WORKER_NUM] + [--http_port HTTP_PORT] [--elastic_server ELASTIC_SERVER] [--job_id JOB_ID] [--np NP] [--scale SCALE] + [--host HOST] [--force FORCE] + training_script ... + + + Base Parameters: + - ``--log_dir``: The path for each process's log. e.g ``--log_dir=output_dir``. Default ``--log_dir=log``. + + - ``--nproc_per_node``: The number of processes to launch on a node. In gpu training, it should be less or equal to the gpus number of you system(or you set by --gpus). And so each process can bound to one or average number of gpus. e.g ``--nproc_per_node=8`` + + - ``--run_mode``: run mode of job, can be:collective/ps/ps-heter. e.g ``--run_mode=ps``. Default ``--run_mode=collective``. + + - ``--gpus``: It's for gpu training. e.g ``--gpus=0,1,2,3`` will launch four training processes each bound to one gpu. + + - ``--selected_gpus``: gpus aliases, recommend to use ``--gpus``. + + - ``--xpus``: It's for xpu training if xpu is available. e.g ``--xpus=0,1,2,3``. + + - ``--selected_xpus``: xpus aliases, recommend to use ``--xpus``. + + - ``training_script``: The full path to the single GPU training program/script to be launched in parallel, followed by all the arguments for the training script. e.g ``traing.py`` + + - ``training_script_args``: The args of training_script. e.g ``--lr=0.1`` + + Collective Parameters: + - ``--ips``: Paddle cluster nodes ips, e.g ``--ips=192.168.0.16,192.168.0.17``. Default ``--ips=127.0.0.1``. + + Parameter-Server Parameters: + - ``--servers``: User defined servers ip:port, e.g ``--servers="192.168.0.16:6170,192.168.0.17:6170"`` + + - ``--workers``: User defined workers ip:port, e.g ``--workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172"`` + + - ``--heter_workers``: User defined heter workers ip:port, e.g ``--heter_workers="192.168.0.16:6172,192.168.0.17:6172"`` + + - ``--worker_num``: Number of workers (It recommend to set when in the emulated distributed environment using single node) + + - ``--server_num``: Number of servers (It recommend to set when in the emulated distributed environment using single node) + + - ``--heter_worker_num``: Number of heter_workers (It recommend to set when in the emulated distributed environment using single node) + + - ``--http_port``: Gloo http Port + + Elastic Parameters: + - ``--elastic_server``: etcd server host:port, e.g ``--elastic_server=127.0.0.1:2379`` + + - ``--job_id``: job unique id, e.g ``--job_id=job1`` + + - ``--np``: job pod/node number, e.g ``--np=2`` + + - ``--scale``: scale np, not be used now! + + - ``--host``: bind host, default to POD_IP env. + + - ``--force``: update np force, not be used now! + + Returns: + ``None`` + + Examples 1 (collective, single node): + .. code-block:: bash + :name: code-block-example-bash1 + + # For single node training using 4 gpus + + python -m paddle.distributed.launch --gpus=0,1,2,3 train.py --lr=0.01 + + Examples 2 (collective, multi node): + .. code-block:: bash + :name: code-block-example-bash2 + + # For multiple node training such as two node:192.168.0.16, 192.168.0.17 + + # On 192.168.0.16: + + python -m paddle.distributed.launch --gpus=0,1,2,3 --ips=192.168.0.16,192.168.0.17 train.py --lr=0.01 + + # On 192.168.0.17: + python -m paddle.distributed.launch --gpus=0,1,2,3 --ips=192.168.0.16,192.168.0.17 train.py --lr=0.01 + + Examples 3 (ps, cpu, single node): + .. code-block:: bash + :name: code-block-example-bash3 + + # The emulated distributed environment using single node, 2 server and 4 worker + + python -m paddle.distributed.launch --server_num=2 --worker_num=4 train.py --lr=0.01 + + Examples 4 (ps, cpu, multi node): + .. code-block:: bash + :name: code-block-example-bash4 + + # For multiple node training such as two node:192.168.0.16, 192.168.0.17 with 2 servers and total 4 workers + + # On 192.168.0.16: + + python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01 + + # On 192.168.0.17: + + python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01 + + Examples 5 (ps, gpu, single node): + .. code-block:: bash + :name: code-block-example-bash5 + + # The emulated distributed environment using single node, 2 server and 4 worker, each worker use single gpu + + export CUDA_VISIBLE_DEVICES=0,1,2,3 + python -m paddle.distributed.launch --server_num=2 --worker_num=4 train.py --lr=0.01 + + Examples 6 (ps, gpu, multi node): + .. code-block:: bash + :name: code-block-example-bash6 + + # For multiple node training such as two node:192.168.0.16, 192.168.0.17 with 2 servers and total 4 workers + + # On 192.168.0.16: + + export CUDA_VISIBLE_DEVICES=0,1 + python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01 + + # On 192.168.0.17: + + export CUDA_VISIBLE_DEVICES=0,1 + python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172" train.py --lr=0.01 + + Examples 7 (ps-heter, cpu + gpu, single node): + .. code-block:: bash + :name: code-block-example-bash7 + + # The emulated distributed environment using single node, 2 server and 4 worker, two worker use gpu, two worker use cpu + + export CUDA_VISIBLE_DEVICES=0,1 + python -m paddle.distributed.launch --server_num=2 --worker_num=2 --heter_worker_num=2 train.py --lr=0.01 + + Examples 8 (ps-heter, cpu + gpu, multi node): + .. code-block:: bash + :name: code-block-example-bash8 + + # For multiple node training such as two node:192.168.0.16, 192.168.0.17 with 2 servers and total 4 workers + + # On 192.168.0.16: + + export CUDA_VISIBLE_DEVICES=0 + python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.17:6171" --heter_workers="192.168.0.16:6172,192.168.0.17:6172" train.py --lr=0.01 + + # On 192.168.0.17: + + export CUDA_VISIBLE_DEVICES=0 + python -m paddle.distributed.launch --servers="192.168.0.16:6170,192.168.0.17:6170" --workers="192.168.0.16:6171,192.168.0.17:6171" --heter_workers="192.168.0.16:6172,192.168.0.17:6172" train.py --lr=0.01 + + Examples 9 (elastic): + .. code-block:: bash + :name: code-block-example-bash9 + + python -m paddle.distributed.launch --elastic_server=127.0.0.1:2379 --np=2 --job_id=job1 --gpus=0,1,2,3 train.py + + """ + args = _parse_args() logger = get_logger() _print_arguments(args)