未验证 提交 e6ca512a 编写于 作者: W wangchaochaohu 提交者: GitHub

refine convert type for numpy type (#22386)

上级 20f30dd6
......@@ -15,7 +15,7 @@
from __future__ import print_function
from . import core
import numpy
import numpy as np
import os
import six
from six.moves import zip, range, xrange
......@@ -47,6 +47,12 @@ def convert_dtype(dtype):
return 'int64'
elif dtype == core.VarDesc.VarType.UINT8:
return 'uint8'
elif isinstance(dtype, type):
if dtype in [
np.bool, np.float16, np.float32, np.float64, np.int8, np.int16,
np.int32, np.int64, np.uint8
]:
return dtype.__name__
else:
if dtype in [
'bool', 'float16', 'float32', 'float64', 'int8', 'int16',
......@@ -136,7 +142,7 @@ class DataToLoDTensorConverter(object):
format(self.shape, shape))
def done(self):
arr = numpy.array(self.data, dtype=self.dtype)
arr = np.array(self.data, dtype=self.dtype)
if self.shape:
if len(arr.shape) != len(self.shape):
try:
......
......@@ -216,20 +216,24 @@ class TestFillConstantAPI(unittest.TestCase):
out_5 = fluid.layers.fill_constant(
shape=shape_tensor_int64, dtype="float32", value=1.1)
out_6 = fluid.layers.fill_constant(
shape=shape_tensor_int64, dtype=np.float32, value=1.1)
exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3, res_4, res_5 = exe.run(
res_1, res_2, res_3, res_4, res_5, res_6 = exe.run(
fluid.default_main_program(),
feed={
"shape_tensor_int32": np.array([1, 2]).astype("int32"),
"shape_tensor_int64": np.array([1, 2]).astype("int64"),
},
fetch_list=[out_1, out_2, out_3, out_4, out_5])
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6])
assert np.array_equal(res_1, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_2, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_3, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_4, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_5, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_6, np.full([1, 2], 1.1, dtype="float32"))
class TestFillConstantOpError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册