From 74da070e5ca0e9fb82298320f88a437cbc06ade8 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 14 Mar 2017 17:30:55 +0800 Subject: [PATCH] Speed up dense converter. --- paddle/py_paddle/dataprovider_converter.py | 31 ++++++++++++++++++---- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/paddle/py_paddle/dataprovider_converter.py b/paddle/py_paddle/dataprovider_converter.py index c009b05cdee..f1ed57f13fc 100644 --- a/paddle/py_paddle/dataprovider_converter.py +++ b/paddle/py_paddle/dataprovider_converter.py @@ -16,6 +16,7 @@ import paddle.trainer.PyDataProvider2 as dp2 import collections import swig_paddle import numpy +import itertools __all__ = ['DataProviderConverter'] @@ -36,6 +37,12 @@ class IScanner(object): self.data_in_gpu = swig_paddle.isUsingGpu( ) and swig_paddle.getTrainerCount() == 1 + def pre_scan(self, dat): + pass + + def finish_pre_scan(self, argument): + pass + def scan(self, dat): pass @@ -51,12 +58,19 @@ class DenseScanner(IScanner): def __init__(self, input_type, pos): IScanner.__init__(self, input_type, pos) self.__mat__ = None + self.__height__ = 0 + + def pre_scan(self, dat): + self.__height__ += 1 + + def finish_pre_scan(self, argument): + self.__mat__ = numpy.ndarray( + shape=(self.__height__, self.input_type.dim), dtype=numpy.float32) + self.__height__ = 0 def scan(self, dat): - if self.__mat__ is None: - self.__mat__ = numpy.array([dat], dtype='float32') - else: - self.__mat__ = numpy.append(self.__mat__, [dat], axis=0) + self.__mat__[self.__height__] = dat + self.__height__ += 1 def finish_scan(self, argument): assert isinstance(argument, swig_paddle.Arguments) @@ -163,7 +177,14 @@ class DataProviderConverter(object): ] for each_sample in dat: - for each_step, scanner in zip(each_sample, scanners): + for each_step, scanner in itertools.izip(each_sample, scanners): + scanner.pre_scan(each_step) + + for scanner in scanners: + scanner.finish_pre_scan(argument) + + for each_sample in dat: + for each_step, scanner in itertools.izip(each_sample, scanners): scanner.scan(each_step) for scanner in scanners: -- GitLab