local_engine.py 3.5 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
from __future__ import unicode_literals
import subprocess
import sys
import os
import copy


def start_procs(args, yaml):
    worker_num = args["worker_num"]
    server_num = args["server_num"]
    start_port = args["start_port"]
    logs_dir = args["log_dir"]

    default_env = os.environ.copy()
    current_env = copy.copy(default_env)
    current_env["CLUSTER_INSTANCE"] = "1"
    current_env.pop("http_proxy", None)
    current_env.pop("https_proxy", None)
    procs = []
    log_fns = []
    ports = range(start_port, start_port + server_num, 1)
    user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
    user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
    user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]

T
tangwei 已提交
41
    factory = "fleetrec.trainer.factory"
T
tangwei12 已提交
42
    cmd = [sys.executable, "-u", "-m", factory, yaml]
T
tangwei 已提交
43 44 45 46 47 48 49 50 51 52

    for i in range(server_num):
        current_env.update({
            "PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints,
            "PADDLE_PORT": user_endpoints_port[i],
            "TRAINING_ROLE": "PSERVER",
            "PADDLE_TRAINERS_NUM": str(worker_num),
            "POD_IP": user_endpoints_ips[i]
        })

T
tangwei12 已提交
53
        if logs_dir is not None:
T
tangwei 已提交
54 55 56
            os.system("mkdir -p {}".format(logs_dir))
            fn = open("%s/server.%d" % (logs_dir, i), "w")
            log_fns.append(fn)
T
tangwei12 已提交
57
            proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
T
tangwei 已提交
58
        else:
T
tangwei12 已提交
59
            proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
T
tangwei 已提交
60 61 62 63 64 65 66 67 68 69
        procs.append(proc)

    for i in range(worker_num):
        current_env.update({
            "PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints,
            "PADDLE_TRAINERS_NUM": str(worker_num),
            "TRAINING_ROLE": "TRAINER",
            "PADDLE_TRAINER_ID": str(i)
        })

T
tangwei12 已提交
70
        if logs_dir is not None:
T
tangwei 已提交
71 72 73
            os.system("mkdir -p {}".format(logs_dir))
            fn = open("%s/worker.%d" % (logs_dir, i), "w")
            log_fns.append(fn)
T
tangwei12 已提交
74
            proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
T
tangwei 已提交
75
        else:
T
tangwei12 已提交
76
            proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
T
tangwei 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        procs.append(proc)

    # only wait worker to finish here
    for i, proc in enumerate(procs):
        if i < server_num:
            continue
        procs[i].wait()
        if len(log_fns) > 0:
            log_fns[i].close()

    print("all workers exit, going to finish parameter server", file=sys.stderr)
    for i in range(server_num):
        if len(log_fns) > 0:
            log_fns[i].close()
        procs[i].terminate()
    print("all parameter server are killed", file=sys.stderr)

T
tangwei12 已提交
94 95 96 97 98 99 100
class Launch():
    def __init__(self, envs, trainer):
        self.envs = envs
        self.trainer = trainer

    def run(self):
        start_procs(self.envs, self.trainer)
T
tangwei 已提交
101