提交 79bcd058 编写于 作者: W wangyanfei01

Merge branch 'develop' of https://github.com/baidu/Paddle into fix_data_sources

...@@ -200,24 +200,26 @@ class CheckWrapper(object): ...@@ -200,24 +200,26 @@ class CheckWrapper(object):
for each in item: for each in item:
callback(each) callback(each)
class CheckInputTypeWrapper(object): class CheckInputTypeWrapper(object):
def __init__(self, generator, input_types, logger): def __init__(self, generator, input_types, logger):
self.generator = generator self.generator = generator
self.input_types = input_types self.input_types = input_types
self.logger = logger self.logger = logger
def __call__(self, obj, filename): def __call__(self, obj, filename):
for items in self.generator(obj, filename): for items in self.generator(obj, filename):
try: try:
# dict type is required for input_types when item is dict type # dict type is required for input_types when item is dict type
assert (isinstance(items, dict) and \ assert (isinstance(items, dict) and \
not isinstance(self.input_types, dict))==False not isinstance(self.input_types, dict))==False
yield items yield items
except AssertionError as e: except AssertionError as e:
self.logger.error( self.logger.error(
"%s type is required for input type but got %s" % "%s type is required for input type but got %s" %
(repr(type(items)), repr(type(self.input_types)))) (repr(type(items)), repr(type(self.input_types))))
raise raise
def provider(input_types=None, def provider(input_types=None,
should_shuffle=None, should_shuffle=None,
...@@ -372,8 +374,8 @@ def provider(input_types=None, ...@@ -372,8 +374,8 @@ def provider(input_types=None,
self.generator = InputOrderWrapper(self.generator, self.generator = InputOrderWrapper(self.generator,
self.input_order) self.input_order)
else: else:
self.generator = CheckInputTypeWrapper(self.generator, self.slots, self.generator = CheckInputTypeWrapper(
self.logger) self.generator, self.slots, self.logger)
if self.check: if self.check:
self.generator = CheckWrapper(self.generator, self.slots, self.generator = CheckWrapper(self.generator, self.slots,
check_fail_continue, check_fail_continue,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册