未验证 提交 0dc485e6 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the value parameter's Tensor support of fill_constant Op test=… (#25986)

上级 168ea223
...@@ -685,8 +685,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -685,8 +685,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
""" """
attrs = {'force_cpu': force_cpu} attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype)
if not isinstance(value, Variable): if not isinstance(value, Variable):
if convert_dtype(dtype) in ['int64', 'int32']: if dtype in ['int64', 'int32']:
attrs['str_value'] = str(int(value)) attrs['str_value'] = str(int(value))
else: else:
attrs['str_value'] = str(float(value)) attrs['str_value'] = str(float(value))
...@@ -697,7 +698,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -697,7 +698,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
out = _varbase_creator(dtype=dtype) out = _varbase_creator(dtype=dtype)
if isinstance(value, Variable): if isinstance(value, Variable):
if convert_dtype(dtype) in ['int64', 'int32']: if dtype in ['int64', 'int32']:
attrs['str_value'] = str(int(value.numpy())) attrs['str_value'] = str(int(value.numpy()))
else: else:
attrs['str_value'] = str(float(value.numpy())) attrs['str_value'] = str(float(value.numpy()))
...@@ -712,6 +713,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -712,6 +713,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
helper = LayerHelper("fill_constant", **locals()) helper = LayerHelper("fill_constant", **locals())
inputs = {} inputs = {}
if isinstance(value, Variable): if isinstance(value, Variable):
if convert_dtype(value.dtype) != dtype:
value = cast(value, dtype)
inputs['ValueTensor'] = value inputs['ValueTensor'] = value
check_dtype(dtype, 'dtype', check_dtype(dtype, 'dtype',
......
...@@ -269,18 +269,26 @@ class TestFillConstantAPI(unittest.TestCase): ...@@ -269,18 +269,26 @@ class TestFillConstantAPI(unittest.TestCase):
out_6 = fluid.layers.fill_constant( out_6 = fluid.layers.fill_constant(
shape=shape_tensor_int64, dtype=np.float32, value=1.1) shape=shape_tensor_int64, dtype=np.float32, value=1.1)
val = fluid.layers.fill_constant(shape=[1], dtype=np.float32, value=1.1) val1 = fluid.layers.fill_constant(
shape=[1], dtype=np.float32, value=1.1)
val2 = fluid.layers.fill_constant(
shape=[1], dtype=np.float64, value=1.1)
out_7 = fluid.layers.fill_constant( out_7 = fluid.layers.fill_constant(
shape=shape_tensor_int64, dtype=np.float32, value=val) shape=shape_tensor_int64, dtype=np.float32, value=val1)
out_8 = fluid.layers.fill_constant(
shape=shape_tensor_int64, dtype=np.float32, value=val2)
exe = fluid.Executor(place=fluid.CPUPlace()) exe = fluid.Executor(place=fluid.CPUPlace())
res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( res_1, res_2, res_3, res_4, res_5, res_6, res_7, res_8 = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed={ feed={
"shape_tensor_int32": np.array([1, 2]).astype("int32"), "shape_tensor_int32": np.array([1, 2]).astype("int32"),
"shape_tensor_int64": np.array([1, 2]).astype("int64"), "shape_tensor_int64": np.array([1, 2]).astype("int64"),
}, },
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7]) fetch_list=[
out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8
])
assert np.array_equal(res_1, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_1, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_2, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_2, np.full([1, 2], 1.1, dtype="float32"))
...@@ -289,6 +297,31 @@ class TestFillConstantAPI(unittest.TestCase): ...@@ -289,6 +297,31 @@ class TestFillConstantAPI(unittest.TestCase):
assert np.array_equal(res_5, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_5, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_6, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_6, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_7, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_7, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_8, np.full([1, 2], 1.1, dtype="float32"))
class TestFillConstantImperative(unittest.TestCase):
def test_api(self):
with fluid.dygraph.guard():
data1 = np.array([1, 2]).astype('int32')
data2 = np.array([1.1]).astype('float32')
shape = fluid.dygraph.to_variable(data1)
val = fluid.dygraph.to_variable(data2)
res1 = fluid.layers.fill_constant(
shape=[1, 2], dtype='float32', value=1.1)
res2 = fluid.layers.fill_constant(
shape=shape, dtype='float32', value=1.1)
res3 = fluid.layers.fill_constant(
shape=shape, dtype='float32', value=val)
assert np.array_equal(
res1.numpy(), np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
res2.numpy(), np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
res3.numpy(), np.full(
[1, 2], 1.1, dtype="float32"))
class TestFillConstantOpError(unittest.TestCase): class TestFillConstantOpError(unittest.TestCase):
......
...@@ -248,7 +248,7 @@ def zeros(shape, dtype=None, name=None): ...@@ -248,7 +248,7 @@ def zeros(shape, dtype=None, name=None):
# shape is a Tensor # shape is a Tensor
shape = paddle.fill_constant(shape=[2], dtype='int32', value=2) shape = paddle.fill_constant(shape=[2], dtype='int32', value=2)
data3 = paddle.ones(shape=shape, dtype='int32') data3 = paddle.zeros(shape=shape, dtype='int32')
# [[0 0] # [[0 0]
# [0 0]] # [0 0]]
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册