diff --git a/python/paddle/v2/data_feeder.py b/python/paddle/v2/data_feeder.py index 2a0b6bbeb563f7b077706f5fd49306eae292c19a..83a4efef9e7a7bab492279105ef0c425e1d355c7 100644 --- a/python/paddle/v2/data_feeder.py +++ b/python/paddle/v2/data_feeder.py @@ -15,5 +15,46 @@ from py_paddle import DataProviderConverter __all__ = ['DataFeeder'] +""" +DataFeeder converts the data returned by paddle.reader into a data structure +of Arguments which is defined in the API. The paddle.reader usually returns +a list of mini-batch data. Each item in the list is a tuple or list, which is +one sample with multiple features. DataFeeder converts this mini-batch data +into Arguments in order to feed it to C++ interface. + +The example usage: + + data_types = [paddle.data_type.dense_vector(784), + paddle.data_type.integer_value(10)] + feeder = DataFeeder(input_types=data_types) + minibatch_data = [ + ( [1.0,2.0,3.0,4.0], 5, [6,7,8] ), # first sample + ( [1.0,2.0,3.0,4.0], 5, [6,7,8] ) # second sample + ] + + # or + # minibatch_data = [ + # [ [1.0,2.0,3.0,4.0], 5, [6,7,8] ], # first sample + # [ [1.0,2.0,3.0,4.0], 5, [6,7,8] ] # second sample + # ] + arg = feeder(minibatch_data) + + +Args: + input_types: A list of input data types. It's length is equal to the length + of data returned by paddle.reader. Each item specifies the type + of each feature. + mintbatch_data: A list of mini-batch data. Each item is a list or tuple, + for example: + [ + (feature_0, feature_1, feature_2, ...), # first sample + (feature_0, feature_1, feature_2, ...), # second sample + ... + ] + +Returns: + An Arguments object contains this mini-batch data with multiple features. + The Arguments definition is in the API. +""" DataFeeder = DataProviderConverter diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 7480a3fb84bbd2abe9d7ff4cbed743bb0470e0e8..5709c7e886c323b9159f18a52133770ea675fa5b 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -89,7 +89,6 @@ class SGD(ITrainer): event_handler = default_event_handler topology = v2_layer.parse_network(topology) - print topology __check_train_args__(**locals())