From 4d81b361230fa157ee35ed1de185f2177fcce095 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Dec 2016 12:57:19 +0800 Subject: [PATCH] A tiny fix in PyDataProvider2 * hidden decorator kwargs in DataProvider.__init__ * also add unit test for this. --- paddle/gserver/tests/test_PyDataProvider2.py | 2 +- python/paddle/trainer/PyDataProvider2.py | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/paddle/gserver/tests/test_PyDataProvider2.py b/paddle/gserver/tests/test_PyDataProvider2.py index f7b540013e7..2e6225519f4 100644 --- a/paddle/gserver/tests/test_PyDataProvider2.py +++ b/paddle/gserver/tests/test_PyDataProvider2.py @@ -17,7 +17,7 @@ import random from paddle.trainer.PyDataProvider2 import * -@provider(input_types=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)]) +@provider(slots=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)]) def test_dense_no_seq(setting, filename): for i in xrange(200): yield [(float(j - 100) * float(i + 1)) / 200.0 for j in xrange(200)] diff --git a/python/paddle/trainer/PyDataProvider2.py b/python/paddle/trainer/PyDataProvider2.py index de266bb5d3d..2da592bd9d4 100644 --- a/python/paddle/trainer/PyDataProvider2.py +++ b/python/paddle/trainer/PyDataProvider2.py @@ -232,7 +232,7 @@ def provider(input_types=None, check=False, check_fail_continue=False, init_hook=None, - **kwargs): + **outter_kwargs): """ Provider decorator. Use it to make a function into PyDataProvider2 object. In this function, user only need to get each sample for some train/test @@ -318,11 +318,6 @@ def provider(input_types=None, self.logger = logging.getLogger("") self.logger.setLevel(logging.INFO) self.input_types = None - if 'slots' in kwargs: - self.logger.warning('setting slots value is deprecated, ' - 'please use input_types instead.') - self.slots = kwargs['slots'] - self.slots = input_types self.should_shuffle = should_shuffle true_table = [1, 't', 'true', 'on'] @@ -358,9 +353,19 @@ def provider(input_types=None, self.check = check if init_hook is not None: init_hook(self, file_list=file_list, **kwargs) + + if 'slots' in outter_kwargs: + self.logger.warning('setting slots value is deprecated, ' + 'please use input_types instead.') + self.slots = outter_kwargs['slots'] + if input_types is not None: + self.slots = input_types + if self.input_types is not None: self.slots = self.input_types - assert self.slots is not None + + assert self.slots is not None, \ + "Data Provider's input_types must be set" assert self.generator is not None use_dynamic_order = False -- GitLab