提交 7e99af55 编写于 作者: D dangqingqing

follow comments

上级 7430d305
......@@ -60,7 +60,7 @@ class IScanner(object):
"""
pass
def finish_pre_scan(self, argument, dat=None):
def finish_pre_scan(self, argument):
"""
Finish first scan pass. Allocate the memory.
......@@ -103,23 +103,29 @@ class DenseScanner(IScanner):
def pre_scan(self, dat):
self.__height__ += 1
def finish_pre_scan(self, argument, dat=None):
if self.__shape__ is None:
self.__shape__ = numpy.array(dat).shape
if len(self.__shape__) > 3:
raise ValueError("The dimension of input is greater than 3.")
raise ValueError(
"The dimension of input cannot be greater than 3.")
else:
if self.__shape__ != numpy.array(dat).shape:
raise ValueError(
"The data shape must be same in one mini-batch.")
def finish_pre_scan(self, argument):
dim = reduce(lambda x, y: x * y, self.__shape__)
if len(self.__shape__) == 1:
assert dim == self.input_type.dim
if len(self.__shape__) == 1 and dim != self.input_type.dim:
raise ValueError("The data size must be equal to it in data layer.")
self.__mat__ = numpy.ndarray(
shape=(self.__height__, dim), dtype=numpy.float32)
self.__height__ = 0
def scan(self, dat):
if isinstance(dat, numpy.ndarray):
assert self.__shape__ == dat.shape
dat = dat.flatten()
self.__mat__[self.__height__] = dat
# It's better to use NumPy array for speed.
d = numpy.array(dat)
d = d.flatten()
self.__mat__[self.__height__] = d
self.__height__ += 1
def finish_scan(self, argument):
......@@ -136,6 +142,7 @@ class DenseScanner(IScanner):
h, w = self.__shape__[-2:]
argument.setSlotFrameHeight(self.pos, h)
argument.setSlotFrameWidth(self.pos, w)
self.__shape__ = None
class SparseBinaryScanner(IScanner):
......@@ -186,7 +193,7 @@ class IndexScanner(IScanner):
def pre_scan(self, dat):
self.__idx__ += 1
def finish_pre_scan(self, argument, dat=None):
def finish_pre_scan(self, argument):
self.__ids__ = [0] * self.__idx__
self.__idx__ = 0
......@@ -211,8 +218,8 @@ class SequenceScanner(IScanner):
for each in dat:
self.__inner_scanner__.pre_scan(each)
def finish_pre_scan(self, argument, dat=None):
self.__inner_scanner__.finish_pre_scan(argument, dat)
def finish_pre_scan(self, argument):
self.__inner_scanner__.finish_pre_scan(argument)
def scan(self, dat):
self.__seq__.append(self.__seq__[-1] + self.get_size(dat))
......@@ -253,11 +260,8 @@ class DataProviderConverter(object):
for each_step, scanner in itertools.izip(each_sample, scanners):
scanner.pre_scan(each_step)
# Some scanners, like dense scanner, pre-allocate memory for mini-batch
# in finish_pre_scan function. The dat[0] is used to calculate the size
# of input data.
for scanner, each_feature in itertools.izip(scanners, dat[0]):
scanner.finish_pre_scan(argument, each_feature)
for scanner in scanners:
scanner.finish_pre_scan(argument)
for each_sample in dat:
for each_step, scanner in itertools.izip(each_sample, scanners):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册