提交 4d81b361 编写于 作者: Y Yu Yang

A tiny fix in PyDataProvider2

* hidden decorator kwargs in DataProvider.__init__
* also add unit test for this.
上级 2965df51
...@@ -17,7 +17,7 @@ import random ...@@ -17,7 +17,7 @@ import random
from paddle.trainer.PyDataProvider2 import * from paddle.trainer.PyDataProvider2 import *
@provider(input_types=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)]) @provider(slots=[dense_vector(200, seq_type=SequenceType.NO_SEQUENCE)])
def test_dense_no_seq(setting, filename): def test_dense_no_seq(setting, filename):
for i in xrange(200): for i in xrange(200):
yield [(float(j - 100) * float(i + 1)) / 200.0 for j in xrange(200)] yield [(float(j - 100) * float(i + 1)) / 200.0 for j in xrange(200)]
......
...@@ -232,7 +232,7 @@ def provider(input_types=None, ...@@ -232,7 +232,7 @@ def provider(input_types=None,
check=False, check=False,
check_fail_continue=False, check_fail_continue=False,
init_hook=None, init_hook=None,
**kwargs): **outter_kwargs):
""" """
Provider decorator. Use it to make a function into PyDataProvider2 object. Provider decorator. Use it to make a function into PyDataProvider2 object.
In this function, user only need to get each sample for some train/test In this function, user only need to get each sample for some train/test
...@@ -318,11 +318,6 @@ def provider(input_types=None, ...@@ -318,11 +318,6 @@ def provider(input_types=None,
self.logger = logging.getLogger("") self.logger = logging.getLogger("")
self.logger.setLevel(logging.INFO) self.logger.setLevel(logging.INFO)
self.input_types = None self.input_types = None
if 'slots' in kwargs:
self.logger.warning('setting slots value is deprecated, '
'please use input_types instead.')
self.slots = kwargs['slots']
self.slots = input_types
self.should_shuffle = should_shuffle self.should_shuffle = should_shuffle
true_table = [1, 't', 'true', 'on'] true_table = [1, 't', 'true', 'on']
...@@ -358,9 +353,19 @@ def provider(input_types=None, ...@@ -358,9 +353,19 @@ def provider(input_types=None,
self.check = check self.check = check
if init_hook is not None: if init_hook is not None:
init_hook(self, file_list=file_list, **kwargs) init_hook(self, file_list=file_list, **kwargs)
if 'slots' in outter_kwargs:
self.logger.warning('setting slots value is deprecated, '
'please use input_types instead.')
self.slots = outter_kwargs['slots']
if input_types is not None:
self.slots = input_types
if self.input_types is not None: if self.input_types is not None:
self.slots = self.input_types self.slots = self.input_types
assert self.slots is not None
assert self.slots is not None, \
"Data Provider's input_types must be set"
assert self.generator is not None assert self.generator is not None
use_dynamic_order = False use_dynamic_order = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册