提交 dc7f2031 编写于 作者: Y Yu Yang

Add comments for functions

上级 ce49124d
......@@ -16,6 +16,10 @@ class BaseEvent(object):
class CompleteTrainOneBatch(BaseEvent):
"""
Event On One Batch Training Complete.
"""
def __init__(self, pass_id, batch_id, cost):
self.pass_id = pass_id
self.batch_id = batch_id
......@@ -38,6 +42,11 @@ class ITrainer(object):
class SGDTrainer(ITrainer):
def __init__(self, update_equation):
"""
Simple SGD Trainer.
:param update_equation: Maybe we should give a DSL for update equation?
"""
if not isinstance(update_equation, paddle.v2.optimizer.Optimizer):
raise ValueError()
......@@ -52,6 +61,21 @@ class SGDTrainer(ITrainer):
event_handler=None,
batch_size=32,
data_types=None):
"""
Training method. Will train num_passes of input data.
:param train_data_reader:
:param topology: Network Topology, a protobuf ModelConfig message.
:param parameters: The parameter pools.
:param num_passes: The total train passes.
:param test_data_reader:
:param event_handler: Event handler. A method will be invoked when event
occurred.
:type event_handler: (BaseEvent) => None
:param batch_size: Not important, will be removed after data refactor.
:param data_types: Not important, will be removed after data refactor.
:return:
"""
if event_handler is None:
event_handler = default_event_handler
......@@ -66,6 +90,9 @@ class SGDTrainer(ITrainer):
assert isinstance(updater, api.ParameterUpdater)
updater.init(gm)
gm.start()
out_args = api.Arguments.createArguments(0)
data_types_lists = []
for each in topology.input_layer_names:
if each not in data_types:
......@@ -74,22 +101,11 @@ class SGDTrainer(ITrainer):
converter = DataProviderConverter(input_types=data_types_lists)
def input_reorder(func):
for item in func():
retv = []
for __layer_name__ in topology.input_layer_names:
retv.append(item[__layer_name__])
yield retv
gm.start()
out_args = api.Arguments.createArguments(0)
for pass_id in xrange(num_passes):
updater.startPass()
for batch_id, data_batch in enumerate(
__generator_to_batch__(
input_reorder(train_data_reader),
batch_size=batch_size)):
__data_reader_to_batch__(train_data_reader, batch_size,
topology)):
pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(converter(data_batch), out_args, pass_type)
for each_param in gm.getParameters():
......@@ -108,7 +124,25 @@ class SGDTrainer(ITrainer):
gm.finish()
def __data_reader_to_batch__(reader, batch_size, topology):
"""
This function is not important, and will be removed when data refactored.
"""
def input_reorder(func):
for item in func():
retv = []
for __layer_name__ in topology.input_layer_names:
retv.append(item[__layer_name__])
yield retv
return __generator_to_batch__(input_reorder(reader), batch_size=batch_size)
def __generator_to_batch__(generator, batch_size):
"""
This function is not important, and will be removed when data refactored.
"""
ret_val = list()
for each_item in generator:
ret_val.append(each_item)
......@@ -139,6 +173,9 @@ def __copy_parameter_from_pool__(gm, pool):
def __check_train_args__(train_data_reader, topology, parameters,
test_data_reader, event_handler, **kwargs):
"""
Check train function's argument types
"""
if not callable(train_data_reader) or not isinstance(train_data_reader(),
collections.Iterator):
raise ValueError('train_data_reader should be a function, '
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册