提交 c397e136 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1121 from reyoung/feature/fix_ndarray_dtypes

Fix bug in DenseScanner of DataProviderConverter.
...@@ -34,6 +34,10 @@ class IScanner(object): ...@@ -34,6 +34,10 @@ class IScanner(object):
class DenseScanner(IScanner): class DenseScanner(IScanner):
"""
:type __mat__: numpy.ndarray
"""
def __init__(self, input_type, pos): def __init__(self, input_type, pos):
IScanner.__init__(self, input_type, pos) IScanner.__init__(self, input_type, pos)
self.__mat__ = None self.__mat__ = None
...@@ -47,6 +51,8 @@ class DenseScanner(IScanner): ...@@ -47,6 +51,8 @@ class DenseScanner(IScanner):
def finish_scan(self, argument): def finish_scan(self, argument):
assert isinstance(argument, swig_paddle.Arguments) assert isinstance(argument, swig_paddle.Arguments)
assert isinstance(self.input_type, dp2.InputType) assert isinstance(self.input_type, dp2.InputType)
if self.__mat__.dtype != numpy.float32:
self.__mat__ = self.__mat__.astype(numpy.float32)
m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, False) m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, False)
argument.setSlotValue(self.pos, m) argument.setSlotValue(self.pos, m)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册