未验证 提交 0067a2e4 编写于 作者: G gongweibao 提交者: GitHub

Save checkpoint automatically (#25917)

上级 e853ece0
......@@ -25,11 +25,13 @@ import six
from .data_feeder import convert_dtype
from .framework import Program, default_main_program, Variable, Operator, convert_np_dtype_to_dtype_
from . import core
from . import unique_name
from . import compiler
from .. import compat as cpt
from .trainer_factory import TrainerFactory
from .trainer_factory import FetchHandlerMonitor
import copy
from .incubate.checkpoint import auto_checkpoint as acp
__all__ = ['Executor', 'global_scope', 'scope_guard']
......@@ -559,6 +561,9 @@ class Executor(object):
self._closed = False
self.pruned_program_scope_caches = dict()
self._auto_checkpoint_name = unique_name.generate(
"__auto_checkpoint_executor__")
def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None)
......@@ -1152,6 +1157,8 @@ class Executor(object):
compiled = isinstance(program, compiler.CompiledProgram)
acp._auto_checkpoint(self, program)
# For backward compatibility, run directly.
if not compiled:
# In distributed training, the compiled program is saved in Program._graph
......
......@@ -2385,12 +2385,29 @@ class Operator(object):
def _is_optimize_op(self):
op_maker = core.op_proto_and_checker_maker
OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
if not self.desc.has_attr(op_maker.kOpRoleAttrName()):
return False
op_role = self.desc.attr(op_maker.kOpRoleAttrName())
if op_role & int(OPTIMIZE):
return True
else:
return False
def _is_backward_op(self):
op_maker = core.op_proto_and_checker_maker
BACKWARD = core.op_proto_and_checker_maker.OpRole.Backward
if not self.desc.has_attr(op_maker.kOpRoleAttrName()):
return False
op_role = self.desc.attr(op_maker.kOpRoleAttrName())
if op_role & int(BACKWARD):
return True
return False
class Block(object):
"""
......@@ -3942,6 +3959,10 @@ class Program(object):
# appending gradients times
self._appending_grad_times = 0
# identifier for auto checkpoint
self._auto_checkpoint_name = unique_name.generate(
"__auto_checkpoint_program__")
# compiled program, i.e. Graph
self._graph = None
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import logging
import hashlib
import json
import os
import six
import time
import collections
from threading import Thread, current_thread
from contextlib import contextmanager
from paddle.fluid import unique_name, compiler
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
from .checkpoint_saver import SerializableBase, CheckpointSaver, PaddleModel
from paddle.fluid.framework import in_dygraph_mode, Program
g_train_epoch_range = None
g_checker = None
logger = None
generator = unique_name.UniqueNameGenerator()
CONST_CHECKPOINT = "checkpoint"
CONST_MEMORYINIT = "memory_init"
# auto checkpoint by dataloader event.
CONST_DACP_TYPE = "dacp"
# auto checkpoint by loop range.
CONST_ACP_TYPE = "acp"
g_acp_type = None
g_program_attr = {} # program_name->can_be_auto_checkpoint
def _get_logger(log_level, name="auto_checkpoint"):
global logger
if logger != None:
return logger
logger = logging.getLogger(name)
logger.setLevel(log_level)
logger.propagate = False
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
log_handler.setFormatter(log_format)
logger.addHandler(log_handler)
return logger
def _thread_checker():
assert current_thread().name == "MainThread", \
"auto checkpoint must run under main thread"
class AutoCheckpointChecker(object):
def __init__(self):
self._run_env = None
self._platform = None
self._job_id = None
self._hdfs_home = None
self._hdfs_name = None
self._hdfs_ugi = None
self._hdfs_checkpoint_path = None
self._trainer_id = None
self._ce_test = None
self._run_env = os.getenv("PADDLE_RUNNING_ENV")
if self._run_env != "PADDLE_EDL_AUTO_CHECKPOINT":
return
try:
self._platform = os.environ["PADDLE_RUNNING_PLATFORM"]
self._job_id = os.environ["PADDLE_JOB_ID"]
self._hdfs_home = os.environ["PADDLE_EDL_HDFS_HOME"]
self._hdfs_name = os.environ["PADDLE_EDL_HDFS_NAME"]
self._hdfs_ugi = os.environ["PADDLE_EDL_HDFS_UGI"]
self._hdfs_checkpoint_path = os.environ[
"PADDLE_EDL_HDFS_CHECKPOINT_PATH"]
self._trainer_id = int(os.environ["PADDLE_TRAINER_ID"])
self._ce_test = int(os.getenv("PADDLE_EDL_ONLY_FOR_CE_TEST", "0"))
self._fs_cache = os.getenv("PADDLE_EDL_FS_CACHE", ".cache")
self._save_checkpoint_inter = int(
os.getenv("PADDLE_EDL_SAVE_CHECKPOINT_INTER", "900")) #s
if not self._ce_test:
assert len(self._hdfs_home) > 3 and \
len(self._hdfs_name) > 6 and \
len(self._hdfs_ugi) > 3 and \
len(self._hdfs_checkpoint_path) > 0, "hdfs environ must set"
else:
assert len(self._hdfs_home) > 3 and \
len(self._hdfs_checkpoint_path) > 0, "hdfs environ must set"
except Exception as e:
logger.fatal("exception:{}".format(e))
sys.exit(1)
def get_range_checkpoint_path(self, name):
return "{}/{}/range/{}".format(self.hdfs_checkpoint_path, self.job_id,
name)
def get_exe_checkpoint_path(self, name):
return "{}/{}/exe/{}".format(self.hdfs_checkpoint_path, self.job_id,
name)
def get_job_path(self):
return "{}/{}".format(self.hdfs_checkpoint_path, self.job_id)
@property
def save_checkpoint_inter(self):
return self._save_checkpoint_inter
def valid(self):
if in_dygraph_mode():
return False
return self._run_env is not None and \
self._platform is not None and \
self._job_id is not None and \
self._hdfs_home is not None and \
self._hdfs_name is not None and \
self._hdfs_ugi is not None and \
self._hdfs_checkpoint_path is not None and \
self._trainer_id is not None
def __str__(self):
return "run_env:{} platform:{} job_id:{} \
hdfs_home:{} hdfs_name:{} hdfs_ugi:{} \
hdfs_checkpoint_path:{} trainer_id:{} ce_test".format(
self._run_env, self._platform, self._hdfs_home, self._hdfs_name,
self._hdfs_ugi, self._hdfs_checkpoint_path, self._trainer_id,
self._ce_test)
@property
def trainer_id(self):
return self._trainer_id
@property
def run_env(self):
return self._run_env
@property
def platform(self):
return self._platform
@property
def job_id(self):
return self._job_id
@property
def hdfs_home(self):
return self._hdfs_home
@property
def hdfs_name(self):
return self._hdfs_name
@property
def ce_test(self):
return self._ce_test
@property
def hdfs_ugi(self):
return self._hdfs_ugi
@property
def hdfs_checkpoint_path(self):
return self._hdfs_checkpoint_path
@staticmethod
def generate_range_name():
return generator("_range_")
class ExeTrainStatus(SerializableBase):
def __init__(self):
self._epoch_no = -1 # start epoch_no
self._hash_key = None
self._key = None
self._checkpoint_path = None
self._checkpoint_no = None
self._restored_from = None
self._exe = None
self._program = None
self._exe_name = None
self._program_name = None
self._file_name = "exe_train_status"
def __eq__(self, t):
return self._epoch_no == t._epoch_no and \
self._hash_key == t._hash_key and \
self._key == t._key and \
self._checkpoint_path == t._checkpoint_path and \
self._checkpoint_no == t._checkpoint_no and \
self._exe_name == t._exe_name and \
self._program_name == t._program_name
def __ne__(self, t):
return not self == t
def serialize(self, path):
file_name = "{}/{}".format(path, self._file_name)
with open(file_name, 'w') as f:
s = self._serialize()
f.write(s)
def _serialize(self, pop_keys=["restored_from"]):
d = self._to_dict()
for k in pop_keys:
d.pop(k, None)
return json.dumps(d)
def deserialize(self, path):
d = None
file_name = "{}/{}".format(path, self._file_name)
with open(file_name, 'r') as f:
s = f.read()
self._deserialize(s)
def _deserialize(self, s):
d = json.loads(s)
self._epoch_no = d["epoch_no"]
self._key = d["key"]
self._hash_key = d["hash_key"]
self._checkpoint_path = d["checkpoint_path"]
self._checkpoint_no = d["checkpoint_no"]
self._exe_name = d["exe_name"]
self._program_name = d["program_name"]
def _to_dict(self):
return {
"epoch_no": self._epoch_no,
"key": self._key,
"hash_key": self._hash_key,
"checkpoint_path": self._checkpoint_path,
"restored_from": self._restored_from,
"exe_name": self._exe_name,
"program_name": self._program_name,
"checkpoint_no": self._checkpoint_no
}
def __str__(self):
return self._serialize([])
class TrainEpochRange(SerializableBase):
def __init__(self,
max_epoch_num,
name,
checkpoint_inter=None,
restored=True):
self._max_epoch_num = max_epoch_num
self._epoch_no = -1 # current epoch_no
self._name = name
self._restored_from = None
self._exe_status = {}
self._flag_generated = False
self._checker = g_checker
if checkpoint_inter is not None:
self._save_checkpoint_inter = checkpoint_inter
else:
self._save_checkpoint_inter = self._checker.save_checkpoint_inter
assert self._save_checkpoint_inter >= 0, "checkpointer:{} must >=0".format(
self._save_checkpoint_inter)
self._last_checkpoint_time = time.time()
self._load_cp_nos = None
self._checkpoint_epoch_no = None
if not self._checker.valid():
return
self._file_name = "range_train_status"
if not restored:
return
self._checkpoint_path = self._checker.get_range_checkpoint_path(name)
config = {
"fs.default.name": self._checker.hdfs_name,
"hadoop.job.ugi": self._checker.hdfs_ugi
}
if self._checker.ce_test:
config = None
self._hdfs = HDFSClient(self._checker.hdfs_home, config)
self._cper = CheckpointSaver(self._hdfs)
_thread_checker()
self._get_last_valid_checkpoint()
def _look_for_valid(self, cp_nos):
cps = []
epoch_no = -1
for i in cp_nos[::-1]:
t = TrainEpochRange(self._max_epoch_num, self.name, restored=False)
self._cper.load_checkpoint(
self._checkpoint_path, [t],
self._checker.trainer_id,
checkpoint_no=i,
local_cache_path=self._checker._fs_cache)
cps.append(t)
logger.debug("look for valid:{} t:{}".format(i, t._serialize()))
if epoch_no < 0:
epoch_no = t._epoch_no
else:
if epoch_no - t._epoch_no >= 1:
return t, i
return None, None
def _get_last_valid_checkpoint(self):
self._load_cp_nos = self._cper.get_checkpoint_no(self._checkpoint_path)
logger.info("find checkpoint nos:{}".format(self._load_cp_nos))
if len(self._load_cp_nos) < 1:
self._restored_from = CONST_MEMORYINIT
return
if g_acp_type == CONST_ACP_TYPE:
# get the last one
self._cper.load_checkpoint(
self._checkpoint_path, [self],
self._checker.trainer_id,
local_cache_path=self._checker._fs_cache)
self._restored_from = CONST_CHECKPOINT
self._checkpoint_epoch_no = self._epoch_no
logger.info("load tain_epoch_range checkpoint:{}".format(
self._serialize()))
elif g_acp_type == CONST_DACP_TYPE:
t, i = self._look_for_valid(self._load_cp_nos)
if t is None:
self._restored_from = CONST_MEMORYINIT
return
self._cper.load_checkpoint(
self._checkpoint_path, [self],
self._checker.trainer_id,
checkpoint_no=i,
local_cache_path=self._checker._fs_cache)
self._restored_from = CONST_CHECKPOINT
self._checkpoint_epoch_no = self._epoch_no
logger.info("load tain_epoch_range checkpoint:{}".format(
self._serialize()))
else:
assert False, "not supported acp_type:{}".format(g_acp_type)
def _to_dict(self):
d = {
"max_epoch_num": self._max_epoch_num,
"epoch_no": self._epoch_no,
"name": self._name,
"checkpoint_path": self._checkpoint_path,
"restored_from": self._restored_from,
"checkpoint_epoch_no": self._checkpoint_epoch_no
}
return d
def __str__(self):
return self._serialize([])
@property
def name(self):
return self._name
def serialize(self, path):
file_name = "{}/{}".format(path, self._file_name)
with open(file_name, 'w') as f:
s = self._serialize()
f.write(s)
def _serialize(self, pop_keys=["restored_from", "checkpoint_epoch_no"]):
# self
d = self._to_dict()
for k in pop_keys:
d.pop(k, None)
# registerd exes
d["exe_status"] = {}
e = d["exe_status"]
for k, t in six.iteritems(self._exe_status):
e[t._key] = t._serialize()
return json.dumps(d)
@property
def restored_from(self):
return self._restored_from
def deserialize(self, path):
d = None
file_name = "{}/{}".format(path, self._file_name)
with open(file_name, 'r') as f:
d = json.load(f)
# self
self._max_epoch_num = d["max_epoch_num"]
self._epoch_no = d["epoch_no"]
self._name = d["name"]
self._checkpoint_path = d["checkpoint_path"]
# exes status
e = d["exe_status"]
for k, v in six.iteritems(e):
t = ExeTrainStatus()
t._deserialize(v)
self._exe_status[k] = t
def next(self):
_thread_checker()
if self._max_epoch_num < 0:
self._max_epoch_num = sys.maxint
assert self._epoch_no >= -1, "self._epoch_no:{} must >=-1".format(
self._epoch_no)
self._last_checkpoint_time = time.time()
start = self._epoch_no + 1
logger.info("started epoch_no:{} max_epoch_num:{}".format(
start, self._max_epoch_num))
for i in range(start, self._max_epoch_num):
self._epoch_no = i
yield i
self.save_checkpoint()
def get(self):
return self._epoch_no
def save_checkpoint(self):
# not save last one because exe and program can't be restored.
if self._checker.trainer_id == 0:
if time.time() - self._last_checkpoint_time >= \
self._save_checkpoint_inter:
if g_acp_type == CONST_ACP_TYPE:
# not save the last one
if self._max_epoch_num > 0 and self._epoch_no != self._max_epoch_num - 1:
self._save_checkpoint()
elif g_acp_type == CONST_DACP_TYPE:
self._save_checkpoint()
else:
assert False, "not supported acp_type:{}".format(g_acp_type)
self._last_checkpoint_time = time.time()
def _save_checkpoint(self):
"""
status => /jobid/xxx_range_xx/range/
model => /exe/
"""
if not self._checker.valid():
return
e = self._exe_status
for k, t in six.iteritems(self._exe_status):
m = PaddleModel(t._exe, t._program)
p = self._checker.get_exe_checkpoint_path(t._hash_key)
t._epoch_no = self.get()
path, checkpoint_no = self._cper.save_checkpoint(
p, [m],
self._checker.trainer_id,
local_cache_path=self._checker._fs_cache)
# index info
t._checkpoint_path = path
t._checkpoint_no = checkpoint_no
e[t._key] = t
logger.debug("save executor checkpoint:{}".format(t._serialize()))
if len(self._exe_status) > 0:
self._cper.save_checkpoint(
self._checkpoint_path, [self],
local_cache_path=self._checker._fs_cache)
logger.info("save train_epoch_range checkpoint:{}".format(
self._serialize()))
self._generate_flag()
def _generate_flag(self):
if self._flag_generated:
return
name = "can_be_auto_checkpoint.flag"
path = self._checker.get_job_path() + "/" + name
logger.info("this job can_be_auto_checkpoint")
self._hdfs.mkdirs(self._checker.get_job_path())
self._hdfs.touch(path, exist_ok=True)
self._flag_generated = True
def _get_train_epoch_range():
return g_train_epoch_range
def _check_program_oprole(program):
global_block = program.global_block()
has_backward = False
has_opt = False
for idx, op in enumerate(global_block.ops):
if op._is_backward_op():
has_backward = True
if op._is_optimize_op():
has_opt = True
if has_backward and has_opt:
return True
return False
def _can_auto_checkpoint(prog):
if not isinstance(prog, compiler.CompiledProgram) and \
not isinstance(prog, Program):
return False
if isinstance(prog, compiler.CompiledProgram):
if prog._program is None or \
prog._program._is_distributed:
return False
else:
if prog._is_distributed:
return False
program = _get_valid_program(prog)
if program._auto_checkpoint_name in g_program_attr:
if not g_program_attr[program._auto_checkpoint_name]:
return False
else:
ret = False
if isinstance(program, compiler.CompiledProgram):
ret = _check_program_oprole(program._program)
else:
ret = _check_program_oprole(program)
g_program_attr[program._auto_checkpoint_name] = ret
if not ret:
logger.debug("program {} need't to auto checkpoint".format(
program._auto_checkpoint_name))
return False
return g_checker.valid() and g_train_epoch_range is not None
def _get_running_key(exe_name, program_name):
return "{}_{}".format(exe_name, program_name)
def _get_checker():
_get_logger(20)
global g_checker
if g_checker is None:
g_checker = AutoCheckpointChecker()
return g_checker
def _normal_yield(max_epoch_num):
if max_epoch_num < 0:
max_epoch_num = sys.maxint
for i in range(0, max_epoch_num):
yield i
return
def train_epoch_range(max_epoch_num, save_checkpoint_inter=None):
global g_acp_type
if not _get_checker().valid():
logger.warning(
"auto checkpoint will take effect automaticly on PaddleCloud")
for i in _normal_yield(max_epoch_num):
yield i
return
if g_acp_type == CONST_DACP_TYPE:
for i in _normal_yield(max_epoch_num):
yield i
return
g_acp_type = CONST_ACP_TYPE
logger.info("acp_type:{}".format(g_acp_type))
global g_train_epoch_range
try:
g_train_epoch_range = TrainEpochRange(
max_epoch_num,
g_checker.generate_range_name(),
checkpoint_inter=save_checkpoint_inter)
for i in g_train_epoch_range.next():
yield i
finally:
g_train_epoch_range = None
def _get_valid_program(prog):
if isinstance(prog, compiler.CompiledProgram):
return prog._program
return prog
def _auto_checkpoint(exe, prog):
_get_checker()
assert exe._auto_checkpoint_name != None
if not _can_auto_checkpoint(prog):
return
program = _get_valid_program(prog)
assert program._auto_checkpoint_name != None
exe_status = g_train_epoch_range._exe_status
key = _get_running_key(exe._auto_checkpoint_name,
program._auto_checkpoint_name)
if g_train_epoch_range.restored_from == CONST_CHECKPOINT:
assert key in exe_status, "when restored key:{} must be in train_epoch_range:{}".format(
key, g_train_epoch_range)
t = None
if key in exe_status:
t = exe_status[key]
if t._restored_from is None:
a = CheckpointSaver(g_train_epoch_range._hdfs)
m = PaddleModel(exe, program)
a.load_checkpoint(
g_checker.get_exe_checkpoint_path(key), [m],
trainer_id=g_checker.trainer_id,
checkpoint_no=t._checkpoint_no,
local_cache_path=g_checker._fs_cache)
t._restored_from = CONST_CHECKPOINT
logger.info("load executor checkpoint {}".format(t))
t._exe = exe
t._program = program
t._epoch_no = g_train_epoch_range.get()
else:
t = ExeTrainStatus()
t._epoch_no = g_train_epoch_range.get()
t._hash_key = key
t._key = key
t._restored_from = CONST_MEMORYINIT
t._exe = exe
t._program = program
t._exe_name = exe._auto_checkpoint_name
t._program_name = program._auto_checkpoint_name
# register this <exe,program,io>
exe_status[key] = t
logger.info("not found checkpoint, so train from epoch 0")
_thread_checker()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..fleet.utils.fs import FS, LocalFS
from ..fleet.utils.hdfs import HDFSClient
from ...compiler import CompiledProgram
class SerializableBase(object):
def serialize(self, path):
raise NotImplementedError
def deserialize(self, path):
raise NotImplementedError
class PaddleModel(SerializableBase):
def __init__(self, exe, program):
self._exe = exe
self._origin_program = program
self._program = program
if isinstance(program, CompiledProgram):
self._program = program._program
self._file_name = "_paddle_fleet_param__"
def serialize(self, path):
from ...io import save_persistables
save_persistables(
executor=self._exe,
dirname=path,
main_program=self._program,
filename=self._file_name)
def deserialize(self, path):
from ...io import load_persistables
load_persistables(
executor=self._exe,
dirname=path,
main_program=self._program,
filename=self._file_name)
class CheckpointSaver(object):
def __init__(self, fs):
self._fs = fs
self._checkpoint_prefix = "__paddle_checkpoint__"
def save_checkpoint(self,
path,
slists,
trainer_id=None,
local_cache_path=".cache"):
"""
Serialize objects in slists to path
Return really saved path and checkpoint_no
"""
if not self._fs.is_exist(path):
self._fs.mkdirs(path)
else:
assert self._fs.is_dir(path), "path:{} must be a directory".format(
path)
max_no = self._get_last_checkpoint_no(path)
if max_no < 0:
max_no = -1
max_no += 1
real_path = "{}/{}.{}".format(path, self._checkpoint_prefix, max_no)
tmp_path = "{}.tmp".format(real_path)
saved_path = tmp_path
local_fs = LocalFS()
cache_path = None
if self._fs.need_upload_download():
cache_path = "{}/{}.{}.saved_cache".format(
local_cache_path, self._checkpoint_prefix, max_no)
if trainer_id is not None:
cache_path = "{}.{}".format(cache_path, trainer_id)
if not local_fs.is_exist(cache_path):
local_fs.mkdirs(cache_path)
else:
assert local_fs.is_dir(cache_path), \
"cache path:{} must be a directory".format(cache_path)
saved_path = cache_path
for s in slists:
s.serialize(saved_path)
if self._fs.need_upload_download():
self._fs.delete(tmp_path)
self._fs.upload(cache_path, tmp_path)
local_fs.delete(cache_path)
self._fs.mv(tmp_path, real_path)
return real_path, max_no
def load_checkpoint(self,
path,
slists,
trainer_id,
local_cache_path=".cache",
checkpoint_no=None,
ignore_empty=True):
"""
Deserialize objects in slists from path
Return really load path
"""
if checkpoint_no is None:
max_no = self._get_last_checkpoint_no(path)
if not ignore_empty:
assert max_no >= 0, "Can't find checkpoint"
if max_no < 0:
return None
checkpoint_no = max_no
else:
assert isinstance(checkpoint_no, int)
assert checkpoint_no >= 0
local_fs = LocalFS()
if self._fs.need_upload_download():
cache_path = "{}/{}.{}.load_cache".format(
local_cache_path, self._checkpoint_prefix, checkpoint_no)
if trainer_id is not None:
cache_path = "{}.{}".format(cache_path, trainer_id)
if not local_fs.is_exist(local_cache_path):
local_fs.mkdirs(local_cache_path)
if local_fs.is_exist(cache_path):
local_fs.delete(cache_path)
real_path = "{}/{}.{}".format(path, self._checkpoint_prefix,
checkpoint_no)
load_path = real_path
if self._fs.need_upload_download():
self._fs.download(real_path, cache_path)
load_path = cache_path
for s in slists:
s.deserialize(load_path)
if self._fs.need_upload_download() and cache_path:
local_fs.delete(cache_path)
return real_path
def get_checkpoint_no(self, root_path):
a = []
dirs = self._fs.list_dirs(root_path)
for d in dirs:
g = d.split(".")
if len(g) != 2:
continue
if g[0] != self._checkpoint_prefix:
continue
try:
n = int(g[1])
a.append(n)
except:
continue
a.sort()
return a
def _get_last_checkpoint_no(self, root_path):
"""
only get the first depth
"""
a = self.get_checkpoint_no(root_path)
if len(a) > 0:
return a[-1]
return -1
def clean_redundant_checkpoints(self, root_path, reserved=[]):
max_no = self._get_last_checkpoint_no(root_path)
if max_no < 0:
return
s = set(reserved)
if len(s) == 0:
s.add(max_no)
dirs = self._fs.list_dirs(root_path)
for d in dirs:
g = d.split(".")
if len(g) != 2:
continue
if g[0] != self._checkpoint_prefix:
continue
try:
n = int(g[1])
if n not in s:
path = "{}/{}.{}".format(root_path, self._checkpoint_prefix,
n)
self._fs.delete(path)
except Exception as e:
print(e)
continue
......@@ -27,6 +27,7 @@ from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid import compiler
from paddle.fluid.incubate.fleet.utils.fs import LocalFS
from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel, CheckpointSaver
import os
import sys
......@@ -46,21 +47,6 @@ class DistFCConfig(object):
pass
class TrainStatus(object):
def __init__(self, epoch_no=-1):
# completed epoch
self._epoch_no = epoch_no
def next(self):
return self._epoch_no + 1
def __eq__(self, t):
return self._epoch_no == t._epoch_no
def __ne__(self, t):
return not self == t
class Collective(Fleet):
def __init__(self):
super(Collective, self).__init__(Mode.COLLECTIVE)
......@@ -152,90 +138,10 @@ class Collective(Fleet):
io.save_persistables(executor, dirname, main_program, filename=filename)
def _save_train_status(self, path, train_status):
d = {}
d["epoch_no"] = train_status._epoch_no
file_name = "{}/fleet_train_status".format(path)
with open(file_name, 'w') as f:
json.dump(d, f)
def _load_train_status(self, path):
file_name = "{}/fleet_train_status".format(path)
r = TrainStatus()
if not os.path.isfile(file_name):
return r
d = {}
with open(file_name, 'r') as f:
d = json.load(f)
assert "epoch_no" in d, "Can't find epoch_no in dict from train_status file:{}".format(
d)
r._epoch_no = d["epoch_no"]
assert r._epoch_no >= 0, "Data in checkpoint file is not valid:{}".format(
d)
return r
def _get_last_checkpoint_no(self, root_path, fs):
"""
only get the first depth
"""
max_no = -1
d = {}
dirs = fs.list_dirs(root_path)
for d in dirs:
g = d.split(".")
if len(g) != 2:
continue
if g[0] != "__paddle_fleet_checkpoint__":
continue
try:
n = int(g[1])
if n > max_no:
max_no = n
except:
continue
return max_no
def clean_redundant_checkpoints(self,
root_path,
fs=LocalFS(),
checkpoint_num=1):
max_no = self._get_last_checkpoint_no(root_path, fs)
if max_no < 0:
return
if checkpoint_num < 1:
checkpoint_num = 1
dirs = fs.list_dirs(root_path)
for d in dirs:
g = d.split(".")
if len(g) != 2:
continue
if g[0] != self._checkpoint_prefix:
continue
try:
n = int(g[1])
if n <= max_no - checkpoint_num:
path = "{}/{}.{}".format(root_path, self._checkpoint_prefix,
n)
fs.delete(path)
except Exception as e:
print(e)
continue
def save_checkpoint(self,
executor,
path,
trainer_id,
train_status,
main_program=None,
fs=LocalFS(),
......@@ -248,53 +154,25 @@ class Collective(Fleet):
if main_program == None:
main_program = self._transpiled_program
if not fs.is_exist(path):
fs.mkdirs(path)
else:
assert fs.is_dir(path), "path:%s must be a directory".format(path)
max_no = self._get_last_checkpoint_no(path, fs=fs)
if max_no < 0:
max_no = -1
real_path = "{}/{}.{}".format(path, self._checkpoint_prefix, max_no + 1)
tmp_path = "{}.tmp".format(real_path)
saved_path = tmp_path
local_fs = LocalFS()
cache_path = None
if fs.need_upload_download():
cache_path = "{}/{}.{}.saved_cache".format(
local_cache_path, self._checkpoint_prefix, max_no + 1)
if not local_fs.is_exist(cache_path):
local_fs.mkdirs(cache_path)
else:
assert fs.is_dir(
path), "cache path:{} must be a directory".format(
cache_path)
saved_path = cache_path
self.save_persistables(
executor=executor,
dirname=saved_path,
main_program=main_program,
filename=self._param_file_name)
self._save_train_status(path=saved_path, train_status=train_status)
if fs.need_upload_download():
fs.delete(tmp_path)
fs.upload(cache_path, tmp_path)
fs.mv(tmp_path, real_path)
m = PaddleModel(executor, main_program)
t = train_status
c = CheckpointSaver(fs)
real_path, checkpoint_no = c.save_checkpoint(
path=path,
slists=[m, t],
trainer_id=trainer_id,
local_cache_path=local_cache_path)
if not remain_all_checkpoint:
self.clean_redundant_checkpoints(path)
c.clean_redundant_checkpoints(path)
return real_path, checkpoint_no
def load_checkpoint(self,
executor,
path,
trainer_id,
train_status,
main_program=None,
fs=LocalFS(),
local_cache_path=".cache",
......@@ -302,39 +180,17 @@ class Collective(Fleet):
"""
This function load persistables and current epoch num from path.
"""
max_no = self._get_last_checkpoint_no(path, fs)
if not ignore_empty:
assert max_no >= 0, "Can't find checkpoint"
if max_no < 0:
return None
local_fs = LocalFS()
if fs.need_upload_download():
cache_path = "{}/{}.{}.load_cache.{}".format(
local_cache_path, self._checkpoint_prefix, max_no, trainer_id)
if not local_fs.is_exist(local_cache_path):
local_fs.mkdirs(local_cache_path)
if local_fs.is_exist(cache_path):
local_fs.delete(cache_path)
real_path = "{}/{}.{}".format(path, self._checkpoint_prefix, max_no)
load_path = real_path
if fs.need_upload_download():
fs.download(real_path, cache_path)
load_path = cache_path
if main_program == None:
main_program = self._transpiled_program
io.load_persistables(
executor=executor,
dirname=load_path,
main_program=main_program,
filename=self._param_file_name)
return self._load_train_status(load_path)
m = PaddleModel(executor, main_program)
c = CheckpointSaver(fs)
return c.load_checkpoint(
path, [m, train_status],
trainer_id=trainer_id,
ignore_empty=ignore_empty,
local_cache_path=local_cache_path)
fleet = Collective()
......
......@@ -45,6 +45,10 @@ class FSTimeOut(Exception):
pass
class FSShellCmdAborted(ExecuteError):
pass
class FS(object):
@abc.abstractmethod
def ls_dir(self, fs_path):
......@@ -87,7 +91,7 @@ class FS(object):
raise NotImplementedError
@abc.abstractmethod
def mv(self, fs_src_path, fs_dst_path):
def mv(self, fs_src_path, fs_dst_path, overwrite=False, test_exists=False):
raise NotImplementedError
@abc.abstractmethod
......@@ -98,6 +102,10 @@ class FS(object):
def list_dirs(self, fs_path):
raise NotImplementedError
@abc.abstractmethod
def touch(self, fs_path, exist_ok=True):
raise NotImplementedError
class LocalFS(FS):
def ls_dir(self, fs_path):
......@@ -138,13 +146,21 @@ class LocalFS(FS):
def is_exist(self, fs_path):
return os.path.exists(fs_path)
def touch(self, fs_path):
return Path(fs_path).touch()
def touch(self, fs_path, exist_ok=True):
if self.is_exist(fs_path):
if exist_ok:
return
raise FSFileExistsError
return Path(fs_path).touch(exist_ok=True)
def mv(self, src_path, dst_path):
def mv(self, src_path, dst_path, overwrite=False, test_exists=False):
if not self.is_exist(src_path):
raise FSFileNotExistsError
if overwrite and self.is_exist(dst_path):
self.delete(dst_path)
if self.is_exist(dst_path):
raise FSFileExistsError
......
......@@ -26,8 +26,8 @@ import time
import logging
import six
from . import fs
from .fs import FS, LocalFS, FSFileExistsError, FSFileNotExistsError, ExecuteError, FSTimeOut
import paddle.fluid as fluid
from .fs import FS, LocalFS, FSFileExistsError, FSFileNotExistsError, ExecuteError, FSTimeOut, FSShellCmdAborted
from paddle.fluid import core
import functools
from pathlib import PurePosixPath, Path
......@@ -36,21 +36,39 @@ import shutil
__all__ = ["HDFSClient"]
def _handle_errors(f):
def handler(*args, **kwargs):
start = time.time()
while True:
try:
return f(*args, **kwargs)
except ExecuteError as e:
o = args[0]
def _handle_errors(max_time_out=None):
def decorator(f):
@functools.wraps(f)
def handler(*args, **kwargs):
o = args[0]
time_out = max_time_out
if time_out is None:
time_out = float(o._time_out) / 1000.0
inter = float(o._sleep_inter) / 1000.0
if time.time() - start >= time_out:
raise FSTimeOut
time.sleep(inter)
else:
time_out /= 1000.0
inter = float(o._sleep_inter) / 1000.0
start = time.time()
last_print_time = start
while True:
try:
return f(*args, **kwargs)
#important: only ExecuteError need to retry
except ExecuteError as e:
if time.time() - start >= time_out:
raise FSTimeOut("args:{} timeout:{}".format(
args, time.time() - start))
time.sleep(inter)
return functools.wraps(f)(handler)
if time.time() - last_print_time > 30:
print("hadoop operator timeout:args:{} timeout:{}".format(
args, time.time() - start))
last_print_time = time.time()
return handler
return decorator
class HDFSClient(FS):
......@@ -72,6 +90,7 @@ class HDFSClient(FS):
if configs:
for k, v in six.iteritems(configs):
config_command = '-D%s=%s' % (k, v)
self.pre_commands.append(config_command)
self._time_out = time_out
self._sleep_inter = sleep_inter
......@@ -80,17 +99,22 @@ class HDFSClient(FS):
r'\s?responseErrorMsg\s?\:.*, errorCode\:\s?[0-9]+, path\:')
def _run_cmd(self, cmd, redirect_stderr=False):
ret, output = fluid.core.shell_execute_cmd(cmd, 0, 0, redirect_stderr)
return int(ret), output.splitlines()
exe_cmd = "{} -{}".format(self._base_cmd, cmd)
ret, output = core.shell_execute_cmd(exe_cmd, 0, 0, redirect_stderr)
ret = int(ret)
if ret == 134:
raise FSShellCmdAborted(cmd)
return ret, output.splitlines()
@_handle_errors()
def list_dirs(self, fs_path):
if not self.is_exist(fs_path):
return []
dirs, _ = self.ls_dir(fs_path)
dirs, files = self._ls_dir(fs_path)
return dirs
@_handle_errors
@_handle_errors()
def ls_dir(self, fs_path):
"""
list directory under fs_path, and only give the pure name, not include the fs_path
......@@ -98,11 +122,14 @@ class HDFSClient(FS):
if not self.is_exist(fs_path):
return [], []
cmd = "{} -ls {}".format(self._base_cmd, fs_path)
return self._ls_dir(fs_path)
def _ls_dir(self, fs_path):
cmd = "ls {}".format(fs_path)
ret, lines = self._run_cmd(cmd)
if ret != 0:
raise ExecuteError
raise ExecuteError(cmd)
dirs = []
files = []
......@@ -111,9 +138,6 @@ class HDFSClient(FS):
if len(arr) != 8:
continue
if fs_path not in arr[7]:
continue
p = PurePosixPath(arr[7])
if arr[0][0] == 'd':
dirs.append(p.name)
......@@ -130,18 +154,20 @@ class HDFSClient(FS):
return None
@_handle_errors
@_handle_errors()
def is_dir(self, fs_path):
if not self.is_exist(fs_path):
return False
cmd = "{} -test -d {}".format(
self._base_cmd, fs_path, redirect_stderr=True)
return self._is_dir(fs_path)
def _is_dir(self, fs_path):
cmd = "test -d {}".format(fs_path, redirect_stderr=True)
ret, lines = self._run_cmd(cmd)
if ret:
# other error
if self._test_match(lines) != None:
raise ExecuteError
if self._test_match(lines):
raise ExecuteError(cmd)
return False
......@@ -151,94 +177,155 @@ class HDFSClient(FS):
if not self.is_exist(fs_path):
return False
return not self.is_dir(fs_path)
return not self._is_dir(fs_path)
@_handle_errors
@_handle_errors()
def is_exist(self, fs_path):
cmd = "{} -ls {} ".format(self._base_cmd, fs_path)
cmd = "ls {} ".format(fs_path)
ret, out = self._run_cmd(cmd, redirect_stderr=True)
if ret != 0:
for l in out:
if "No such file or directory" in l:
return False
raise ExecuteError
raise ExecuteError(cmd)
return True
@_handle_errors
# can't retry
def upload(self, local_path, fs_path):
if self.is_exist(fs_path):
raise FSFileExistsError
raise FSFileExistsError("{} exists".format(fs_path))
local = LocalFS()
if not local.is_exist(local_path):
raise FSFileNotExistsError
cmd = "{} -put {} {}".format(self._base_cmd, local_path, fs_path)
ret, lines = self._run_cmd(cmd)
if ret != 0:
raise ExecuteError
@_handle_errors
raise FSFileNotExistsError("{} not exists".format(local_path))
return self._try_upload(local_path, fs_path)
@_handle_errors()
def _try_upload(self, local_path, fs_path):
cmd = "put {} {}".format(local_path, fs_path)
ret = 0
try:
ret, lines = self._run_cmd(cmd)
if ret != 0:
raise ExecuteError(cmd)
except Exception as e:
self.delete(fs_path)
raise e
# can't retry
def download(self, fs_path, local_path):
if self.is_exist(local_path):
raise FSFileExistsError
raise FSFileExistsError("{} exists".format(local_path))
if not self.is_exist(fs_path):
raise FSFileNotExistsError
cmd = "{} -get {} {}".format(self._base_cmd, fs_path, local_path)
ret, lines = self._run_cmd(cmd)
if ret != 0:
raise ExecuteError
@_handle_errors
raise FSFileNotExistsError("{} not exits".format(fs_path))
return self._try_download(fs_path, local_path)
@_handle_errors()
def _try_download(self, fs_path, local_path):
cmd = "get {} {}".format(fs_path, local_path)
ret = 0
try:
ret, lines = self._run_cmd(cmd)
if ret != 0:
raise ExecuteError(cmd)
except Exception as e:
local_fs = LocalFS()
local_fs.delete(local_path)
raise e
@_handle_errors()
def mkdirs(self, fs_path):
if self.is_exist(fs_path):
return
cmd = "{} -mkdir {}".format(self._base_cmd, fs_path)
ret, lines = self._run_cmd(cmd)
out_hdfs = False
cmd = "mkdir {} ".format(fs_path)
ret, out = self._run_cmd(cmd, redirect_stderr=True)
if ret != 0:
raise ExecuteError
for l in out:
if "No such file or directory" in l:
out_hdfs = True
break
if not out_hdfs:
raise ExecuteError(cmd)
if out_hdfs and not self.is_exist(fs_path):
cmd = "mkdir -p {}".format(fs_path)
ret, lines = self._run_cmd(cmd)
if ret != 0:
raise ExecuteError(cmd)
def mv(self, fs_src_path, fs_dst_path, overwrite=False, test_exists=True):
if overwrite and self.is_exist(fs_dst_path):
self.delete(fs_dst_path)
@_handle_errors
def mv(self, fs_src_path, fs_dst_path, test_exists=True):
if test_exists:
if not self.is_exist(fs_src_path):
raise FSFileNotExistsError
raise FSFileNotExistsError("{} is not exists".format(
fs_src_path))
if self.is_exist(fs_dst_path):
raise FSFileExistsError
raise FSFileExistsError("{} exists already".format(
fs_src_path, fs_dst_path, fs_dst_path))
return self._try_mv(fs_src_path, fs_dst_path)
@_handle_errors()
def _try_mv(self, fs_src_path, fs_dst_path):
cmd = "mv {} {}".format(fs_src_path, fs_dst_path)
ret = 0
try:
ret, _ = self._run_cmd(cmd)
if ret != 0:
raise ExecuteError(cmd)
except Exception as e:
if not self.is_exist(fs_src_path) and \
self.is_exist(fs_dst_path):
return
raise e
cmd = "{} -mv {} {}".format(self._base_cmd, fs_src_path, fs_dst_path)
ret, _ = self._run_cmd(cmd)
if ret != 0:
raise ExecuteError
@_handle_errors
def _rmr(self, fs_path):
cmd = "{} -rmr {}".format(self._base_cmd, fs_path)
cmd = "rmr {}".format(fs_path)
ret, _ = self._run_cmd(cmd)
if ret != 0:
raise ExecuteError
raise ExecuteError(cmd)
@_handle_errors
def _rm(self, fs_path):
cmd = "{} -rm {}".format(self._base_cmd, fs_path)
cmd = "rm {}".format(fs_path)
ret, _ = self._run_cmd(cmd)
if ret != 0:
raise ExecuteError
raise ExecuteError(cmd)
@_handle_errors()
def delete(self, fs_path):
if not self.is_exist(fs_path):
return
is_dir = self.is_dir(fs_path)
is_dir = self._is_dir(fs_path)
if is_dir:
return self._rmr(fs_path)
return self._rm(fs_path)
def touch(self, fs_path, exist_ok=True):
if self.is_exist(fs_path):
if exist_ok:
return
raise FSFileExistsError
return self._touchz(fs_path)
@_handle_errors()
def _touchz(self, fs_path):
cmd = "touchz {}".format(fs_path)
ret, _ = self._run_cmd(cmd)
if ret != 0:
raise ExecuteError
def need_upload_download(self):
return True
......@@ -86,6 +86,10 @@ if(WIN32)
LIST(REMOVE_ITEM TEST_OPS test_ref_by_trainer_id_op)
endif()
LIST(REMOVE_ITEM TEST_OPS test_auto_checkpoint)
LIST(REMOVE_ITEM TEST_OPS test_auto_checkpoint2)
LIST(REMOVE_ITEM TEST_OPS test_checkpoint_saver)
if(APPLE OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_hdfs)
LIST(REMOVE_ITEM TEST_OPS test_fs_interface)
......@@ -190,10 +194,11 @@ function(bash_test_modules TARGET_NAME)
endif()
set(options SERIAL)
set(oneValueArgs "")
set(multiValueArgs MODULES DEPS ENVS LABELS)
set(oneValueArgs TIMEOUT START_BASH)
set(multiValueArgs DEPS ENVS LABELS)
cmake_parse_arguments(bash_test_modules "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(timeout 350)
if(${bash_test_modules_TIMEOUT})
set(timeout ${bash_test_modules_TIMEOUT})
......@@ -204,13 +209,13 @@ function(bash_test_modules TARGET_NAME)
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS}
WITH_COVERAGE=ON COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES}
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_START_BASH}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else()
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS}
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_MODULES}
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_START_BASH}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
......@@ -397,15 +402,16 @@ if(WITH_DISTRIBUTE)
if(NOT APPLE)
if(WITH_GPU)
# NOTE. test_launch only work in gpu collective mode
bash_test_modules(test_launch MODULES test_launch.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_launch START_BASH test_launch.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
py_test_modules(test_fleet_checkpoint MODULES test_fleet_checkpoint)
endif()
bash_test_modules(test_launch_ps MODULES test_launch_ps.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_fleet_launch MODULES test_fleet_launch.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_launch_ps START_BASH test_launch_ps.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_fleet_launch START_BASH test_fleet_launch.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
set(dist_ut_port 20001)
foreach(TEST_OP ${DIST_TEST_OPS})
bash_test_modules(${TEST_OP} MODULES dist_test.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}")
bash_test_modules(${TEST_OP} START_BASH dist_test.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}")
MATH(EXPR dist_ut_port "${dist_ut_port}+50")
endforeach(TEST_OP)
endif(NOT APPLE)
......@@ -441,6 +447,12 @@ if(NOT WIN32)
set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 450)
endif()
if(NOT APPLE AND NOT WIN32)
bash_test_modules(test_auto_checkpoint START_BASH dist_test.sh TIMEOUT 600)
bash_test_modules(test_auto_checkpoint2 START_BASH dist_test.sh TIMEOUT 600)
bash_test_modules(test_checkpoint_saver START_BASH dist_test.sh TIMEOUT 600)
endif()
add_subdirectory(sequence)
add_subdirectory(dygraph_to_static)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os
import sys
from paddle.fluid.incubate.fleet.utils.fs import LocalFS
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
import paddle.fluid.incubate.checkpoint.auto_checkpoint as acp
from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel
from paddle.fluid.framework import program_guard
from paddle.fluid import unique_name
import numpy as np
from paddle.io import Dataset, BatchSampler, DataLoader
BATCH_NUM = 20
BATCH_SIZE = 16
#IMAGE_SIZE = 128
CLASS_NUM = 10
USE_GPU = False # whether use GPU to run model
places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()
logger = None
def get_logger():
global logger
logger = acp._get_logger(20)
return logger
def get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
def sample_list_generator_creator():
def __reader__():
for _ in range(BATCH_NUM):
sample_list = []
for _ in range(BATCH_SIZE):
image, label = get_random_images_and_labels([16, 16], [1])
sample_list.append([image, label])
yield sample_list
return __reader__
class AutoCheckpointBase(unittest.TestCase):
def _init_env(self,
exe,
main_prog,
startup_prog,
minimize=True,
iterable=True):
def simple_net():
image = fluid.data(
name='image', shape=[-1, 16, 16], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
fc_tmp = fluid.layers.fc(image, size=CLASS_NUM)
cross_entropy = fluid.layers.softmax_with_cross_entropy(fc_tmp,
label)
loss = fluid.layers.reduce_mean(cross_entropy)
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
if minimize:
sgd.minimize(loss)
return sgd, loss, image, label
with program_guard(main_prog, startup_prog):
sgd, loss, image, label = simple_net()
if minimize:
compiled = fluid.CompiledProgram(main_prog).with_data_parallel(
loss_name=loss.name)
else:
compiled = None
loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=64,
use_double_buffer=True,
iterable=iterable)
loader.set_sample_list_generator(sample_list_generator_creator(),
places[0])
if minimize:
exe.run(startup_prog)
return compiled, loader, sgd, loss, image, label
def _generate(self):
main_prog = fluid.Program()
startup_prog = fluid.Program()
exe = fluid.Executor(places[0])
return exe, main_prog, startup_prog
def _reset_generator(self):
unique_name.generator = fluid.unique_name.UniqueNameGenerator()
acp.generator = fluid.unique_name.UniqueNameGenerator()
acp.g_acp_type = None
acp.g_checker = acp.AutoCheckpointChecker()
acp.g_program_attr = {}
def _clear_envs(self):
os.environ.pop("PADDLE_RUNNING_ENV", None)
def _readd_envs(self):
os.environ["PADDLE_RUNNING_ENV"] = "PADDLE_EDL_AUTO_CHECKPOINT"
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os
import sys
from paddle.fluid.incubate.fleet.utils.fs import LocalFS
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
import paddle.fluid.incubate.checkpoint.auto_checkpoint as acp
from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel
from paddle.fluid.framework import program_guard
from paddle.fluid import unique_name
import numpy as np
from paddle.io import Dataset, BatchSampler, DataLoader
from paddle.fluid.tests.unittests.auto_checkpoint_utils import AutoCheckpointBase, get_logger
logger = get_logger()
class AutoCheckPointACLBase(AutoCheckpointBase):
def setUp(self):
get_logger()
logger.info("enter tests")
self._old_environ = dict(os.environ)
proc_env = {
"PADDLE_RUNNING_ENV": "PADDLE_EDL_AUTO_CHECKPOINT",
"PADDLE_TRAINER_ID": "0",
"PADDLE_RUNNING_PLATFORM": "PADDLE_CLOUD",
"PADDLE_JOB_ID": "test_job_auto",
"PADDLE_EDL_HDFS_HOME": "/usr/local/hadoop-2.7.7",
"PADDLE_EDL_HDFS_NAME": "",
"PADDLE_EDL_HDFS_UGI": "",
"PADDLE_EDL_HDFS_CHECKPOINT_PATH": "auto_checkpoint",
"PADDLE_EDL_ONLY_FOR_CE_TEST": "1",
"PADDLE_EDL_FS_CACHE": ".auto_checkpoint_test",
"PADDLE_EDL_SAVE_CHECKPOINT_INTER": "0"
}
os.environ.update(proc_env)
def tearDown(self):
os.environ.clear()
os.environ.update(self._old_environ)
def _run_normal(self):
exe, main_prog, startup_prog = self._generate()
save_dir = "./run_save_model"
fs = LocalFS()
fs.delete(save_dir)
logger.info("begin _run_normal")
compiled, data_loader, optimizer, loss, image, label = self._init_env(
exe, main_prog, startup_prog)
for i in range(3):
self.assertEqual(acp._get_train_epoch_range(), None)
self.assertEqual(acp.g_acp_type, None)
for data in data_loader():
self.assertEqual(acp.g_acp_type, None)
self.assertEqual(acp._get_train_epoch_range(), None)
fetch = exe.run(compiled, feed=data, fetch_list=[loss])
self.assertEqual(acp.g_acp_type, None)
self.assertEqual(acp._get_train_epoch_range(), None)
m1 = PaddleModel(exe, compiled)
m1.serialize(save_dir)
m2 = PaddleModel(exe, compiled)
m2.deserialize(save_dir)
logger.info("end _run_normal")
fs.delete(save_dir)
def _not_use_train(self):
logger.info("begin _not_use_train")
exe, main_prog, startup_prog = self._generate()
compiled, data_loader, optimizer, loss, image, label = \
self._init_env(exe, main_prog, startup_prog)
epochs = []
for i in acp.train_epoch_range(3, 0):
epochs.append(i)
for data in data_loader():
fetch = exe.run(compiled, feed=data, fetch_list=[loss])
self.assertEqual(epochs, [0, 1, 2])
logger.info("end _not_use_train")
def _run_save_0(self, break_epoch_no=None):
logger.info("begin _run_save_0")
fs = LocalFS()
save_dir = "./run_save_0"
fs.delete(save_dir)
exe, main_prog, startup_prog = self._generate()
compiled, data_loader, optimizer, loss, image, label = \
self._init_env(exe, main_prog, startup_prog)
o = None
i = 0
name = None
for i in acp.train_epoch_range(3, 0):
o = acp._get_train_epoch_range()
name = o.name
for data in data_loader():
fetch = exe.run(compiled, feed=data, fetch_list=[loss])
self.assertEqual(len(o._exe_status), 1)
if break_epoch_no is not None:
if i == break_epoch_no:
break
o = acp._get_train_epoch_range()
assert o == None, "now train epoch must not exits now"
if break_epoch_no is None:
self.assertEqual(i, 2)
else:
self.assertEqual(i, break_epoch_no)
fs.delete(save_dir)
logger.info("end _run_save_0")
def _run_load_0(self, break_epoch_no=None):
logger.info("begin _run_load_0")
exe, main_prog, startup_prog = self._generate()
fs = LocalFS()
save_dir = "./run_load_0"
fs.delete(save_dir)
compiled, data_loader, optimizer, loss, image, label = self._init_env(
exe, main_prog, startup_prog)
o = None
i = 0
check = False
epochs = []
for i in acp.train_epoch_range(3, 0):
epochs.append(i)
for data in data_loader():
fetch = exe.run(compiled, feed=data, fetch_list=[loss])
o = acp._get_train_epoch_range()
self.assertTrue(o == None, "now train epoch must not exits now")
self.assertEqual(i, 2)
if break_epoch_no is not None:
if break_epoch_no == 0:
self.assertEqual(epochs, [0, 1, 2])
elif break_epoch_no == 1:
self.assertEqual(epochs, [1, 2])
elif break_epoch_no == 2:
self.assertEqual(epochs, [2])
else:
self.assertEqual(epochs, [2])
fs.delete(save_dir)
logger.info("begin _run_load_0")
class AutoCheckpointTest(AutoCheckPointACLBase):
def setUp(self):
get_logger()
logger.info("enter tests")
self._old_environ = dict(os.environ)
proc_env = {
"PADDLE_RUNNING_ENV": "PADDLE_EDL_AUTO_CHECKPOINT",
"PADDLE_TRAINER_ID": "0",
"PADDLE_RUNNING_PLATFORM": "PADDLE_CLOUD",
"PADDLE_JOB_ID": "test_job_auto_1",
"PADDLE_EDL_HDFS_HOME": "/usr/local/hadoop-2.7.7",
"PADDLE_EDL_HDFS_NAME": "",
"PADDLE_EDL_HDFS_UGI": "",
"PADDLE_EDL_HDFS_CHECKPOINT_PATH": "auto_checkpoint_1",
"PADDLE_EDL_ONLY_FOR_CE_TEST": "1",
"PADDLE_EDL_FS_CACHE": ".auto_checkpoint_test_1",
"PADDLE_EDL_SAVE_CHECKPOINT_INTER": "0"
}
os.environ.update(proc_env)
def test_normal(self):
logger.info("begin test_normal")
checker = acp._get_checker()
fs = HDFSClient(checker.hdfs_home, None)
fs.delete(checker.hdfs_checkpoint_path)
self._clear_envs()
self._reset_generator()
self._run_normal()
self._readd_envs()
logger.info("end test_normal")
def test_basic(self):
logger.info("begin test_basic")
checker = acp._get_checker()
self.assertEqual(checker.run_env, "PADDLE_EDL_AUTO_CHECKPOINT")
self.assertEqual(checker.platform, "PADDLE_CLOUD")
self.assertEqual(checker.save_checkpoint_inter, 0)
print(checker)
fs = HDFSClient(checker.hdfs_home, None)
fs.delete(checker.hdfs_checkpoint_path)
self._reset_generator()
self._run_save_0()
self._reset_generator()
self._run_load_0()
logger.info("end test_basic")
def test_not_use(self):
logger.info("begin test_not_use")
self._clear_envs()
self._reset_generator()
self._not_use_train()
self._readd_envs()
logger.info("end test_not_use")
def test_multiple(self):
checker = acp._get_checker()
fs = HDFSClient(checker.hdfs_home, None)
fs.delete(checker.hdfs_checkpoint_path)
self._reset_generator()
logger.info("begin test_multiple")
fs = LocalFS()
save_dir = "./run_save_0"
fs.delete(save_dir)
exe, main_prog1, startup_prog1 = self._generate()
_, main_prog2, startup_prog2 = self._generate()
compiled1, data_loader1, optimizer1, loss1, image1, label1 = \
self._init_env(exe, main_prog1, startup_prog1)
compiled2, data_loader2, optimizer2, loss2, image2, label2 = \
self._init_env(exe, main_prog2, startup_prog2)
o = None
epochs = []
for i in acp.train_epoch_range(3, 0):
for data in data_loader1():
fetch = exe.run(compiled1, feed=data, fetch_list=[loss1])
for data in data_loader2():
fetch = exe.run(compiled2, feed=data, fetch_list=[loss2])
o = acp._get_train_epoch_range()
self.assertEqual(len(o._exe_status), 2)
print(o._exe_status)
epochs.append(i)
o = acp._get_train_epoch_range()
self.assertTrue(o == None, "now train epoch must not exits now")
self.assertEqual(i, 2)
self.assertEqual(epochs, [0, 1, 2])
fs.delete(save_dir)
logger.info("end test_multiple")
def test_distributed_basic(self):
checker = acp._get_checker()
fs = HDFSClient(checker.hdfs_home, None)
fs.delete(checker.hdfs_checkpoint_path)
self._reset_generator()
logger.info("begin test_distributed_basic")
fs = LocalFS()
save_dir = "./run_save_0"
fs.delete(save_dir)
#basic
exe, main_prog, startup_prog = self._generate()
compiled, data_loader, optimizer, loss, image, label = \
self._init_env(exe, main_prog, startup_prog, minimize=False)
#fleet
os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["PADDLE_TRAINER_ID"] = "0"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:6070"
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
with fluid.program_guard(main_prog, startup_prog):
dist_optimizer = fleet.distributed_optimizer(optimizer)
dist_optimizer.minimize(loss)
exe.run(startup_prog)
o = None
i = 0
name = None
for i in acp.train_epoch_range(3, 0):
o = acp._get_train_epoch_range()
name = o.name
logger.info("_run_save_0 name:{} epoch_no:{}".format(o.name, i))
for data in data_loader():
fetch = exe.run(fleet.main_program,
feed=data,
fetch_list=[loss])
self.assertEqual(len(o._exe_status), 1)
o = acp._get_train_epoch_range()
assert o == None, "now train epoch must not exits now"
self.assertEqual(i, 2)
fs.delete(save_dir)
logger.info("end test_distributed_basic")
def test_checker(self):
os.environ.pop("PADDLE_JOB_ID", None)
try:
checker = AutoCheckpointChecker()
self.assertFalse(True)
except Exception as e:
pass
os.environ["PADDLE_JOB_ID"] = "test_job_auto_1"
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os
import sys
from paddle.fluid.incubate.fleet.utils.fs import LocalFS
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
import paddle.fluid.incubate.checkpoint.auto_checkpoint as acp
from paddle.fluid.incubate.checkpoint.checkpoint_saver import PaddleModel
from paddle.fluid.framework import program_guard
from paddle.fluid import unique_name
import numpy as np
from paddle.io import Dataset, BatchSampler, DataLoader
from paddle.fluid.tests.unittests.auto_checkpoint_utils import AutoCheckpointBase, get_logger
from paddle.fluid.tests.unittests.test_auto_checkpoint import AutoCheckPointACLBase
logger = get_logger()
class AutoCheckpointTest2(AutoCheckPointACLBase):
def setUp(self):
get_logger()
logger.info("enter tests")
self._old_environ = dict(os.environ)
proc_env = {
"PADDLE_RUNNING_ENV": "PADDLE_EDL_AUTO_CHECKPOINT",
"PADDLE_TRAINER_ID": "0",
"PADDLE_RUNNING_PLATFORM": "PADDLE_CLOUD",
"PADDLE_JOB_ID": "test_job_auto_2",
"PADDLE_EDL_HDFS_HOME": "/usr/local/hadoop-2.7.7",
"PADDLE_EDL_HDFS_NAME": "",
"PADDLE_EDL_HDFS_UGI": "",
"PADDLE_EDL_HDFS_CHECKPOINT_PATH": "auto_checkpoint_2",
"PADDLE_EDL_ONLY_FOR_CE_TEST": "1",
"PADDLE_EDL_FS_CACHE": ".auto_checkpoint_test_2",
"PADDLE_EDL_SAVE_CHECKPOINT_INTER": "0"
}
os.environ.update(proc_env)
def test_corner_epoch_no(self):
logger.info("begin test_corener_epoch_no")
checker = acp._get_checker()
fs = HDFSClient(checker.hdfs_home, None)
for i in range(3):
fs.delete(checker.hdfs_checkpoint_path)
self._reset_generator()
self._run_save_0(break_epoch_no=i)
self._reset_generator()
self._run_load_0(break_epoch_no=i)
fs.delete(checker.hdfs_checkpoint_path)
logger.info("end test_corener_epoch_no")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
from paddle.fluid.incubate.checkpoint.auto_checkpoint import ExeTrainStatus
from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver
import os
import sys
from paddle.fluid.incubate.fleet.utils.fs import LocalFS
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver
class CheckpointerSaverTest(unittest.TestCase):
def test(self):
fs = HDFSClient("/usr/local/hadoop-2.7.7", None)
dir_path = "./checkpointsaver_test"
fs.delete(dir_path)
s = CheckpointSaver(fs)
fs.mkdirs("{}/exe.exe".format(dir_path))
fs.mkdirs("{}/exe.1".format(dir_path))
fs.mkdirs("{}/exe".format(dir_path))
a = s.get_checkpoint_no(dir_path)
self.assertEqual(len(a), 0)
fs.mkdirs("{}/__paddle_checkpoint__.0".format(dir_path))
fs.mkdirs("{}/__paddle_checkpoint__.exe".format(dir_path))
a = s.get_checkpoint_no(dir_path)
self.assertEqual(len(a), 1)
s.clean_redundant_checkpoints(dir_path)
s.clean_redundant_checkpoints(dir_path)
fs.delete(dir_path)
if __name__ == '__main__':
unittest.main()
......@@ -170,7 +170,8 @@ def program_equal(a, b):
k))
return False
assert (len(a.blocks) == len(b.blocks))
elif k == '_auto_checkpoint_name':
continue
elif (v != b.__dict__[k]):
raise ValueError("In program_equal not equal:{0}\n".format(k))
......
......@@ -15,12 +15,15 @@
import unittest
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet, TrainStatus
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
from paddle.fluid.incubate.checkpoint.auto_checkpoint import ExeTrainStatus
from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver
import os
import sys
from paddle.fluid.incubate.fleet.utils.fs import LocalFS
from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient
from paddle.fluid.incubate.checkpoint.checkpoint_saver import CheckpointSaver
class FleetTest(unittest.TestCase):
......@@ -49,24 +52,35 @@ class FleetTest(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
status = TrainStatus(2)
fleet.save_checkpoint(exe, dir_path, train_status=status, fs=fs)
n1 = fleet._get_last_checkpoint_no(dir_path, fs=fs)
status = ExeTrainStatus()
status.epoch_no = 2
_, n1 = fleet.save_checkpoint(
exe, dir_path, trainer_id=0, train_status=status, fs=fs)
status2 = fleet.load_checkpoint(exe, dir_path, trainer_id=0, fs=fs)
status2 = ExeTrainStatus()
fleet.load_checkpoint(
exe, dir_path, trainer_id=0, fs=fs, train_status=status2)
self.assertEqual(status2, status)
fleet.save_checkpoint(exe, dir_path, train_status=status, fs=fs)
n2 = fleet._get_last_checkpoint_no(dir_path, fs=fs)
_, n2 = fleet.save_checkpoint(
exe,
dir_path,
trainer_id=0,
train_status=status,
fs=fs,
remain_all_checkpoint=False)
self.assertEqual(n2, n1 + 1)
fleet.clean_redundant_checkpoints(dir_path, fs=fs)
c = CheckpointSaver(fs)
cp_nos = c.get_checkpoint_no(dir_path)
assert len(cp_nos) == 1 # cleanup all others
# unnormal
# test remain_all_checkpoint
fleet.save_checkpoint(
exe,
dir_path,
trainer_id=0,
train_status=status,
fs=fs,
remain_all_checkpoint=False)
......@@ -79,6 +93,7 @@ class FleetTest(unittest.TestCase):
fleet.save_checkpoint(
exe,
dir_path,
trainer_id=0,
train_status=status,
fs=fs,
cache_path=cache_path)
......@@ -88,8 +103,13 @@ class FleetTest(unittest.TestCase):
# can't load under a file
try:
status2 = fleet.load_checkpoint(
exe, dir_path, trainer_id=0, fs=fs, cache_path=cache_path)
fleet.load_checkpoint(
exe,
dir_path,
trainer_id=0,
train_status=status2,
fs=fs,
cache_path=cache_path)
self.assertFalse(True)
except:
pass
......
......@@ -15,7 +15,7 @@
import unittest
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet, TrainStatus
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os
import sys
import inspect
......@@ -38,6 +38,8 @@ class FSTest(unittest.TestCase):
func(a)
elif len(args) == 3:
func(a, a)
elif len(args) == 5:
func(a, a, a, a)
print("args:", args, len(args), "func:", func)
self.assertFalse(True)
except NotImplementedError as e:
......
......@@ -15,7 +15,7 @@
import unittest
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet, TrainStatus
from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os
import sys
......@@ -57,6 +57,12 @@ class FSTest(unittest.TestCase):
fs.delete(dir_path)
self.assertTrue(not fs.is_exist(dir_path))
fs.mkdirs(dir_path)
fs.mkdirs(new_dir_path)
fs.mv(dir_path, new_dir_path, overwrite=True)
self.assertTrue(not fs.is_exist(dir_path))
self.assertTrue(fs.is_exist(new_dir_path))
def _test_touch_file(self, fs):
file_path = os.path.abspath("./test_file")
......@@ -104,6 +110,35 @@ class FSTest(unittest.TestCase):
fs.delete(dst_file)
fs.delete(src_file)
def _test_try_download(self, fs):
src_file = os.path.abspath("./test_try_download.src")
dst_file = os.path.abspath("./test_try_download.dst")
fs.delete(dst_file)
fs.delete(src_file)
try:
fs._try_download(src_file, dst_file)
self.assertFalse(True)
except Exception as e:
pass
fs.delete(dst_file)
fs.delete(src_file)
def _test_try_upload(self, fs):
src_file = os.path.abspath("./test_try_upload.src")
dst_file = os.path.abspath("./test_try_uolpad.dst")
try:
fs._try_upload(src_file, dst_file)
self.assertFalse(True)
except Exception as e:
pass
fs.delete(dst_file)
fs.delete(src_file)
def _test_download(self, fs):
src_file = os.path.abspath("./test_download.src")
dst_file = os.path.abspath("./test_download.dst")
......@@ -138,8 +173,27 @@ class FSTest(unittest.TestCase):
fs.mkdirs(dir_name)
fs.mkdirs(dir_name)
def _test_rm(self, fs):
dir_name = "./test_rm_no_exist.flag"
fs.delete(dir_name)
try:
fs._rmr(dir_name)
self.assertFalse(True)
except Exception as e:
pass
try:
fs._rm(dir_name)
self.assertFalse(True)
except Exception as e:
pass
def test_exists(self):
fs = HDFSClient("/usr/local/hadoop-2.7.7/", None, time_out=15 * 1000)
fs = HDFSClient(
"/usr/local/hadoop-2.7.7/",
None,
time_out=15 * 1000,
sleep_inter=100)
self.assertFalse(fs.is_exist(os.path.abspath("./xxxx")))
self.assertFalse(fs.is_dir(os.path.abspath("./xxxx")))
self.assertTrue(fs.is_dir(os.path.abspath("./xxx/..")))
......@@ -149,27 +203,39 @@ class FSTest(unittest.TestCase):
dirs, files = fs.ls_dir(os.path.abspath("./xxx/.."))
def test_hdfs(self):
fs = HDFSClient("/usr/local/hadoop-2.7.7/", None, time_out=15 * 1000)
fs = HDFSClient(
"/usr/local/hadoop-2.7.7/",
None,
time_out=15 * 1000,
sleep_inter=100)
self._test_rm(fs)
self._test_touch(fs)
self._test_dirs(fs)
self._test_upload(fs)
self._test_download(fs)
self._test_mkdirs(fs)
self._test_list_dir(fs)
self._test_try_upload(fs)
self._test_try_download(fs)
def test_local(self):
fs = LocalFS()
self._test_rm(fs)
self._test_touch(fs)
self._test_dirs(fs)
self._test_touch_file(fs)
self._test_mkdirs(fs)
self._test_list_dir(fs)
self._test_try_upload(fs)
self._test_try_download(fs)
def test_timeout(self):
fs = HDFSClient(
"/usr/local/hadoop-2.7.7/",
None,
time_out=6 * 1000,
sleep_inter=2000)
sleep_inter=100)
src = "hdfs_test_timeout"
dst = "new_hdfs_test_timeout"
fs.delete(dst)
......@@ -190,7 +256,11 @@ class FSTest(unittest.TestCase):
print("second mv ret:{} output:{}".format(ret, output))
def test_is_dir(self):
fs = HDFSClient("/usr/local/hadoop-2.7.7/", None, time_out=15 * 1000)
fs = HDFSClient(
"/usr/local/hadoop-2.7.7/",
None,
time_out=15 * 1000,
sleep_inter=100)
self.assertFalse(fs.is_dir("./test_hdfs.py"))
s = """
java.io.IOException: Input/output error
......@@ -212,12 +282,38 @@ java.io.IOException: Input/output error
def test_config(self):
config = {"fs.default.name": "hdfs://xxx", "hadoop.job.ugi": "ugi"}
fs = HDFSClient("/usr/local/hadoop-2.7.7/", config, time_out=15 * 1000)
fs = HDFSClient(
"/usr/local/hadoop-2.7.7/",
config,
time_out=15 * 1000,
sleep_inter=100)
def _test_list_dir(self, fs):
fs = HDFSClient("/usr/local/hadoop-2.7.7/", None, time_out=15 * 1000)
fs = HDFSClient(
"/usr/local/hadoop-2.7.7/",
None,
time_out=15 * 1000,
sleep_inter=100)
fs.ls_dir("test_not_exists")
def _test_touch(self, fs):
path = "./touch.flag"
fs.touch(path, exist_ok=True)
try:
fs.touch("./touch.flag", exist_ok=False)
self.assertFalse(0, "can't reach here")
except FSFileExistsError as e:
pass
try:
fs._touchz("./touch.flag")
self.assertFalse(True, "can't reach here")
except Exception as e:
pass
self.assertFalse(fs.is_dir(path))
fs.delete(path)
if __name__ == '__main__':
unittest.main()
......@@ -178,6 +178,7 @@ packages=['paddle',
'paddle.fluid.incubate',
'paddle.fluid.incubate.data_generator',
'paddle.fluid.incubate.fleet',
'paddle.fluid.incubate.checkpoint',
'paddle.fluid.incubate.fleet.base',
'paddle.fluid.incubate.fleet.parameter_server',
'paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册