提交 7e99af55 编写于 作者: D dangqingqing

follow comments

上级 7430d305
...@@ -60,7 +60,7 @@ class IScanner(object): ...@@ -60,7 +60,7 @@ class IScanner(object):
""" """
pass pass
def finish_pre_scan(self, argument, dat=None): def finish_pre_scan(self, argument):
""" """
Finish first scan pass. Allocate the memory. Finish first scan pass. Allocate the memory.
...@@ -103,23 +103,29 @@ class DenseScanner(IScanner): ...@@ -103,23 +103,29 @@ class DenseScanner(IScanner):
def pre_scan(self, dat): def pre_scan(self, dat):
self.__height__ += 1 self.__height__ += 1
if self.__shape__ is None:
self.__shape__ = numpy.array(dat).shape
if len(self.__shape__) > 3:
raise ValueError(
"The dimension of input cannot be greater than 3.")
else:
if self.__shape__ != numpy.array(dat).shape:
raise ValueError(
"The data shape must be same in one mini-batch.")
def finish_pre_scan(self, argument, dat=None): def finish_pre_scan(self, argument):
self.__shape__ = numpy.array(dat).shape
if len(self.__shape__) > 3:
raise ValueError("The dimension of input is greater than 3.")
dim = reduce(lambda x, y: x * y, self.__shape__) dim = reduce(lambda x, y: x * y, self.__shape__)
if len(self.__shape__) == 1: if len(self.__shape__) == 1 and dim != self.input_type.dim:
assert dim == self.input_type.dim raise ValueError("The data size must be equal to it in data layer.")
self.__mat__ = numpy.ndarray( self.__mat__ = numpy.ndarray(
shape=(self.__height__, dim), dtype=numpy.float32) shape=(self.__height__, dim), dtype=numpy.float32)
self.__height__ = 0 self.__height__ = 0
def scan(self, dat): def scan(self, dat):
if isinstance(dat, numpy.ndarray): # It's better to use NumPy array for speed.
assert self.__shape__ == dat.shape d = numpy.array(dat)
dat = dat.flatten() d = d.flatten()
self.__mat__[self.__height__] = dat self.__mat__[self.__height__] = d
self.__height__ += 1 self.__height__ += 1
def finish_scan(self, argument): def finish_scan(self, argument):
...@@ -136,6 +142,7 @@ class DenseScanner(IScanner): ...@@ -136,6 +142,7 @@ class DenseScanner(IScanner):
h, w = self.__shape__[-2:] h, w = self.__shape__[-2:]
argument.setSlotFrameHeight(self.pos, h) argument.setSlotFrameHeight(self.pos, h)
argument.setSlotFrameWidth(self.pos, w) argument.setSlotFrameWidth(self.pos, w)
self.__shape__ = None
class SparseBinaryScanner(IScanner): class SparseBinaryScanner(IScanner):
...@@ -186,7 +193,7 @@ class IndexScanner(IScanner): ...@@ -186,7 +193,7 @@ class IndexScanner(IScanner):
def pre_scan(self, dat): def pre_scan(self, dat):
self.__idx__ += 1 self.__idx__ += 1
def finish_pre_scan(self, argument, dat=None): def finish_pre_scan(self, argument):
self.__ids__ = [0] * self.__idx__ self.__ids__ = [0] * self.__idx__
self.__idx__ = 0 self.__idx__ = 0
...@@ -211,8 +218,8 @@ class SequenceScanner(IScanner): ...@@ -211,8 +218,8 @@ class SequenceScanner(IScanner):
for each in dat: for each in dat:
self.__inner_scanner__.pre_scan(each) self.__inner_scanner__.pre_scan(each)
def finish_pre_scan(self, argument, dat=None): def finish_pre_scan(self, argument):
self.__inner_scanner__.finish_pre_scan(argument, dat) self.__inner_scanner__.finish_pre_scan(argument)
def scan(self, dat): def scan(self, dat):
self.__seq__.append(self.__seq__[-1] + self.get_size(dat)) self.__seq__.append(self.__seq__[-1] + self.get_size(dat))
...@@ -253,11 +260,8 @@ class DataProviderConverter(object): ...@@ -253,11 +260,8 @@ class DataProviderConverter(object):
for each_step, scanner in itertools.izip(each_sample, scanners): for each_step, scanner in itertools.izip(each_sample, scanners):
scanner.pre_scan(each_step) scanner.pre_scan(each_step)
# Some scanners, like dense scanner, pre-allocate memory for mini-batch for scanner in scanners:
# in finish_pre_scan function. The dat[0] is used to calculate the size scanner.finish_pre_scan(argument)
# of input data.
for scanner, each_feature in itertools.izip(scanners, dat[0]):
scanner.finish_pre_scan(argument, each_feature)
for each_sample in dat: for each_sample in dat:
for each_step, scanner in itertools.izip(each_sample, scanners): for each_step, scanner in itertools.izip(each_sample, scanners):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册