提交 958ab99e 编写于 作者: Y yuyang18

Polish Non-Layer API

上级 16a0f746
...@@ -38,6 +38,13 @@ class BeginEpochEvent(object): ...@@ -38,6 +38,13 @@ class BeginEpochEvent(object):
class EndEpochEvent(object): class EndEpochEvent(object):
"""
The end of a training epoch.
Args:
epoch_id(int): The current epoch ID.
"""
def __init__(self, epoch_id): def __init__(self, epoch_id):
self.epoch = epoch_id self.epoch = epoch_id
...@@ -50,6 +57,16 @@ class BeginStepEvent(object): ...@@ -50,6 +57,16 @@ class BeginStepEvent(object):
class EndStepEvent(object): class EndStepEvent(object):
"""
The end of a training step.
Args:
epoch_id(int): The current epoch ID.
step_id(int): The current step ID.
metrics(list): A list of fetched tensor. The order of this list is same
as the :code:`train_func` returns.
"""
def __init__(self, epoch_id, step_id, metrics): def __init__(self, epoch_id, step_id, metrics):
self.epoch = epoch_id self.epoch = epoch_id
self.step = step_id self.step = step_id
...@@ -57,6 +74,27 @@ class EndStepEvent(object): ...@@ -57,6 +74,27 @@ class EndStepEvent(object):
class CheckpointConfig(object): class CheckpointConfig(object):
"""
Parameter object for :code:`fluid.io.save_checkpoint` and
:code:`fluid.Trainer`. Used to configuration how to save checkpoint.
Args:
checkpoint_dir(str): Directory path to save check point. Default is the
current directory.
max_num_checkpoints(int): The max number of local check points.
epoch_interval(int): Every number of epoch to save check point.
step_interval(int): Every number of step to save check point.
Examples:
>>> config = fluid.CheckpointConfig("./checkpoints")
>>> trainer = fluid.Trainer(train_func=train_program,
>>> place=place,
>>> optimizer_func=optimizer_func,
>>> checkpoint_config=config)
>>> trainer.train(...)
"""
def __init__(self, def __init__(self,
checkpoint_dir=None, checkpoint_dir=None,
max_num_checkpoints=3, max_num_checkpoints=3,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册