提交 b044724d 编写于 作者: T tangwei12

update fluid Train API param_path to checkpoint_config

上级 e901de66
...@@ -27,11 +27,8 @@ import parallel_executor ...@@ -27,11 +27,8 @@ import parallel_executor
from transpiler import distribute_transpiler from transpiler import distribute_transpiler
__all__ = [ __all__ = [
'Trainer', 'Trainer', 'BeginEpochEvent', 'EndEpochEvent', 'BeginStepEvent',
'BeginEpochEvent', 'EndStepEvent', 'CheckpointConfig'
'EndEpochEvent',
'BeginStepEvent',
'EndStepEvent',
] ]
...@@ -59,6 +56,17 @@ class EndStepEvent(object): ...@@ -59,6 +56,17 @@ class EndStepEvent(object):
self.metrics = metrics self.metrics = metrics
class CheckpointConfig(object):
def __init__(self,
checkpoint_dir=None,
max_num_checkpoints=3,
save_interval_secs=600):
if checkpoint_dir is None:
self.checkpoint_dir = os.getcwd()
self.max_num_checkpoints = max_num_checkpoints
self.save_interval_secs = save_interval_secs
def check_and_get_place(place): def check_and_get_place(place):
""" """
Check the type of place or get the default place Check the type of place or get the default place
...@@ -97,9 +105,9 @@ class Trainer(object): ...@@ -97,9 +105,9 @@ class Trainer(object):
def __init__(self, def __init__(self,
train_func, train_func,
optimizer, optimizer,
param_path=None,
place=None, place=None,
parallel=False): parallel=False,
checkpoint_config=None):
self.__stop = False self.__stop = False
self.parallel = parallel self.parallel = parallel
# 1. we need to generate a framework.Program by calling # 1. we need to generate a framework.Program by calling
...@@ -108,6 +116,16 @@ class Trainer(object): ...@@ -108,6 +116,16 @@ class Trainer(object):
if not isinstance(optimizer, opt_module.Optimizer): if not isinstance(optimizer, opt_module.Optimizer):
raise TypeError("The optimizer should be an instance of Optimizer") raise TypeError("The optimizer should be an instance of Optimizer")
# config for checkpoint
# only chief worker will save variables
self.chief = True
self.checkpoint = checkpoint_config
if self.checkpoint and not isinstance(self.checkpoint,
CheckpointConfig):
raise TypeError(
"The checkpoint_config shoule be an instance of CheckpointConfig"
)
self.scope = core.Scope() self.scope = core.Scope()
self.startup_program = framework.Program() self.startup_program = framework.Program()
...@@ -136,9 +154,10 @@ class Trainer(object): ...@@ -136,9 +154,10 @@ class Trainer(object):
exe = executor.Executor(place) exe = executor.Executor(place)
exe.run(self.startup_program) exe.run(self.startup_program)
if param_path: if self.checkpoint:
# load params from param_path into scope exe = executor.Executor(place)
io.load_persistables(exe, dirname=param_path) io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.startup_program)
def _transpile_nccl2_dist(self): def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS # PADDLE_TRAINER_IPS
...@@ -146,6 +165,7 @@ class Trainer(object): ...@@ -146,6 +165,7 @@ class Trainer(object):
self.nccl_id_var = None self.nccl_id_var = None
else: else:
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
self.chief = self.trainer_id == 0
port = os.getenv("PADDLE_PSERVER_PORT") port = os.getenv("PADDLE_PSERVER_PORT")
worker_ips = os.getenv("PADDLE_TRAINER_IPS") worker_ips = os.getenv("PADDLE_TRAINER_IPS")
worker_endpoints = [] worker_endpoints = []
...@@ -194,6 +214,7 @@ class Trainer(object): ...@@ -194,6 +214,7 @@ class Trainer(object):
# the unique trainer id, starting from 0, needed by trainer # the unique trainer id, starting from 0, needed by trainer
# only # only
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.chief = self.trainer_id == 0
# the role, should be either PSERVER or TRAINER # the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE") training_role = os.getenv("PADDLE_TRAINING_ROLE")
with self._prog_and_scope_guard(): with self._prog_and_scope_guard():
...@@ -263,6 +284,14 @@ class Trainer(object): ...@@ -263,6 +284,14 @@ class Trainer(object):
exe = executor.Executor(self.place) exe = executor.Executor(self.place)
io.save_persistables(exe, dirname=param_path) io.save_persistables(exe, dirname=param_path)
def _save_checkpoint(self):
if self.checkpoint and self.chief:
exe = executor.Executor(self.place)
io.save_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.checkpoint.max_num_checkpoints,
self.checkpoint.save_interval_secs,
self.train_program)
@contextlib.contextmanager @contextlib.contextmanager
def _prog_and_scope_guard(self): def _prog_and_scope_guard(self):
with framework.program_guard( with framework.program_guard(
...@@ -309,6 +338,7 @@ class Trainer(object): ...@@ -309,6 +338,7 @@ class Trainer(object):
else: else:
metrics = exe.run(feed=data, fetch_list=[]) metrics = exe.run(feed=data, fetch_list=[])
event_handler(EndStepEvent(epoch_id, step_id, metrics)) event_handler(EndStepEvent(epoch_id, step_id, metrics))
self._save_checkpoint()
event_handler(EndEpochEvent(epoch_id)) event_handler(EndEpochEvent(epoch_id))
def _test_by_executor(self, reader, feed_order, fetch_list): def _test_by_executor(self, reader, feed_order, fetch_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册