提交 dc7f2031 编写于 作者: Y Yu Yang

Add comments for functions

上级 ce49124d
...@@ -16,6 +16,10 @@ class BaseEvent(object): ...@@ -16,6 +16,10 @@ class BaseEvent(object):
class CompleteTrainOneBatch(BaseEvent): class CompleteTrainOneBatch(BaseEvent):
"""
Event On One Batch Training Complete.
"""
def __init__(self, pass_id, batch_id, cost): def __init__(self, pass_id, batch_id, cost):
self.pass_id = pass_id self.pass_id = pass_id
self.batch_id = batch_id self.batch_id = batch_id
...@@ -38,6 +42,11 @@ class ITrainer(object): ...@@ -38,6 +42,11 @@ class ITrainer(object):
class SGDTrainer(ITrainer): class SGDTrainer(ITrainer):
def __init__(self, update_equation): def __init__(self, update_equation):
"""
Simple SGD Trainer.
:param update_equation: Maybe we should give a DSL for update equation?
"""
if not isinstance(update_equation, paddle.v2.optimizer.Optimizer): if not isinstance(update_equation, paddle.v2.optimizer.Optimizer):
raise ValueError() raise ValueError()
...@@ -52,6 +61,21 @@ class SGDTrainer(ITrainer): ...@@ -52,6 +61,21 @@ class SGDTrainer(ITrainer):
event_handler=None, event_handler=None,
batch_size=32, batch_size=32,
data_types=None): data_types=None):
"""
Training method. Will train num_passes of input data.
:param train_data_reader:
:param topology: Network Topology, a protobuf ModelConfig message.
:param parameters: The parameter pools.
:param num_passes: The total train passes.
:param test_data_reader:
:param event_handler: Event handler. A method will be invoked when event
occurred.
:type event_handler: (BaseEvent) => None
:param batch_size: Not important, will be removed after data refactor.
:param data_types: Not important, will be removed after data refactor.
:return:
"""
if event_handler is None: if event_handler is None:
event_handler = default_event_handler event_handler = default_event_handler
...@@ -66,6 +90,9 @@ class SGDTrainer(ITrainer): ...@@ -66,6 +90,9 @@ class SGDTrainer(ITrainer):
assert isinstance(updater, api.ParameterUpdater) assert isinstance(updater, api.ParameterUpdater)
updater.init(gm) updater.init(gm)
gm.start()
out_args = api.Arguments.createArguments(0)
data_types_lists = [] data_types_lists = []
for each in topology.input_layer_names: for each in topology.input_layer_names:
if each not in data_types: if each not in data_types:
...@@ -74,22 +101,11 @@ class SGDTrainer(ITrainer): ...@@ -74,22 +101,11 @@ class SGDTrainer(ITrainer):
converter = DataProviderConverter(input_types=data_types_lists) converter = DataProviderConverter(input_types=data_types_lists)
def input_reorder(func):
for item in func():
retv = []
for __layer_name__ in topology.input_layer_names:
retv.append(item[__layer_name__])
yield retv
gm.start()
out_args = api.Arguments.createArguments(0)
for pass_id in xrange(num_passes): for pass_id in xrange(num_passes):
updater.startPass() updater.startPass()
for batch_id, data_batch in enumerate( for batch_id, data_batch in enumerate(
__generator_to_batch__( __data_reader_to_batch__(train_data_reader, batch_size,
input_reorder(train_data_reader), topology)):
batch_size=batch_size)):
pass_type = updater.startBatch(len(data_batch)) pass_type = updater.startBatch(len(data_batch))
gm.forwardBackward(converter(data_batch), out_args, pass_type) gm.forwardBackward(converter(data_batch), out_args, pass_type)
for each_param in gm.getParameters(): for each_param in gm.getParameters():
...@@ -108,7 +124,25 @@ class SGDTrainer(ITrainer): ...@@ -108,7 +124,25 @@ class SGDTrainer(ITrainer):
gm.finish() gm.finish()
def __data_reader_to_batch__(reader, batch_size, topology):
"""
This function is not important, and will be removed when data refactored.
"""
def input_reorder(func):
for item in func():
retv = []
for __layer_name__ in topology.input_layer_names:
retv.append(item[__layer_name__])
yield retv
return __generator_to_batch__(input_reorder(reader), batch_size=batch_size)
def __generator_to_batch__(generator, batch_size): def __generator_to_batch__(generator, batch_size):
"""
This function is not important, and will be removed when data refactored.
"""
ret_val = list() ret_val = list()
for each_item in generator: for each_item in generator:
ret_val.append(each_item) ret_val.append(each_item)
...@@ -139,6 +173,9 @@ def __copy_parameter_from_pool__(gm, pool): ...@@ -139,6 +173,9 @@ def __copy_parameter_from_pool__(gm, pool):
def __check_train_args__(train_data_reader, topology, parameters, def __check_train_args__(train_data_reader, topology, parameters,
test_data_reader, event_handler, **kwargs): test_data_reader, event_handler, **kwargs):
"""
Check train function's argument types
"""
if not callable(train_data_reader) or not isinstance(train_data_reader(), if not callable(train_data_reader) or not isinstance(train_data_reader(),
collections.Iterator): collections.Iterator):
raise ValueError('train_data_reader should be a function, ' raise ValueError('train_data_reader should be a function, '
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册