local_mpi.py 1.9 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
from __future__ import unicode_literals
import subprocess
import sys
import os
import copy

22
from paddlerec.core.engine.engine import Engine
T
tangwei 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35


class LocalMPIEngine(Engine):
    def start_procs(self):
        logs_dir = self.envs["log_dir"]

        default_env = os.environ.copy()
        current_env = copy.copy(default_env)
        current_env.pop("http_proxy", None)
        current_env.pop("https_proxy", None)
        procs = []
        log_fns = []

36
        factory = "paddlerec.core.factory"
T
tangwei 已提交
37 38
        cmd = "mpirun -npernode 2 -timestamp-output -tag-output".split(" ")
        cmd.extend([sys.executable, "-u", "-m", factory, self.trainer])
T
tangwei 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51

        if logs_dir is not None:
            os.system("mkdir -p {}".format(logs_dir))
            fn = open("%s/job.log" % logs_dir, "w")
            log_fns.append(fn)
            proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
        else:
            proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
        procs.append(proc)

        for i in range(len(procs)):
            if len(log_fns) > 0:
                log_fns[i].close()
T
tangwei 已提交
52
            procs[i].wait()
T
tangwei 已提交
53 54 55 56
        print("all workers and parameter servers already completed", file=sys.stderr)

    def run(self):
        self.start_procs()