collective.py 2.1 KB
Newer Older
K
kuizhiqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import tempfile
16 17 18
import os
import shutil
import paddle
K
kuizhiqing 已提交
19

20
from paddle.distributed.fleet.launch_utils import logger, pull_worker_log, start_local_trainers
K
kuizhiqing 已提交
21 22 23 24
from paddle.distributed.fleet.elastic.manager import LauncherInterface


class CollectiveLauncher(LauncherInterface):
25

K
kuizhiqing 已提交
26 27 28 29 30 31 32
    def __init__(self, args):
        self.args = args
        self.procs = []

    def launch(self):
        logger.info("collective lauchner launch ...")
        args = self.args
33
        self.tmp_dir = tempfile.mkdtemp()
X
xiayanming 已提交
34
        cluster, pod = paddle.distributed.fleet.launch.get_cluster_info(args)
35 36
        global_envs = paddle.distributed.fleet.launch.get_global_envs(
            args, self.tmp_dir)
K
kuizhiqing 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

        self.procs = start_local_trainers(
            cluster,
            pod,
            training_script=args.training_script,
            training_script_args=args.training_script_args,
            log_dir=args.log_dir,
            envs=global_envs)

        for idx, proc in enumerate(self.procs):
            logger.info("launch proc_id:{} idx:{}".format(proc.proc.pid, idx))

    def stop(self):
        logger.info("collective lauchner stop ...")
        if not self._terminate_procs():
            logger.error("kill process failed")
53 54
        if os.path.exists(self.tmp_dir):
            shutil.rmtree(self.tmp_dir)
K
kuizhiqing 已提交
55 56 57 58 59 60 61 62

    def watch(self):
        logger.debug("collective lauchner watch ...")
        for p in self.procs:
            if p.log_fn and p.local_rank == 0:
                pull_worker_log(p)
        ret = self._check_procs()
        return ret