未验证 提交 e6ca512a 编写于 作者: W wangchaochaohu 提交者: GitHub

refine convert type for numpy type (#22386)

上级 20f30dd6
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import print_function from __future__ import print_function
from . import core from . import core
import numpy import numpy as np
import os import os
import six import six
from six.moves import zip, range, xrange from six.moves import zip, range, xrange
...@@ -47,6 +47,12 @@ def convert_dtype(dtype): ...@@ -47,6 +47,12 @@ def convert_dtype(dtype):
return 'int64' return 'int64'
elif dtype == core.VarDesc.VarType.UINT8: elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8' return 'uint8'
elif isinstance(dtype, type):
if dtype in [
np.bool, np.float16, np.float32, np.float64, np.int8, np.int16,
np.int32, np.int64, np.uint8
]:
return dtype.__name__
else: else:
if dtype in [ if dtype in [
'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'bool', 'float16', 'float32', 'float64', 'int8', 'int16',
...@@ -136,7 +142,7 @@ class DataToLoDTensorConverter(object): ...@@ -136,7 +142,7 @@ class DataToLoDTensorConverter(object):
format(self.shape, shape)) format(self.shape, shape))
def done(self): def done(self):
arr = numpy.array(self.data, dtype=self.dtype) arr = np.array(self.data, dtype=self.dtype)
if self.shape: if self.shape:
if len(arr.shape) != len(self.shape): if len(arr.shape) != len(self.shape):
try: try:
......
...@@ -216,20 +216,24 @@ class TestFillConstantAPI(unittest.TestCase): ...@@ -216,20 +216,24 @@ class TestFillConstantAPI(unittest.TestCase):
out_5 = fluid.layers.fill_constant( out_5 = fluid.layers.fill_constant(
shape=shape_tensor_int64, dtype="float32", value=1.1) shape=shape_tensor_int64, dtype="float32", value=1.1)
out_6 = fluid.layers.fill_constant(
shape=shape_tensor_int64, dtype=np.float32, value=1.1)
exe = fluid.Executor(place=fluid.CPUPlace()) exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3, res_4, res_5 = exe.run( res_1, res_2, res_3, res_4, res_5, res_6 = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed={ feed={
"shape_tensor_int32": np.array([1, 2]).astype("int32"), "shape_tensor_int32": np.array([1, 2]).astype("int32"),
"shape_tensor_int64": np.array([1, 2]).astype("int64"), "shape_tensor_int64": np.array([1, 2]).astype("int64"),
}, },
fetch_list=[out_1, out_2, out_3, out_4, out_5]) fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6])
assert np.array_equal(res_1, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_1, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_2, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_2, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_3, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_3, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_4, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_4, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_5, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_5, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_6, np.full([1, 2], 1.1, dtype="float32"))
class TestFillConstantOpError(unittest.TestCase): class TestFillConstantOpError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册