提交 b044724d 编写于 作者: T tangwei12

update fluid Train API param_path to checkpoint_config

上级 e901de66
......@@ -27,11 +27,8 @@ import parallel_executor
from transpiler import distribute_transpiler
__all__ = [
'Trainer',
'BeginEpochEvent',
'EndEpochEvent',
'BeginStepEvent',
'EndStepEvent',
'Trainer', 'BeginEpochEvent', 'EndEpochEvent', 'BeginStepEvent',
'EndStepEvent', 'CheckpointConfig'
]
......@@ -59,6 +56,17 @@ class EndStepEvent(object):
self.metrics = metrics
class CheckpointConfig(object):
def __init__(self,
checkpoint_dir=None,
max_num_checkpoints=3,
save_interval_secs=600):
if checkpoint_dir is None:
self.checkpoint_dir = os.getcwd()
self.max_num_checkpoints = max_num_checkpoints
self.save_interval_secs = save_interval_secs
def check_and_get_place(place):
"""
Check the type of place or get the default place
......@@ -97,9 +105,9 @@ class Trainer(object):
def __init__(self,
train_func,
optimizer,
param_path=None,
place=None,
parallel=False):
parallel=False,
checkpoint_config=None):
self.__stop = False
self.parallel = parallel
# 1. we need to generate a framework.Program by calling
......@@ -108,6 +116,16 @@ class Trainer(object):
if not isinstance(optimizer, opt_module.Optimizer):
raise TypeError("The optimizer should be an instance of Optimizer")
# config for checkpoint
# only chief worker will save variables
self.chief = True
self.checkpoint = checkpoint_config
if self.checkpoint and not isinstance(self.checkpoint,
CheckpointConfig):
raise TypeError(
"The checkpoint_config shoule be an instance of CheckpointConfig"
)
self.scope = core.Scope()
self.startup_program = framework.Program()
......@@ -136,9 +154,10 @@ class Trainer(object):
exe = executor.Executor(place)
exe.run(self.startup_program)
if param_path:
# load params from param_path into scope
io.load_persistables(exe, dirname=param_path)
if self.checkpoint:
exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.startup_program)
def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS
......@@ -146,6 +165,7 @@ class Trainer(object):
self.nccl_id_var = None
else:
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
self.chief = self.trainer_id == 0
port = os.getenv("PADDLE_PSERVER_PORT")
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
worker_endpoints = []
......@@ -194,6 +214,7 @@ class Trainer(object):
# the unique trainer id, starting from 0, needed by trainer
# only
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.chief = self.trainer_id == 0
# the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE")
with self._prog_and_scope_guard():
......@@ -263,6 +284,14 @@ class Trainer(object):
exe = executor.Executor(self.place)
io.save_persistables(exe, dirname=param_path)
def _save_checkpoint(self):
if self.checkpoint and self.chief:
exe = executor.Executor(self.place)
io.save_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.checkpoint.max_num_checkpoints,
self.checkpoint.save_interval_secs,
self.train_program)
@contextlib.contextmanager
def _prog_and_scope_guard(self):
with framework.program_guard(
......@@ -309,6 +338,7 @@ class Trainer(object):
else:
metrics = exe.run(feed=data, fetch_list=[])
event_handler(EndStepEvent(epoch_id, step_id, metrics))
self._save_checkpoint()
event_handler(EndEpochEvent(epoch_id))
def _test_by_executor(self, reader, feed_order, fetch_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册