controller.py 7.4 KB
Newer Older
K
kuizhiqing 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
K
kuizhiqing 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
K
kuizhiqing 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
K
kuizhiqing 已提交
9 10 11 12 13 14 15 16 17 18
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import signal

19 20 21
from paddle.distributed.launch.job.job import Job
from paddle.distributed.launch.job.pod import Pod
from paddle.distributed.launch.job.container import Container
K
kuizhiqing 已提交
22 23

from .master import Master
24
from .watcher import Watcher
K
kuizhiqing 已提交
25 26 27 28 29 30 31

import time


class ControleMode:
    COLLECTIVE = "collective"
    PS = "ps"
32
    IPU = "ipu"
K
kuizhiqing 已提交
33 34 35


class ControllerBase(object):
36

K
kuizhiqing 已提交
37 38 39 40 41 42 43 44
    def __init__(self, ctx):
        signal.signal(signal.SIGTERM, self.signal_handler)
        signal.signal(signal.SIGABRT, self.signal_handler)
        signal.signal(signal.SIGINT, self.signal_handler)

        self.ctx = ctx
        self.master = Master.factory(self.ctx)

45 46
        self.watcher = Watcher(self.ctx)

47
        self.job = Job(nnodes=self.ctx.args.nnodes,
48
                       mode=self.ctx.args.run_mode,
49
                       jid=self.ctx.args.job_id)
K
kuizhiqing 已提交
50 51
        self.pod = Pod()

K
kuizhiqing 已提交
52 53
        self.ctx.set_envs({"POD_NAME": self.pod.name})

K
kuizhiqing 已提交
54 55
        self.join_server = None

56
    def deploy_pod(self):
K
kuizhiqing 已提交
57

58
        assert len(self.pod.containers) > 0, "No container in the pod"
K
kuizhiqing 已提交
59 60 61 62

        self.ctx.logger.info("Run {}".format(self.pod))
        self.ctx.logger.debug(self.pod.containers[0])

63
        self.ctx.status.run()
K
kuizhiqing 已提交
64 65
        self.pod.deploy()

66 67 68 69 70 71
    def run(self):
        self.build_job()
        self.build_pod()

        self.deploy_pod()

K
kuizhiqing 已提交
72 73 74
        self.watch()

    def watch(self) -> bool:
75 76 77 78 79
        '''
        watch self and peer status, return true to exit
        '''
        #TODO(kuizhiqing) unify ctx.status and master status

80 81
        self.ctx.logger.info("Watching {}".format(self.pod))

82 83 84
        while not self.ctx.status.is_done():
            status = self.pod.watch(timeout=2)

85 86 87
            #if self.ctx.continous_log():
            # default to print log
            self.pod.logs()
88 89 90 91 92 93 94

            # completed
            if status == self.ctx.status.COMPLETED:
                self.ctx.status.complete()

                self.master.set_status(status)

95 96 97
                while self.pod.logs():
                    pass

98 99 100 101 102 103 104 105 106 107 108 109 110
                self.ctx.logger.info("Pod {}".format(status))
                return True

            # self failure
            elif status == self.ctx.status.FAILED:
                self.ctx.status.fail()

                self.master.set_status(status)
                self.master.restart_peer()

                fc = self.pod.failed_container()
                self.ctx.logger.info("Pod {}".format(status))
                self.ctx.logger.error("Container failed !!!\n{}".format(fc[0]))
111 112 113
                self.ctx.logger.info(
                    "------------------------- ERROR LOG DETAIL -------------------------"
                )
114 115 116
                fc[0].tail()

                if self.ctx.args.elastic_level <= 0:
K
kuizhiqing 已提交
117
                    self.pod.stop(timeout=3)
118 119
                    return True
                else:
K
kuizhiqing 已提交
120
                    self.pod.stop(timeout=30)
121
                    return False
K
kuizhiqing 已提交
122

123
            # peer failure
124 125
            if self.ctx.status.is_restarting(
            ) and self.master.get_status() != self.ctx.status.COMPLETED:
K
kuizhiqing 已提交
126
                self.pod.stop(timeout=30)
127
                return False
K
kuizhiqing 已提交
128 129 130

    def stop(self, sigint=None):
        self.ctx.logger.debug("Controller stop")
