args_envs.py 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from argparse import ArgumentParser, REMAINDER

env_args_mapping = {
    'POD_IP': 'host',
    'PADDLE_MASTER': 'master',
    'PADDLE_DEVICES': 'devices',
    'PADDLE_NNODES': 'nnodes',
23
    'PADDLE_RUN_MODE': 'run_mode',
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    'PADDLE_LOG_LEVEL': 'log_level',
    'PADDLE_NPROC_PER_NODE': 'nproc_per_node',
    'PADDLE_JOB_ID': 'job_id',
    'PADDLE_RANK': 'rank',
    'PADDLE_LOG_DIR': 'log_dir',
    'PADDLE_MAX_RESTART': 'max_restart',
    'PADDLE_ELASTIC_LEVEL': 'elastic_level',
    'PADDLE_ELASTIC_TIMEOUT': 'elastic_timeout',
    'PADDLE_SERVER_NUM': 'server_num',
    'PADDLE_TRAINER_NUM': 'trainer_num',
    'PADDLE_SERVERS_ENDPOINTS': 'servers',
    'PADDLE_TRAINERS_ENDPOINTS': 'trainers',
    'PADDLE_GLOO_PORT': 'gloo_port',
    'PADDLE_WITH_GLOO': 'with_gloo',
}


def fetch_envs():
    os.environ.pop('http_proxy', None)
    os.environ.pop('https_proxy', None)

    return os.environ.copy()


def parse_args():
    parser = ArgumentParser()

    base_group = parser.add_argument_group("Base Parameters")

    base_group.add_argument(
        "--master",
        type=str,
        default=None,
        help="the master/rendezvous server, ip:port")

    base_group.add_argument(
        "--legacy", type=bool, default=False, help="use legacy launch")

    base_group.add_argument(
63
        "--rank", type=int, default=-1, help="the node rank")
64 65 66 67 68 69 70 71

    base_group.add_argument(
        "--log_level", type=str, default="INFO", help="log level. Default INFO")

    base_group.add_argument(
        "--nnodes",
        type=str,
        default="1",
72
        help="the number of nodes, i.e. pod/node number")
73 74 75 76 77 78 79 80 81 82 83 84 85

    base_group.add_argument(
        "--nproc_per_node",
        type=int,
        default=None,
        help="the number of processes in a pod")

    base_group.add_argument(
        "--log_dir",
        type=str,
        default="log",
        help="the path for each process's log. Default ./log")
    base_group.add_argument(
86
        "--run_mode",
87
        type=str,
K
kuizhiqing 已提交
88
        default=None,
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
        help="run mode of the job, collective/ps/ps-heter")

    base_group.add_argument(
        "--job_id",
        type=str,
        default="default",
        help="unique id of the job. Default default")

    base_group.add_argument(
        "--devices",
        type=str,
        default=None,
        help="accelerate devices. as --gpus,npus,xps")

    base_group.add_argument("--host", type=str, default=None, help="host ip")

    base_group.add_argument(
        "training_script",
        type=str,
        help="the full path of py script,"
        "followed by arguments for the "
        "training script")

    base_group.add_argument('training_script_args', nargs=REMAINDER)

    ps_group = parser.add_argument_group("Parameter-Server Parameters")
    # for parameter server
    ps_group.add_argument(
        "--servers", type=str, default='', help="servers endpoints full list")
    ps_group.add_argument(
        "--trainers", type=str, default='', help="trainers endpoints full list")

    ps_group.add_argument(
        "--trainer_num", type=int, default=None, help="number of trainers")
    ps_group.add_argument(
        "--server_num", type=int, default=None, help="number of servers")
    ps_group.add_argument(
        "--gloo_port", type=int, default=6767, help="gloo http port")
    ps_group.add_argument(
K
kuizhiqing 已提交
128
        "--with_gloo", type=str, default="1", help="use gloo or not")
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

    # parameter elastic mode
    elastic_group = parser.add_argument_group("Elastic Parameters")
    elastic_group.add_argument(
        "--max_restart",
        type=int,
        default=3,
        help="the times can restart. Default 3")

    elastic_group.add_argument(
        "--elastic_level",
        type=int,
        default=-1,
        help="elastic level: -1 disable, 0 failed exit, peers hold, 1 internal restart"
    )

    elastic_group.add_argument(
        "--elastic_timeout",
        type=int,
        default=30,
149
        help="seconds to wait before elastic job begin to train")
150 151

    return parser.parse_known_args()