diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index 8ae2d9221862c86afdd821e241218a905c2b98a8..8a4608a5450868e0726c5fb640dbf07249d3f320 100755 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -480,6 +480,22 @@ class API_TestElementwise_Equal(unittest.TestCase): self.assertEqual((res == np.array([True, False])).all(), True) +class API_TestElementwise_Greater_Than(unittest.TestCase): + def test_api_fp16(self): + paddle.enable_static() + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + label = paddle.to_tensor([3, 3], dtype="float16") + limit = paddle.to_tensor([3, 2], dtype="float16") + out = paddle.greater_than(x=label, y=limit) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + (res,) = exe.run(fetch_list=[out]) + self.assertEqual((res == np.array([False, True])).all(), True) + + class TestCompareOpPlace(unittest.TestCase): def test_place_1(self): paddle.enable_static() diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index cb02050002c83234e346ce61e9b7baee6daeabfb..0ca0935c88a560dfea19017056e10a92e684a805 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -548,8 +548,8 @@ def greater_than(x, y, name=None): The output has no gradient. Args: - x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. - y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. + x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. + y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -571,13 +571,13 @@ def greater_than(x, y, name=None): check_variable_and_dtype( x, "x", - ["bool", "float32", "float64", "int32", "int64"], + ["bool", "float16", "float32", "float64", "int32", "int64"], "greater_than", ) check_variable_and_dtype( y, "y", - ["bool", "float32", "float64", "int32", "int64"], + ["bool", "float16", "float32", "float64", "int32", "int64"], "greater_than", ) helper = LayerHelper("greater_than", **locals())