# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import logging import hashlib import json import os import six import time import collections from threading import Thread, current_thread from contextlib import contextmanager from paddle.fluid import unique_name, compiler from .checkpoint_saver import SerializableBase, CheckpointSaver, PaddleModel from paddle.fluid.framework import in_dygraph_mode, Program g_train_epoch_range = None g_checker = None logger = None generator = unique_name.UniqueNameGenerator() CONST_CHECKPOINT = "checkpoint" CONST_MEMORYINIT = "memory_init" # auto checkpoint by dataloader event. CONST_DACP_TYPE = "dacp" # auto checkpoint by loop range. CONST_ACP_TYPE = "acp" g_acp_type = None g_program_attr = {} # program_name->can_be_auto_checkpoint def _get_logger(log_level, name="auto_checkpoint"): global logger if logger != None: return logger logger = logging.getLogger(name) logger.setLevel(log_level) logger.propagate = False log_handler = logging.StreamHandler() log_format = logging.Formatter( '%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') log_handler.setFormatter(log_format) logger.addHandler(log_handler) return logger def _thread_checker(): assert current_thread().name == "MainThread", \ "auto checkpoint must run under main thread" class AutoCheckpointChecker(object): def __init__(self): self._run_env = None self._platform = None self._job_id = None self._hdfs_home = None self._hdfs_name = None self._hdfs_ugi = None self._hdfs_checkpoint_path = None self._trainer_id = None self._ce_test = None self._run_env = os.getenv("PADDLE_RUNNING_ENV") if self._run_env != "PADDLE_EDL_AUTO_CHECKPOINT": return try: self._platform = os.environ["PADDLE_RUNNING_PLATFORM"] self._job_id = os.environ["PADDLE_JOB_ID"] self._hdfs_home = os.environ["PADDLE_EDL_HDFS_HOME"] self._hdfs_name = os.environ["PADDLE_EDL_HDFS_NAME"] self._hdfs_ugi = os.environ["PADDLE_EDL_HDFS_UGI"] self._hdfs_checkpoint_path = os.environ[ "PADDLE_EDL_HDFS_CHECKPOINT_PATH"] self._trainer_id = int(os.environ["PADDLE_TRAINER_ID"]) self._ce_test = int(os.getenv("PADDLE_EDL_ONLY_FOR_CE_TEST", "0")) self._fs_cache = os.getenv("PADDLE_EDL_FS_CACHE", ".cache") self._save_checkpoint_inter = int( os.getenv("PADDLE_EDL_SAVE_CHECKPOINT_INTER", "900")) #s if not self._ce_test: assert len(self._hdfs_home) > 3 and \ len(self._hdfs_name) > 6 and \ len(self._hdfs_ugi) > 3 and \ len(self._hdfs_checkpoint_path) > 0, "hdfs environ must set" else: assert len(self._hdfs_home) > 3 and \ len(self._hdfs_checkpoint_path) > 0, "hdfs environ must set" except Exception as e: logger.fatal("exception:{}".format(e)) sys.exit(1) def get_range_checkpoint_path(self, name): return "{}/{}/range/{}".format(self.hdfs_checkpoint_path, self.job_id, name) def get_exe_checkpoint_path(self, name): return "{}/{}/exe/{}".format(self.hdfs_checkpoint_path, self.job_id, name) def get_job_path(self): return "{}/{}".format(self.hdfs_checkpoint_path, self.job_id) @property def save_checkpoint_inter(self): return self._save_checkpoint_inter def valid(self): if in_dygraph_mode(): return False return self._run_env is not None and \ self._platform is not None and \ self._job_id is not None and \ self._hdfs_home is not None and \ self._hdfs_name is not None and \ self._hdfs_ugi is not None and \ self._hdfs_checkpoint_path is not None and \ self._trainer_id is not None def __str__(self): return "run_env:{} platform:{} job_id:{} \ hdfs_home:{} hdfs_name:{} hdfs_ugi:{} \ hdfs_checkpoint_path:{} trainer_id:{} ce_test".format( self._run_env, self._platform, self._hdfs_home, self._hdfs_name, self._hdfs_ugi, self._hdfs_checkpoint_path, self._trainer_id, self._ce_test) @property def trainer_id(self): return self._trainer_id @property def run_env(self): return self._run_env @property def platform(self): return self._platform @property def job_id(self): return self._job_id @property def hdfs_home(self): return self._hdfs_home @property def hdfs_name(self): return self._hdfs_name @property def ce_test(self): return self._ce_test @property def hdfs_ugi(self): return self._hdfs_ugi @property def hdfs_checkpoint_path(self): return self._hdfs_checkpoint_path @staticmethod def generate_range_name(): return generator("_range_") class ExeTrainStatus(SerializableBase): def __init__(self): self._epoch_no = -1 # start epoch_no self._hash_key = None self._key = None self._checkpoint_path = None self._checkpoint_no = None self._restored_from = None self._exe = None self._program = None self._exe_name = None self._program_name = None self._file_name = "exe_train_status" def __eq__(self, t): return self._epoch_no == t._epoch_no and \ self._hash_key == t._hash_key and \ self._key == t._key and \ self._checkpoint_path == t._checkpoint_path and \ self._checkpoint_no == t._checkpoint_no and \ self._exe_name == t._exe_name and \ self._program_name == t._program_name def __ne__(self, t): return not self == t def serialize(self, path): file_name = "{}/{}".format(path, self._file_name) with open(file_name, 'w') as f: s = self._serialize() f.write(s) def _serialize(self, pop_keys=["restored_from"]): d = self._to_dict() for k in pop_keys: d.pop(k, None) return json.dumps(d) def deserialize(self, path): d = None file_name = "{}/{}".format(path, self._file_name) with open(file_name, 'r') as f: s = f.read() self._deserialize(s) def _deserialize(self, s): d = json.loads(s) self._epoch_no = d["epoch_no"] self._key = d["key"] self._hash_key = d["hash_key"] self._checkpoint_path = d["checkpoint_path"] self._checkpoint_no = d["checkpoint_no"] self._exe_name = d["exe_name"] self._program_name = d["program_name"] def _to_dict(self): return { "epoch_no": self._epoch_no, "key": self._key, "hash_key": self._hash_key, "checkpoint_path": self._checkpoint_path, "restored_from": self._restored_from, "exe_name": self._exe_name, "program_name": self._program_name, "checkpoint_no": self._checkpoint_no } def __str__(self): return self._serialize([]) class TrainEpochRange(SerializableBase): def __init__(self, max_epoch_num, name, checkpoint_inter=None, restored=True): self._max_epoch_num = max_epoch_num self._epoch_no = -1 # current epoch_no self._name = name self._restored_from = None self._exe_status = {} self._flag_generated = False self._checker = g_checker if checkpoint_inter is not None: self._save_checkpoint_inter = checkpoint_inter else: self._save_checkpoint_inter = self._checker.save_checkpoint_inter assert self._save_checkpoint_inter >= 0, "checkpointer:{} must >=0".format( self._save_checkpoint_inter) self._last_checkpoint_time = time.time() self._load_cp_nos = None self._checkpoint_epoch_no = None if not self._checker.valid(): return self._file_name = "range_train_status" if not restored: return self._checkpoint_path = self._checker.get_range_checkpoint_path(name) config = { "fs.default.name": self._checker.hdfs_name, "hadoop.job.ugi": self._checker.hdfs_ugi } if self._checker.ce_test: config = None from paddle.distributed.fleet.utils.fs import HDFSClient self._hdfs = HDFSClient(self._checker.hdfs_home, config) self._cper = CheckpointSaver(self._hdfs) _thread_checker() self._get_last_valid_checkpoint() def _look_for_valid(self, cp_nos): cps = [] epoch_no = -1 for i in cp_nos[::-1]: t = TrainEpochRange(self._max_epoch_num, self.name, restored=False) self._cper.load_checkpoint( self._checkpoint_path, [t], self._checker.trainer_id, checkpoint_no=i, local_cache_path=self._checker._fs_cache) cps.append(t) logger.debug("look for valid:{} t:{}".format(i, t._serialize())) if epoch_no < 0: epoch_no = t._epoch_no else: if epoch_no - t._epoch_no >= 1: return t, i return None, None def _get_last_valid_checkpoint(self): self._load_cp_nos = self._cper.get_checkpoint_no(self._checkpoint_path) logger.info("find checkpoint nos:{}".format(self._load_cp_nos)) if len(self._load_cp_nos) < 1: self._restored_from = CONST_MEMORYINIT return if g_acp_type == CONST_ACP_TYPE: # get the last one self._cper.load_checkpoint( self._checkpoint_path, [self], self._checker.trainer_id, local_cache_path=self._checker._fs_cache) self._restored_from = CONST_CHECKPOINT self._checkpoint_epoch_no = self._epoch_no logger.info("load tain_epoch_range checkpoint:{}".format( self._serialize())) elif g_acp_type == CONST_DACP_TYPE: t, i = self._look_for_valid(self._load_cp_nos) if t is None: self._restored_from = CONST_MEMORYINIT return self._cper.load_checkpoint( self._checkpoint_path, [self], self._checker.trainer_id, checkpoint_no=i, local_cache_path=self._checker._fs_cache) self._restored_from = CONST_CHECKPOINT self._checkpoint_epoch_no = self._epoch_no logger.info("load tain_epoch_range checkpoint:{}".format( self._serialize())) else: assert False, "not supported acp_type:{}".format(g_acp_type) def _to_dict(self): d = { "max_epoch_num": self._max_epoch_num, "epoch_no": self._epoch_no, "name": self._name, "checkpoint_path": self._checkpoint_path, "restored_from": self._restored_from, "checkpoint_epoch_no": self._checkpoint_epoch_no } return d def __str__(self): return self._serialize([]) @property def name(self): return self._name def serialize(self, path): file_name = "{}/{}".format(path, self._file_name) with open(file_name, 'w') as f: s = self._serialize() f.write(s) def _serialize(self, pop_keys=["restored_from", "checkpoint_epoch_no"]): # self d = self._to_dict() for k in pop_keys: d.pop(k, None) # registerd exes d["exe_status"] = {} e = d["exe_status"] for k, t in six.iteritems(self._exe_status): e[t._key] = t._serialize() return json.dumps(d) @property def restored_from(self): return self._restored_from def deserialize(self, path): d = None file_name = "{}/{}".format(path, self._file_name) with open(file_name, 'r') as f: d = json.load(f) # self self._max_epoch_num = d["max_epoch_num"] self._epoch_no = d["epoch_no"] self._name = d["name"] self._checkpoint_path = d["checkpoint_path"] # exes status e = d["exe_status"] for k, v in six.iteritems(e): t = ExeTrainStatus() t._deserialize(v) self._exe_status[k] = t def next(self): _thread_checker() if self._max_epoch_num < 0: self._max_epoch_num = sys.maxint assert self._epoch_no >= -1, "self._epoch_no:{} must >=-1".format( self._epoch_no) self._last_checkpoint_time = time.time() start = self._epoch_no + 1 logger.info("started epoch_no:{} max_epoch_num:{}".format( start, self._max_epoch_num)) for i in range(start, self._max_epoch_num): self._epoch_no = i yield i self.save_checkpoint() def get(self): return self._epoch_no def save_checkpoint(self): # not save last one because exe and program can't be restored. if self._checker.trainer_id == 0: if time.time() - self._last_checkpoint_time >= \ self._save_checkpoint_inter: if g_acp_type == CONST_ACP_TYPE: # not save the last one if self._max_epoch_num > 0 and self._epoch_no != self._max_epoch_num - 1: self._save_checkpoint() elif g_acp_type == CONST_DACP_TYPE: self._save_checkpoint() else: assert False, "not supported acp_type:{}".format(g_acp_type) self._last_checkpoint_time = time.time() def _save_checkpoint(self): """ status => /jobid/xxx_range_xx/range/ model => /exe/ """ if not self._checker.valid(): return e = self._exe_status for k, t in six.iteritems(self._exe_status): m = PaddleModel(t._exe, t._program) p = self._checker.get_exe_checkpoint_path(t._hash_key) t._epoch_no = self.get() path, checkpoint_no = self._cper.save_checkpoint( p, [m], self._checker.trainer_id, local_cache_path=self._checker._fs_cache) # index info t._checkpoint_path = path t._checkpoint_no = checkpoint_no e[t._key] = t logger.debug("save executor checkpoint:{}".format(t._serialize())) if len(self._exe_status) > 0: self._cper.save_checkpoint( self._checkpoint_path, [self], local_cache_path=self._checker._fs_cache) logger.info("save train_epoch_range checkpoint:{}".format( self._serialize())) self._generate_flag() def _generate_flag(self): if self._flag_generated: return name = "can_be_auto_checkpoint.flag" path = self._checker.get_job_path() + "/" + name logger.info("this job can_be_auto_checkpoint") self._hdfs.mkdirs(self._checker.get_job_path()) self._hdfs.touch(path, exist_ok=True) self._flag_generated = True def _get_train_epoch_range(): return g_train_epoch_range def _check_program_oprole(program): global_block = program.global_block() has_backward = False has_opt = False for idx, op in enumerate(global_block.ops): if op._is_backward_op(): has_backward = True if op._is_optimize_op(): has_opt = True if has_backward and has_opt: return True return False def _can_auto_checkpoint(prog): if not isinstance(prog, compiler.CompiledProgram) and \ not isinstance(prog, Program): return False if isinstance(prog, compiler.CompiledProgram): if prog._program is None or \ prog._program._is_distributed: return False else: if prog._is_distributed: return False program = _get_valid_program(prog) if program._auto_checkpoint_name in g_program_attr: if not g_program_attr[program._auto_checkpoint_name]: return False else: ret = False if isinstance(program, compiler.CompiledProgram): ret = _check_program_oprole(program._program) else: ret = _check_program_oprole(program) g_program_attr[program._auto_checkpoint_name] = ret if not ret: logger.debug("program {} need't to auto checkpoint".format( program._auto_checkpoint_name)) return False return g_checker.valid() and g_train_epoch_range is not None def _get_running_key(exe_name, program_name): return "{}_{}".format(exe_name, program_name) def _get_checker(): _get_logger(20) global g_checker if g_checker is None: g_checker = AutoCheckpointChecker() return g_checker def _normal_yield(max_epoch_num): if max_epoch_num < 0: max_epoch_num = sys.maxint for i in range(0, max_epoch_num): yield i return def train_epoch_range(max_epoch_num, save_checkpoint_inter=None): global g_acp_type if not _get_checker().valid(): logger.warning( "auto checkpoint will take effect automaticly on PaddleCloud") for i in _normal_yield(max_epoch_num): yield i return if g_acp_type == CONST_DACP_TYPE: for i in _normal_yield(max_epoch_num): yield i return g_acp_type = CONST_ACP_TYPE logger.info("acp_type:{}".format(g_acp_type)) global g_train_epoch_range try: g_train_epoch_range = TrainEpochRange( max_epoch_num, g_checker.generate_range_name(), checkpoint_inter=save_checkpoint_inter) for i in g_train_epoch_range.next(): yield i finally: g_train_epoch_range = None def _get_valid_program(prog): if isinstance(prog, compiler.CompiledProgram): return prog._program return prog def _auto_checkpoint(exe, prog): _get_checker() assert exe._auto_checkpoint_name != None if not _can_auto_checkpoint(prog): return program = _get_valid_program(prog) assert program._auto_checkpoint_name != None exe_status = g_train_epoch_range._exe_status key = _get_running_key(exe._auto_checkpoint_name, program._auto_checkpoint_name) if g_train_epoch_range.restored_from == CONST_CHECKPOINT: assert key in exe_status, "when restored key:{} must be in train_epoch_range:{}".format( key, g_train_epoch_range) t = None if key in exe_status: t = exe_status[key] if t._restored_from is None: a = CheckpointSaver(g_train_epoch_range._hdfs) m = PaddleModel(exe, program) a.load_checkpoint( g_checker.get_exe_checkpoint_path(key), [m], trainer_id=g_checker.trainer_id, checkpoint_no=t._checkpoint_no, local_cache_path=g_checker._fs_cache) t._restored_from = CONST_CHECKPOINT logger.info("load executor checkpoint {}".format(t)) t._exe = exe t._program = program t._epoch_no = g_train_epoch_range.get() else: t = ExeTrainStatus() t._epoch_no = g_train_epoch_range.get() t._hash_key = key t._key = key t._restored_from = CONST_MEMORYINIT t._exe = exe t._program = program t._exe_name = exe._auto_checkpoint_name t._program_name = program._auto_checkpoint_name # register this exe_status[key] = t logger.info("not found checkpoint, so train from epoch 0") _thread_checker()