From 8e9ac0cc5541654aa44bf9b07a679c9b335f1f95 Mon Sep 17 00:00:00 2001 From: zhanghaichao Date: Fri, 2 Dec 2016 04:36:56 -0800 Subject: [PATCH] adding input type check for data provider --- python/paddle/trainer/PyDataProvider2.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/paddle/trainer/PyDataProvider2.py b/python/paddle/trainer/PyDataProvider2.py index 0c577ec657b..081e5408482 100644 --- a/python/paddle/trainer/PyDataProvider2.py +++ b/python/paddle/trainer/PyDataProvider2.py @@ -202,6 +202,24 @@ class CheckWrapper(object): for each in item: callback(each) +class CheckInputTypeWrapper(object): + def __init__(self, generator, input_types, logger): + self.generator = generator + self.input_types = input_types + self.logger = logger + + def __call__(self, obj, filename): + for items in self.generator(obj, filename): + try: + # dict type is required for input_types when item is dict type + assert (isinstance(items, dict) and \ + not isinstance(self.input_types, dict))==False + yield items + except AssertionError as e: + self.logger.error( + "%s type is required for input type but got %s" % + (repr(type(items)), repr(type(self.input_types)))) + raise def provider(input_types=None, should_shuffle=None, @@ -355,6 +373,9 @@ def provider(input_types=None, if use_dynamic_order: self.generator = InputOrderWrapper(self.generator, self.input_order) + else: + self.generator = CheckInputTypeWrapper(self.generator, self.slots, + self.logger) if self.check: self.generator = CheckWrapper(self.generator, self.slots, check_fail_continue, -- GitLab