cluster.py 12.0 KB
Newer Older
T
tangwei 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
T
tangwei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
from __future__ import unicode_literals
T
tangwei 已提交
17

T
tangwei 已提交
18
import copy
T
tangwei 已提交
19 20
import os
import subprocess
C
Chengmo 已提交
21
import warnings
22
import logging
T
tangwei 已提交
23

24 25 26
from paddlerec.core.engine.engine import Engine
from paddlerec.core.factory import TrainerFactory
from paddlerec.core.utils import envs
T
tangwei 已提交
27

28 29 30 31
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger()
logger.setLevel(logging.INFO)

T
tangwei 已提交
32

T
tangwei 已提交
33
class ClusterEngine(Engine):
T
tangwei 已提交
34
    def __init_impl__(self):
C
Chengmo 已提交
35 36 37 38
        self.role = envs.get_runtime_environ("engine_role")
        if self.role == "WORKER":
            return

T
tangwei 已提交
39
        abs_dir = os.path.dirname(os.path.abspath(__file__))
C
Chengmo 已提交
40
        os.environ["abs_dir"] = str(abs_dir)
T
tangwei 已提交
41

C
Chengmo 已提交
42 43 44 45 46
        self.backend = envs.get_runtime_environ("backend")
        if not self.backend:
            self.backend = ""
        self.backend = self.backend.upper()
        if self.backend == "PADDLECLOUD":
T
tangwei 已提交
47
            self.submit_script = os.path.join(abs_dir, "cloud/cluster.sh")
C
Chengmo 已提交
48
        elif self.backend == "KUBERNETES":
J
Jinhua Liang 已提交
49
            self.submit_script = os.path.join(abs_dir, "k8s/cluster.sh")
T
tangwei 已提交
50
        else:
C
Chengmo 已提交
51 52
            raise ValueError("{} can not be supported now".format(
                self.backend))
T
tangwei 已提交
53

T
tangwei 已提交
54
    def start_worker_procs(self):
T
tangwei 已提交
55 56
        trainer = TrainerFactory.create(self.trainer)
        trainer.run()
T
tangwei 已提交
57 58

    def start_master_procs(self):
C
Chengmo 已提交
59 60 61 62 63
        if self.backend == "PADDLECLOUD":
            self.paddlecloud_env_check()
        elif self.backend == "KUBERNETES":
            self.kubernetes_env_check()

T
tangwei 已提交
64 65 66 67
        default_env = os.environ.copy()
        current_env = copy.copy(default_env)
        current_env.pop("http_proxy", None)
        current_env.pop("https_proxy", None)
T
tangwei 已提交
68 69 70 71 72

        cmd = ("bash {}".format(self.submit_script)).split(" ")
        proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
        proc.wait()

J
Jinhua Liang 已提交
73 74
    @staticmethod
    def workspace_replace():
C
Chengmo 已提交
75
        remote_workspace = envs.get_runtime_environ("remote_workspace")
J
Jinhua Liang 已提交
76 77

        for k, v in os.environ.items():
C
Chengmo 已提交
78
            v = v.replace("{workspace}", remote_workspace)
J
Jinhua Liang 已提交
79 80
            os.environ[k] = str(v)

T
tangwei 已提交
81
    def run(self):
C
Chengmo 已提交
82
        if self.role == "MASTER":
T
tangwei 已提交
83 84
            self.start_master_procs()

C
Chengmo 已提交
85
        elif self.role == "WORKER":
T
tangwei 已提交
86 87 88
            self.start_worker_procs()

        else:
T
tangwei 已提交
89
            raise ValueError("role {} error, must in MASTER/WORKER".format(
C
Chengmo 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
                self.role))

    def paddlecloud_env_check(self):
        # get fleet mode
        fleet_mode = envs.get_runtime_environ("fleet_mode")
        # get device
        device = envs.get_runtime_environ("device")
        # get cluster type
        cluster_type = envs.get_runtime_environ("cluster_type")

        cluster_env_check_tool = None
        if cluster_type.upper() == "MPI":
            if device == "CPU" and fleet_mode == "PS":
                cluster_env_check_tool = PaddleCloudMpiEnv()
            else:
                raise ValueError(
C
Chengmo 已提交
106
                    "Paddlecloud with Mpi don't support GPU training, check your config.yaml & backend.yaml"
C
Chengmo 已提交
107 108 109 110
                )
        elif cluster_type.upper() == "K8S":
            if fleet_mode == "PS":
                if device == "CPU":
C
Chengmo 已提交
111
                    cluster_env_check_tool = CloudPsCpuEnv()
