未验证 提交 6b85eb59 编写于 作者: I Infinity_lee 提交者: GitHub

fix fp16 dtype checking for conj op (#50868)

上级 3e9ffaef
......@@ -133,5 +133,20 @@ class TestComplexConjOp(unittest.TestCase):
np.testing.assert_array_equal(result, target)
class Testfp16ConjOp(unittest.TestCase):
def testfp16(self):
input_x = (
np.random.random((12, 14)) + 1j * np.random.random((12, 14))
).astype('float16')
with static.program_guard(static.Program()):
x = static.data(name="x", shape=[12, 14], dtype='float16')
out = paddle.conj(x)
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': input_x}, fetch_list=[out])
if __name__ == "__main__":
unittest.main()
......@@ -3861,7 +3861,7 @@ def conj(x, name=None):
Args:
x (Tensor): The input Tensor which hold the complex numbers.
Optional data types are: complex64, complex128, float32, float64, int32 or int64.
Optional data types are:float16, complex64, complex128, float32, float64, int32 or int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -3889,7 +3889,15 @@ def conj(x, name=None):
check_variable_and_dtype(
x,
"x",
['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64'],
[
'complex64',
'complex128',
'float16',
'float32',
'float64',
'int32',
'int64',
],
'conj',
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册