From 7e99af556fca2a7179dc77193754c4b8831b5ac4 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 25 May 2017 20:07:30 +0800 Subject: [PATCH] follow comments --- paddle/py_paddle/dataprovider_converter.py | 42 ++++++++++++---------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/paddle/py_paddle/dataprovider_converter.py b/paddle/py_paddle/dataprovider_converter.py index cfb82e92d51..6234dc65dcd 100644 --- a/paddle/py_paddle/dataprovider_converter.py +++ b/paddle/py_paddle/dataprovider_converter.py @@ -60,7 +60,7 @@ class IScanner(object): """ pass - def finish_pre_scan(self, argument, dat=None): + def finish_pre_scan(self, argument): """ Finish first scan pass. Allocate the memory. @@ -103,23 +103,29 @@ class DenseScanner(IScanner): def pre_scan(self, dat): self.__height__ += 1 + if self.__shape__ is None: + self.__shape__ = numpy.array(dat).shape + if len(self.__shape__) > 3: + raise ValueError( + "The dimension of input cannot be greater than 3.") + else: + if self.__shape__ != numpy.array(dat).shape: + raise ValueError( + "The data shape must be same in one mini-batch.") - def finish_pre_scan(self, argument, dat=None): - self.__shape__ = numpy.array(dat).shape - if len(self.__shape__) > 3: - raise ValueError("The dimension of input is greater than 3.") + def finish_pre_scan(self, argument): dim = reduce(lambda x, y: x * y, self.__shape__) - if len(self.__shape__) == 1: - assert dim == self.input_type.dim + if len(self.__shape__) == 1 and dim != self.input_type.dim: + raise ValueError("The data size must be equal to it in data layer.") self.__mat__ = numpy.ndarray( shape=(self.__height__, dim), dtype=numpy.float32) self.__height__ = 0 def scan(self, dat): - if isinstance(dat, numpy.ndarray): - assert self.__shape__ == dat.shape - dat = dat.flatten() - self.__mat__[self.__height__] = dat + # It's better to use NumPy array for speed. + d = numpy.array(dat) + d = d.flatten() + self.__mat__[self.__height__] = d self.__height__ += 1 def finish_scan(self, argument): @@ -136,6 +142,7 @@ class DenseScanner(IScanner): h, w = self.__shape__[-2:] argument.setSlotFrameHeight(self.pos, h) argument.setSlotFrameWidth(self.pos, w) + self.__shape__ = None class SparseBinaryScanner(IScanner): @@ -186,7 +193,7 @@ class IndexScanner(IScanner): def pre_scan(self, dat): self.__idx__ += 1 - def finish_pre_scan(self, argument, dat=None): + def finish_pre_scan(self, argument): self.__ids__ = [0] * self.__idx__ self.__idx__ = 0 @@ -211,8 +218,8 @@ class SequenceScanner(IScanner): for each in dat: self.__inner_scanner__.pre_scan(each) - def finish_pre_scan(self, argument, dat=None): - self.__inner_scanner__.finish_pre_scan(argument, dat) + def finish_pre_scan(self, argument): + self.__inner_scanner__.finish_pre_scan(argument) def scan(self, dat): self.__seq__.append(self.__seq__[-1] + self.get_size(dat)) @@ -253,11 +260,8 @@ class DataProviderConverter(object): for each_step, scanner in itertools.izip(each_sample, scanners): scanner.pre_scan(each_step) - # Some scanners, like dense scanner, pre-allocate memory for mini-batch - # in finish_pre_scan function. The dat[0] is used to calculate the size - # of input data. - for scanner, each_feature in itertools.izip(scanners, dat[0]): - scanner.finish_pre_scan(argument, each_feature) + for scanner in scanners: + scanner.finish_pre_scan(argument) for each_sample in dat: for each_step, scanner in itertools.izip(each_sample, scanners): -- GitLab