C
Chengmo 已提交
112 113 114 115 116 117 118 119 120
                elif device == "GPU":
                    raise ValueError(
                        "PS-GPU on paddlecloud is not supported at this time, comming soon"
                    )
            if fleet_mode == "COLLECTIVE":
                if device == "GPU":
                    cluster_env_check_tool = CloudCollectiveEnv()
                elif device == "CPU":
                    raise ValueError(
C
Chengmo 已提交
121
                        "Unexpected config -> device: CPU with fleet_mode: Collective, check your config.yaml"
C
Chengmo 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
                    )
        else:
            raise ValueError("cluster_type {} error, must in MPI/K8S".format(
                cluster_type))

        cluster_env_check_tool.env_check()
        cluster_env_check_tool.env_set()

    def kubernetes_env_check(self):
        pass


class ClusterEnvBase(object):
    def __init__(self):
        # get backend env
        backend_yaml = envs.get_runtime_environ("backend_yaml")
        _env = envs.load_yaml(backend_yaml)
        self.backend_env = envs.flatten_environs(_env, ".")
        self.cluster_env = {}

    def env_check(self):
        # check common env
        # fs_name & fs_ugi
        self.cluster_env["FS_NAME"] = self.backend_env.get("config.fs_name",
                                                           "")
        self.cluster_env["FS_UGI"] = self.backend_env.get("config.fs_ugi", "")
        if self.cluster_env["FS_NAME"] == "" or self.cluster_env[
                "FS_UGI"] == "":
            raise ValueError(
                "No -- FS_UGI or FS_NAME -- found in your backend.yaml, please check."
            )

        # output_path
        self.cluster_env["OUTPUT_PATH"] = self.backend_env.get(
            "config.output_path", "")
        if self.cluster_env["OUTPUT_PATH"] == "":
            warnings.warn(
                "Job output_path not set! Please check your backend yaml.",
                category=UserWarning,
                stacklevel=2)

        # paddle_version
        self.cluster_env["PADDLE_VERSION"] = self.backend_env.get(
            "config.paddle_version", "1.7.2")

M
MrChengmo 已提交
167 168 169 170
        # python_version
        self.cluster_env["USE_PYTHON3"] = self.backend_env.get(
            "config.use_python3", "0")

C
Chengmo 已提交
171
        # communicator
M
MrChengmo 已提交
172
        max_thread_num = int(envs.get_runtime_environ("max_thread_num"))
C
Chengmo 已提交
173 174 175 176 177
        self.cluster_env[
            "FLAGS_communicator_is_sgd_optimizer"] = self.backend_env.get(
                "config.communicator.FLAGS_communicator_is_sgd_optimizer", 0)
        self.cluster_env[
            "FLAGS_communicator_send_queue_size"] = self.backend_env.get(
C
fix  
chengmo 已提交
178 179
                "config.communicator.FLAGS_communicator_send_queue_size",
                max_thread_num)
C
Chengmo 已提交
180 181 182 183 184
        self.cluster_env[
            "FLAGS_communicator_thread_pool_size"] = self.backend_env.get(
                "config.communicator.FLAGS_communicator_thread_pool_size", 32)
        self.cluster_env[
            "FLAGS_communicator_max_merge_var_num"] = self.backend_env.get(
C
fix  
chengmo 已提交
185 186
                "config.communicator.FLAGS_communicator_max_merge_var_num",
                max_thread_num)
C
Chengmo 已提交
187 188 189
        self.cluster_env[
            "FLAGS_communicator_max_send_grad_num_before_recv"] = self.backend_env.get(
                "config.communicator.FLAGS_communicator_max_send_grad_num_before_recv",
M
MrChengmo 已提交
190
                max_thread_num)
C
Chengmo 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
        self.cluster_env["FLAGS_communicator_fake_rpc"] = self.backend_env.get(
            "config.communicator.FLAGS_communicator_fake_rpc", 0)
        self.cluster_env["FLAGS_rpc_retry_times"] = self.backend_env.get(
            "config.communicator.FLAGS_rpc_retry_times", 3)

        # ak & sk
        self.cluster_env["AK"] = self.backend_env.get("submit.ak", "")
        self.cluster_env["SK"] = self.backend_env.get("submit.sk", "")
        if self.cluster_env["AK"] == "" or self.cluster_env["SK"] == "":
            raise ValueError(
                "No -- AK or SK -- found in your backend.yaml, please check.")

        # priority
        self.cluster_env["PRIORITY"] = self.backend_env.get("submit.priority",
                                                            "high")

        # job name
        self.cluster_env["JOB_NAME"] = self.backend_env.get(
            "submit.job_name", "PaddleRecClusterJob")

        # group
        self.cluster_env["GROUP_NAME"] = self.backend_env.get("submit.group",
                                                              "paddle")

        # start_cmd
        self.cluster_env["START_CMD"] = self.backend_env.get(
            "submit.start_cmd", "python -m paddlerec.run -m config.yaml")

        # files
        self.cluster_env["FILES"] = self.backend_env.get("submit.files", "")
        if self.cluster_env["FILES"] == "":
            raise ValueError(
                "No -- files -- found in your backend.yaml, please check.")

    def env_set(self):
        envs.set_runtime_environs(self.cluster_env)
        flattens = envs.flatten_environs(self.cluster_env)
