diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 933700c615527773710bdde3c6e2beca8bb37139..3794ab07720e217e1b9ee41d3188675a713b62b2 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -16,6 +16,10 @@ class BaseEvent(object): class CompleteTrainOneBatch(BaseEvent): + """ + Event On One Batch Training Complete. + """ + def __init__(self, pass_id, batch_id, cost): self.pass_id = pass_id self.batch_id = batch_id @@ -38,6 +42,11 @@ class ITrainer(object): class SGDTrainer(ITrainer): def __init__(self, update_equation): + """ + Simple SGD Trainer. + + :param update_equation: Maybe we should give a DSL for update equation? + """ if not isinstance(update_equation, paddle.v2.optimizer.Optimizer): raise ValueError() @@ -52,6 +61,21 @@ class SGDTrainer(ITrainer): event_handler=None, batch_size=32, data_types=None): + """ + Training method. Will train num_passes of input data. + + :param train_data_reader: + :param topology: Network Topology, a protobuf ModelConfig message. + :param parameters: The parameter pools. + :param num_passes: The total train passes. + :param test_data_reader: + :param event_handler: Event handler. A method will be invoked when event + occurred. + :type event_handler: (BaseEvent) => None + :param batch_size: Not important, will be removed after data refactor. + :param data_types: Not important, will be removed after data refactor. + :return: + """ if event_handler is None: event_handler = default_event_handler @@ -66,6 +90,9 @@ class SGDTrainer(ITrainer): assert isinstance(updater, api.ParameterUpdater) updater.init(gm) + gm.start() + out_args = api.Arguments.createArguments(0) + data_types_lists = [] for each in topology.input_layer_names: if each not in data_types: @@ -74,22 +101,11 @@ class SGDTrainer(ITrainer): converter = DataProviderConverter(input_types=data_types_lists) - def input_reorder(func): - for item in func(): - retv = [] - for __layer_name__ in topology.input_layer_names: - retv.append(item[__layer_name__]) - yield retv - - gm.start() - - out_args = api.Arguments.createArguments(0) for pass_id in xrange(num_passes): updater.startPass() for batch_id, data_batch in enumerate( - __generator_to_batch__( - input_reorder(train_data_reader), - batch_size=batch_size)): + __data_reader_to_batch__(train_data_reader, batch_size, + topology)): pass_type = updater.startBatch(len(data_batch)) gm.forwardBackward(converter(data_batch), out_args, pass_type) for each_param in gm.getParameters(): @@ -108,7 +124,25 @@ class SGDTrainer(ITrainer): gm.finish() +def __data_reader_to_batch__(reader, batch_size, topology): + """ + This function is not important, and will be removed when data refactored. + """ + + def input_reorder(func): + for item in func(): + retv = [] + for __layer_name__ in topology.input_layer_names: + retv.append(item[__layer_name__]) + yield retv + + return __generator_to_batch__(input_reorder(reader), batch_size=batch_size) + + def __generator_to_batch__(generator, batch_size): + """ + This function is not important, and will be removed when data refactored. + """ ret_val = list() for each_item in generator: ret_val.append(each_item) @@ -139,6 +173,9 @@ def __copy_parameter_from_pool__(gm, pool): def __check_train_args__(train_data_reader, topology, parameters, test_data_reader, event_handler, **kwargs): + """ + Check train function's argument types + """ if not callable(train_data_reader) or not isinstance(train_data_reader(), collections.Iterator): raise ValueError('train_data_reader should be a function, '