提交 efb5c10c 编写于 作者: Y Yu Yang

Merge branch 'feature/fix_swig_dense_scanner' into feature/mnist_train_api

...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import paddle.trainer.PyDataProvider2 as dp2 import paddle.trainer.PyDataProvider2 as dp2
import collections import collections
import swig_paddle import swig_paddle
import numpy
__all__ = ['DataProviderConverter'] __all__ = ['DataProviderConverter']
...@@ -35,18 +36,18 @@ class IScanner(object): ...@@ -35,18 +36,18 @@ class IScanner(object):
class DenseScanner(IScanner): class DenseScanner(IScanner):
def __init__(self, input_type, pos): def __init__(self, input_type, pos):
IScanner.__init__(self, input_type, pos) IScanner.__init__(self, input_type, pos)
self.__mat__ = [] self.__mat__ = None
self.__height__ = 0
def scan(self, dat): def scan(self, dat):
self.__mat__.extend(dat) if self.__mat__ is None:
self.__height__ += 1 self.__mat__ = numpy.array([dat], dtype='float32')
else:
self.__mat__ = numpy.append(self.__mat__, [dat], axis=0)
def finish_scan(self, argument): def finish_scan(self, argument):
assert isinstance(argument, swig_paddle.Arguments) assert isinstance(argument, swig_paddle.Arguments)
assert isinstance(self.input_type, dp2.InputType) assert isinstance(self.input_type, dp2.InputType)
m = swig_paddle.Matrix.createDense(self.__mat__, self.__height__, m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, False)
self.input_type.dim, False)
argument.setSlotValue(self.pos, m) argument.setSlotValue(self.pos, m)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册