From 958ab99ef86d329b9ecfef58cffa13791435c3c8 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Sun, 17 Jun 2018 14:32:34 +0800 Subject: [PATCH] Polish Non-Layer API --- python/paddle/fluid/trainer.py | 38 ++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index efc28d89930..2373cff2225 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -38,6 +38,13 @@ class BeginEpochEvent(object): class EndEpochEvent(object): + """ + The end of a training epoch. + + Args: + epoch_id(int): The current epoch ID. + """ + def __init__(self, epoch_id): self.epoch = epoch_id @@ -50,6 +57,16 @@ class BeginStepEvent(object): class EndStepEvent(object): + """ + The end of a training step. + + Args: + epoch_id(int): The current epoch ID. + step_id(int): The current step ID. + metrics(list): A list of fetched tensor. The order of this list is same + as the :code:`train_func` returns. + """ + def __init__(self, epoch_id, step_id, metrics): self.epoch = epoch_id self.step = step_id @@ -57,6 +74,27 @@ class EndStepEvent(object): class CheckpointConfig(object): + """ + Parameter object for :code:`fluid.io.save_checkpoint` and + :code:`fluid.Trainer`. Used to configuration how to save checkpoint. + + Args: + checkpoint_dir(str): Directory path to save check point. Default is the + current directory. + + max_num_checkpoints(int): The max number of local check points. + epoch_interval(int): Every number of epoch to save check point. + step_interval(int): Every number of step to save check point. + + Examples: + >>> config = fluid.CheckpointConfig("./checkpoints") + >>> trainer = fluid.Trainer(train_func=train_program, + >>> place=place, + >>> optimizer_func=optimizer_func, + >>> checkpoint_config=config) + >>> trainer.train(...) + """ + def __init__(self, checkpoint_dir=None, max_num_checkpoints=3, -- GitLab