collective.py 2.0 KB
Newer Older
K
kuizhiqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import tempfile
K
kuizhiqing 已提交
16 17 18 19 20 21 22

from paddle.distributed.fleet.launch_utils import *

from paddle.distributed.fleet.elastic.manager import LauncherInterface


class CollectiveLauncher(LauncherInterface):
23

K
kuizhiqing 已提交
24 25 26 27 28 29 30
    def __init__(self, args):
        self.args = args
        self.procs = []

    def launch(self):
        logger.info("collective lauchner launch ...")
        args = self.args
31
        self.tmp_dir = tempfile.mkdtemp()
X
xiayanming 已提交
32
        cluster, pod = paddle.distributed.fleet.launch.get_cluster_info(args)
33 34
        global_envs = paddle.distributed.fleet.launch.get_global_envs(
            args, self.tmp_dir)
K
kuizhiqing 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50

        self.procs = start_local_trainers(
            cluster,
            pod,
            training_script=args.training_script,
            training_script_args=args.training_script_args,
            log_dir=args.log_dir,
            envs=global_envs)

        for idx, proc in enumerate(self.procs):
            logger.info("launch proc_id:{} idx:{}".format(proc.proc.pid, idx))

    def stop(self):
        logger.info("collective lauchner stop ...")
        if not self._terminate_procs():
            logger.error("kill process failed")
51 52
        if os.path.exists(self.tmp_dir):
            shutil.rmtree(self.tmp_dir)
K
kuizhiqing 已提交
53 54 55 56 57 58 59 60

    def watch(self):
        logger.debug("collective lauchner watch ...")
        for p in self.procs:
            if p.log_fn and p.local_rank == 0:
                pull_worker_log(p)
        ret = self._check_procs()
        return ret