131 132 133

        self.watcher.stop()

K
kuizhiqing 已提交
134
        self.master.stop()
K
kuizhiqing 已提交
135
        self.pod.stop(timeout=30)
K
kuizhiqing 已提交
136 137 138 139 140 141 142 143 144 145

    def finalize(self):
        self.pod.join()
        self.master.stop()

        self.ctx.logger.info("Exit code {}".format(self.pod.exit_code))
        sys.exit(self.pod.exit_code)

    def signal_handler(self, sigint, frame):
        if hasattr(self, 'sigint'):
K
kuizhiqing 已提交
146
            self.ctx.logger.info("Force quit in 10 seconds...")
K
kuizhiqing 已提交
147
            self.pod.stop(timeout=10)
K
kuizhiqing 已提交
148 149
            sys.exit(sigint)

K
kuizhiqing 已提交
150 151
        self.ctx.logger.info("Terminating with signal {}".format(sigint))

K
kuizhiqing 已提交
152 153
        self.sigint = sigint
        self.ctx.status.done()
K
kuizhiqing 已提交
154
        self.stop(sigint=sigint)
K
kuizhiqing 已提交
155
        self.ctx.logger.info("Exit with signal {}".format(sigint))
K
kuizhiqing 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
        sys.exit(sigint)


class Controller(ControllerBase):
    '''
    Controller API for customization
    '''

    def build_job(self):
        '''
        build job fill the job info.
        '''
        self.ctx.logger.info(self.job)

    def build_pod(self) -> bool:
        '''
        build pod includes creating containers etc.

        Return True if succeed
        '''
        raise NotImplementedError

    def _get_entrypoint(self):
K
kuizhiqing 已提交
179 180 181 182 183
        if self.ctx.args.training_script.endswith('.py'):
            entrypoint = [sys.executable, "-u", self.ctx.args.training_script]
        else:
            entrypoint = [self.ctx.args.training_script]

K
kuizhiqing 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
        entrypoint.extend(self.ctx.args.training_script_args)
        return entrypoint

    def _get_out_err_file(self, out=None, err=None):
        if out and self.ctx.args.log_dir != "":
            out = os.path.join(self.ctx.args.log_dir, out)
        if err and self.ctx.args.log_dir != "":
            err = os.path.join(self.ctx.args.log_dir, err)
        return out, (err or out)

    def new_container(self,
                      entrypoint=None,
                      envs={},
                      use_ctx_env=True,
                      out=None,
                      err=None):
        c = Container(
            entrypoint=(entrypoint or self._get_entrypoint()),
202 203
            env=(self.ctx.get_envs() if use_ctx_env else {}),
        )
K
kuizhiqing 已提交
204 205 206 207 208 209 210 211
        c.outfile, c.errfile = self._get_out_err_file(out, err)
        c.update_env(envs)
        return c

    def add_container(self,
                      container=None,
                      entrypoint=None,
                      envs={},
212
                      log_file=None,
K
kuizhiqing 已提交
213 214 215
                      is_init=False):

        if not container:
216 217 218 219
            container = self.new_container(entrypoint=entrypoint,
                                           envs=envs,
                                           out=log_file,
                                           err=log_file)
K
kuizhiqing 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232

        if is_init:
            self.pod.add_init_container(container)
        else:
            self.pod.add_container(container)

    def pod_replicas(self):
        '''
        how many process/container should be run in pod
        '''

        if self.ctx.args.nproc_per_node:
            return int(self.ctx.args.nproc_per_node)
K
kuizhiqing 已提交
233 234
        elif self.ctx.args.devices:
            return len(self.ctx.args.devices.split(','))
K
kuizhiqing 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
        else:
            return self.ctx.node.device.count

    def save_pod_log(self, info):
        '''
        save_pod_log append *info* to the log file of pod.name
        '''
        if not self.ctx.args.log_dir:
            return

        f = os.path.join(self.ctx.args.log_dir,
                         '{}.{}.log'.format(self.job.id, self.pod.name))
        try:
            os.makedirs(os.path.dirname(f), exist_ok=True)
            with open(f, 'a+') as fd:
                fd.write(str(info))
        except Exception as e:
            self.ctx.logger.error("save log failed because {}".format(e))