228 229
        logger.info(
            envs.pretty_print_envs(flattens, ("Cluster Envs", "Value")))
C
Chengmo 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247


class PaddleCloudMpiEnv(ClusterEnvBase):
    def __init__(self):
        super(PaddleCloudMpiEnv, self).__init__()

    def env_check(self):
        super(PaddleCloudMpiEnv, self).env_check()

        # check mpi env

        self.cluster_env["DISTRIBUTE_MODE"] = "PS_CPU_MPI"

        # train_data_path
        self.cluster_env["TRAIN_DATA_PATH"] = self.backend_env.get(
            "config.train_data_path", "")
        if self.cluster_env["TRAIN_DATA_PATH"] == "":
            raise ValueError(
C
Chengmo 已提交
248
                "No -- TRAIN_DATA_PATH -- found in your backend.yaml, please add train_data_path in your backend yaml."
C
Chengmo 已提交
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
            )
        # test_data_path
        self.cluster_env["TEST_DATA_PATH"] = self.backend_env.get(
            "config.test_data_path", "")
        if self.cluster_env["TEST_DATA_PATH"] == "":
            warnings.warn(
                "Job test_data_path not set! Please check your backend yaml.",
                category=UserWarning,
                stacklevel=2)

        # thirdparty_path
        self.cluster_env["THIRDPARTY_PATH"] = self.backend_env.get(
            "config.thirdparty_path", "")
        if self.cluster_env["THIRDPARTY_PATH"] == "":
            warnings.warn(
                "Job thirdparty_path not set! Please check your backend yaml.",
                category=UserWarning,
                stacklevel=2)

        # nodes
        self.cluster_env["MPI_NODES"] = self.backend_env.get("submit.nodes", 1)


class PaddleCloudK8sEnv(ClusterEnvBase):
    def __init__(self):
        super(PaddleCloudK8sEnv, self).__init__()

    def env_check(self):
        super(PaddleCloudK8sEnv, self).env_check()

        # check afs_remote_mount_point
        self.cluster_env["AFS_REMOTE_MOUNT_POINT"] = self.backend_env.get(
            "config.afs_remote_mount_point", "")
        if self.cluster_env["AFS_REMOTE_MOUNT_POINT"] == "":
            warnings.warn(
                "Job afs_remote_mount_point not set! Please check your backend yaml.",
                category=UserWarning,
                stacklevel=2)
        warnings.warn(
C
Chengmo 已提交
288
            "The remote afs path will be mounted to the ./afs/",
C
Chengmo 已提交
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
            category=UserWarning,
            stacklevel=2)


class CloudCollectiveEnv(PaddleCloudK8sEnv):
    def __init__(self):
        super(CloudCollectiveEnv, self).__init__()

    def env_check(self):
        super(CloudCollectiveEnv, self).env_check()

        self.cluster_env["DISTRIBUTE_MODE"] = "COLLECTIVE_GPU_K8S"
        self.cluster_env["K8S_TRAINERS"] = self.backend_env.get(
            "submit.k8s_trainers", 1)
        self.cluster_env["K8S_GPU_CARD"] = self.backend_env.get(
            "submit.k8s_gpu_card", 1)
        self.cluster_env["K8S_CPU_CORES"] = self.backend_env.get(
            "submit.k8s_cpu_cores", 1)
C
Chengmo 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324


class CloudPsCpuEnv(PaddleCloudK8sEnv):
    def __init__(self):
        super(CloudPsCpuEnv, self).__init__()

    def env_check(self):
        super(CloudPsCpuEnv, self).env_check()

        self.cluster_env["DISTRIBUTE_MODE"] = "PS_CPU_K8S"
        self.cluster_env["K8S_TRAINERS"] = self.backend_env.get(
            "submit.k8s_trainers", 1)
        self.cluster_env["K8S_CPU_CORES"] = self.backend_env.get(
            "submit.k8s_cpu_cores", 2)
        self.cluster_env["K8S_PS_NUM"] = self.backend_env.get(
            "submit.k8s_ps_num", 1)
        self.cluster_env["K8S_PS_CORES"] = self.backend_env.get(
            "submit.k8s_ps_cores", 2)