local_cluster.py 3.7 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
from __future__ import unicode_literals
import subprocess
import sys
import os
import copy

22 23
from paddlerec.core.engine.engine import Engine
from paddlerec.core.utils import envs
T
tangwei 已提交
24 25 26 27 28 29


class LocalClusterEngine(Engine):
    def start_procs(self):
        worker_num = self.envs["worker_num"]
        server_num = self.envs["server_num"]
C
chengmo 已提交
30
        ports = [self.envs["start_port"]]
T
tangwei 已提交
31 32 33 34 35 36 37 38 39
        logs_dir = self.envs["log_dir"]

        default_env = os.environ.copy()
        current_env = copy.copy(default_env)
        current_env["CLUSTER_INSTANCE"] = "1"
        current_env.pop("http_proxy", None)
        current_env.pop("https_proxy", None)
        procs = []
        log_fns = []
C
chengmo 已提交
40 41 42 43 44 45 46

        for i in range(server_num - 1):
            while True:
                new_port = envs.find_free_port()
                if new_port not in ports:
                    ports.append(new_port)
                    break
T
tangwei 已提交
47
        user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
C
chengmo 已提交
48 49 50 51
        user_endpoints_ips = [x.split(":")[0]
                              for x in user_endpoints.split(",")]
        user_endpoints_port = [x.split(":")[1]
                               for x in user_endpoints.split(",")]
T
tangwei 已提交
52

53
        factory = "paddlerec.core.factory"
T
tangwei 已提交
54 55 56 57 58 59 60 61 62 63 64
        cmd = [sys.executable, "-u", "-m", factory, self.trainer]

        for i in range(server_num):
            current_env.update({
                "PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints,
                "PADDLE_PORT": user_endpoints_port[i],
                "TRAINING_ROLE": "PSERVER",
                "PADDLE_TRAINERS_NUM": str(worker_num),
                "POD_IP": user_endpoints_ips[i]
            })

65 66 67
            os.system("mkdir -p {}".format(logs_dir))
            fn = open("%s/server.%d" % (logs_dir, i), "w")
            log_fns.append(fn)
C
chengmo 已提交
68 69
            proc = subprocess.Popen(
                cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
T
tangwei 已提交
70 71 72 73 74 75 76 77 78 79
            procs.append(proc)

        for i in range(worker_num):
            current_env.update({
                "PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints,
                "PADDLE_TRAINERS_NUM": str(worker_num),
                "TRAINING_ROLE": "TRAINER",
                "PADDLE_TRAINER_ID": str(i)
            })

80 81 82
            os.system("mkdir -p {}".format(logs_dir))
            fn = open("%s/worker.%d" % (logs_dir, i), "w")
            log_fns.append(fn)
C
chengmo 已提交
83 84
            proc = subprocess.Popen(
                cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
T
tangwei 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98
            procs.append(proc)

        # only wait worker to finish here
        for i, proc in enumerate(procs):
            if i < server_num:
                continue
            procs[i].wait()
            if len(log_fns) > 0:
                log_fns[i].close()

        for i in range(server_num):
            if len(log_fns) > 0:
                log_fns[i].close()
            procs[i].terminate()
99
        print("all workers already completed, you can view logs under the `{}` directory".format(logs_dir),
100
              file=sys.stderr)
T
tangwei 已提交
101

102 103
    def run(self):
        self.start_procs()