未验证 提交 ebea0885 编写于 作者: 陈沧夜 提交者: GitHub

fix fp16 dtype checking for paddle.diag API (#50848)

上级 9951b86f
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
...@@ -42,14 +43,26 @@ class TestDiagOpCase1(TestDiagOp): ...@@ -42,14 +43,26 @@ class TestDiagOpCase1(TestDiagOp):
self.case = np.array([3], dtype='int32') self.case = np.array([3], dtype='int32')
class TestDiagOpFp16(unittest.TestCase):
def test_fp16(self):
x_np = np.array([3], dtype='float16')
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=[1, 0], name='x', dtype='float16')
out = paddle.diag(x)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_np}, fetch_list=[out])
class TestDiagError(unittest.TestCase): class TestDiagError(unittest.TestCase):
def test_errors(self): def test_errors(self):
paddle.enable_static() paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
def test_diag_type(): def test_diag_type():
x = [1, 2, 3] return paddle.diag(x=[1, 2, 3])
output = paddle.diag(x=x)
self.assertRaises(TypeError, test_diag_type) self.assertRaises(TypeError, test_diag_type)
......
...@@ -1627,7 +1627,7 @@ def diag(x, offset=0, padding_value=0, name=None): ...@@ -1627,7 +1627,7 @@ def diag(x, offset=0, padding_value=0, name=None):
If ``offset`` < 0, it is subdiagonal. If ``offset`` < 0, it is subdiagonal.
Args: Args:
x (Tensor): The input tensor. Its shape is either 1-D or 2-D. Its data type should be float32, float64, int32, int64. x (Tensor): The input tensor. Its shape is either 1-D or 2-D. Its data type should be float16, float32, float64, int32, int64.
offset (int, optional): The diagonal offset. A positive value represents superdiagonal, 0 represents the main diagonal, and a negative value represents subdiagonal. offset (int, optional): The diagonal offset. A positive value represents superdiagonal, 0 represents the main diagonal, and a negative value represents subdiagonal.
padding_value (int|float, optional): Use this value to fill the area outside the specified diagonal band. Only takes effect when the input is a 1-D Tensor. The default value is 0. padding_value (int|float, optional): Use this value to fill the area outside the specified diagonal band. Only takes effect when the input is a 1-D Tensor. The default value is 0.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
...@@ -1694,7 +1694,7 @@ def diag(x, offset=0, padding_value=0, name=None): ...@@ -1694,7 +1694,7 @@ def diag(x, offset=0, padding_value=0, name=None):
check_dtype( check_dtype(
x.dtype, x.dtype,
'x', 'x',
['float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64'],
'diag_v2', 'diag_v2',
) )
check_type(offset, 'offset', (int), 'diag_v2') check_type(offset, 'offset', (int), 'diag_v2')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册