提交 efb5c10c 编写于 作者: Y Yu Yang

Merge branch 'feature/fix_swig_dense_scanner' into feature/mnist_train_api

......@@ -15,6 +15,7 @@
import paddle.trainer.PyDataProvider2 as dp2
import collections
import swig_paddle
import numpy
__all__ = ['DataProviderConverter']
......@@ -35,18 +36,18 @@ class IScanner(object):
class DenseScanner(IScanner):
def __init__(self, input_type, pos):
IScanner.__init__(self, input_type, pos)
self.__mat__ = []
self.__height__ = 0
self.__mat__ = None
def scan(self, dat):
self.__mat__.extend(dat)
self.__height__ += 1
if self.__mat__ is None:
self.__mat__ = numpy.array([dat], dtype='float32')
else:
self.__mat__ = numpy.append(self.__mat__, [dat], axis=0)
def finish_scan(self, argument):
assert isinstance(argument, swig_paddle.Arguments)
assert isinstance(self.input_type, dp2.InputType)
m = swig_paddle.Matrix.createDense(self.__mat__, self.__height__,
self.input_type.dim, False)
m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, False)
argument.setSlotValue(self.pos, m)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部