未验证 提交 af149c0c 编写于 作者: L LoneRanger 提交者: GitHub

[fp16] suppot fp16 in diagflat (#50945)

上级 48060b2e
...@@ -104,6 +104,29 @@ class TestDiagFlatAPI(unittest.TestCase): ...@@ -104,6 +104,29 @@ class TestDiagFlatAPI(unittest.TestCase):
with paddle.static.program_guard(Program()): with paddle.static.program_guard(Program()):
self.run_static(use_gpu=True) self.run_static(use_gpu=True)
def test_fp16_with_gpu(self, use_gpu=False):
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input = np.random.random([10, 10]).astype("float16")
x = paddle.static.data(
name="x", shape=[10, 10], dtype="float16"
)
y = paddle.diagflat(x)
expected = np.diagflat(input)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={
"x": input,
},
fetch_list=[y],
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -1498,7 +1498,7 @@ def diagflat(x, offset=0, name=None): ...@@ -1498,7 +1498,7 @@ def diagflat(x, offset=0, name=None):
If ``offset`` < 0, it is subdiagonal. If ``offset`` < 0, it is subdiagonal.
Args: Args:
x (Tensor): The input tensor. It can be any shape. Its data type should be float32, float64, int32, int64. x (Tensor): The input tensor. It can be any shape. Its data type should be float16, float32, float64, int32, int64.
offset (int, optional): The diagonal offset. A positive value represents superdiagonal, 0 represents the main diagonal, and a negative value represents subdiagonal. Default: 0 (main diagonal). offset (int, optional): The diagonal offset. A positive value represents superdiagonal, 0 represents the main diagonal, and a negative value represents subdiagonal. Default: 0 (main diagonal).
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
...@@ -1577,7 +1577,10 @@ def diagflat(x, offset=0, name=None): ...@@ -1577,7 +1577,10 @@ def diagflat(x, offset=0, name=None):
padding_value = 0 padding_value = 0
check_type(x, 'x', (Variable), 'diagflat') check_type(x, 'x', (Variable), 'diagflat')
check_dtype( check_dtype(
x.dtype, 'x', ['float32', 'float64', 'int32', 'int64'], 'diagflat' x.dtype,
'x',
['float16', 'float32', 'float64', 'int32', 'int64'],
'diagflat',
) )
check_type(offset, 'offset', (int), 'diagflat') check_type(offset, 'offset', (int), 'diagflat')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册