提交 f97c5d4c 编写于 作者: Y yuyang18

Trainer documentation

上级 08995ac9
...@@ -151,11 +151,62 @@ def check_and_get_place(place): ...@@ -151,11 +151,62 @@ def check_and_get_place(place):
class Trainer(object): class Trainer(object):
""" """
A trainer wraps MultiGPU/MultiNode training loops and can be used to train a
simple neural network easily.
This API takes a :code:`train_func`. A :code:`train_func` is a function that
return loss as it first return value. The reset value can be fetched by
EndStepEvent.metrics
This API also takes a :code:`optimizer_func` that will return an optimizer
instance.
For example, to train a MLP for MNIST dataset, the sample program is
>>> import paddle.fluid as fluid
>>>
>>> def mlp(image, layer_sizes=[200, 100], activation="relu", num_classes=10):
>>> hidden = image
>>> for layer_size in layer_sizes:
>>> hidden = fluid.layers.fc(input=hidden, size=layer_size, act=activation)
>>> return fluid.layers.fc(input=hidden, size=num_classes, act="softmax")
>>>
>>> def train_mnist_mlp():
>>> img = fluid.layers.data(name='image', shape=[784])
>>> label = fluid.layers.data(name='label', shape=[1], dtype='int64')
>>> prediction = mlp(img)
>>> return fluid.layers.mean(fluid.layers.cross_entropy(prediction, label))
>>>
>>> def optimizer():
>>> return fluid.optimizer.Adam()
>>>
>>> trainer = Trainer(train_func=train_mnist_mlp,
>>> optimizer_func=optimizer,
>>> place=fluid.CUDAPlace(0),
>>> parallel=True)
>>>
>>> def train_callback(event):
>>> if isinstance(event, fluid.EndStepEvent):
>>> print "Epoch ID", event.epoch, "Step ID",\
>>> event.step, "AvgLoss", event.metrics[0]
>>> elif isinstance(event, fluid.EndEpochEvent):
>>> trainer.save_params("./model_{0}".format(event.epoch))
>>>
>>> trainer.train(num_epochs=100, event_handler=train_callback)
For more example, please see :ref:`api_guide_high_level_api`.
Args: Args:
train_func(callable): A function which will return loss. The loss must be a scalar. train_func(callable): A function which will return loss. The loss must be
a scalar tensor.
optimizer_func(callable): A function that returns an Optimizer object. optimizer_func(callable): A function that returns an Optimizer object.
place: The device place of this trainer. place(CUDAPlace|CPUPlace): The device place of this trainer. If
:code:`parallel=True,` all CUDA Places will be used if :code:`place`
is a :code:`CUDAPlace`.
parallel(bool): True if use multiple devices.
checkpoint_config(CheckpointConfig): Configuration about how to save
checkpoints.
""" """
def __init__(self, def __init__(self,
...@@ -167,9 +218,6 @@ class Trainer(object): ...@@ -167,9 +218,6 @@ class Trainer(object):
checkpoint_config=None): checkpoint_config=None):
self.__stop = False self.__stop = False
self.parallel = parallel self.parallel = parallel
# 1. we need to generate a framework.Program by calling
# program_func. Reference: fluid.program_guard in
# test_word2vec.py
# config for checkpoint # config for checkpoint
# only chief worker will save variables # only chief worker will save variables
...@@ -183,6 +231,10 @@ class Trainer(object): ...@@ -183,6 +231,10 @@ class Trainer(object):
self.scope = core.Scope() self.scope = core.Scope()
# 1. we need to generate a framework.Program by calling
# program_func. Reference: fluid.program_guard in
# test_word2vec.py
self.startup_program = framework.Program() self.startup_program = framework.Program()
self.train_program = framework.Program() self.train_program = framework.Program()
...@@ -315,17 +367,18 @@ class Trainer(object): ...@@ -315,17 +367,18 @@ class Trainer(object):
def train(self, num_epochs, event_handler, reader=None, feed_order=None): def train(self, num_epochs, event_handler, reader=None, feed_order=None):
""" """
Train the model. Start the train loop to train the model.
Args: Args:
num_epochs: The number of epoch. An epoch will process all data in reader num_epochs: The number of epoch. An epoch will process all data in reader
event_handler: The event handler. A function with type (ev:Event)->void event_handler: The event handler. A function with type (ev:Event)->void
reader: reader: A reader creator object. See also
:ref:`api_guide_python_reader` .
feed_order: Feeding order of reader. None will following the defining feed_order: Feeding order of reader. None will following the defining
order in program order in program
Returns: Returns:
None
""" """
training_role = os.getenv("PADDLE_TRAINING_ROLE", "") training_role = os.getenv("PADDLE_TRAINING_ROLE", "")
if training_role == "PSERVER": if training_role == "PSERVER":
...@@ -354,7 +407,14 @@ class Trainer(object): ...@@ -354,7 +407,14 @@ class Trainer(object):
self.train_func_outputs) self.train_func_outputs)
def save_params(self, param_path): def save_params(self, param_path):
# reference: save_persistables in io.py """
Save all parameters into :code:`param_path`
Args:
param_path(str): The path to save parameters
Returns:
None
"""
with self._prog_and_scope_guard(): with self._prog_and_scope_guard():
exe = executor.Executor(self.place) exe = executor.Executor(self.place)
io.save_persistables(exe, dirname=param_path) io.save_persistables(exe, dirname=param_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册