提交 74da070e 编写于 作者: Y Yu Yang

Speed up dense converter.

上级 85189e8d
......@@ -16,6 +16,7 @@ import paddle.trainer.PyDataProvider2 as dp2
import collections
import swig_paddle
import numpy
import itertools
__all__ = ['DataProviderConverter']
......@@ -36,6 +37,12 @@ class IScanner(object):
self.data_in_gpu = swig_paddle.isUsingGpu(
) and swig_paddle.getTrainerCount() == 1
def pre_scan(self, dat):
pass
def finish_pre_scan(self, argument):
pass
def scan(self, dat):
pass
......@@ -51,12 +58,19 @@ class DenseScanner(IScanner):
def __init__(self, input_type, pos):
IScanner.__init__(self, input_type, pos)
self.__mat__ = None
self.__height__ = 0
def pre_scan(self, dat):
self.__height__ += 1
def finish_pre_scan(self, argument):
self.__mat__ = numpy.ndarray(
shape=(self.__height__, self.input_type.dim), dtype=numpy.float32)
self.__height__ = 0
def scan(self, dat):
if self.__mat__ is None:
self.__mat__ = numpy.array([dat], dtype='float32')
else:
self.__mat__ = numpy.append(self.__mat__, [dat], axis=0)
self.__mat__[self.__height__] = dat
self.__height__ += 1
def finish_scan(self, argument):
assert isinstance(argument, swig_paddle.Arguments)
......@@ -163,7 +177,14 @@ class DataProviderConverter(object):
]
for each_sample in dat:
for each_step, scanner in zip(each_sample, scanners):
for each_step, scanner in itertools.izip(each_sample, scanners):
scanner.pre_scan(each_step)
for scanner in scanners:
scanner.finish_pre_scan(argument)
for each_sample in dat:
for each_step, scanner in itertools.izip(each_sample, scanners):
scanner.scan(each_step)
for scanner in scanners:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册