From 40fea722cd243bb75afb91859c9bf17731ae93e7 Mon Sep 17 00:00:00 2001 From: yeliang2258 <30516196+yeliang2258@users.noreply.github.com> Date: Fri, 24 Mar 2023 10:34:24 +0800 Subject: [PATCH] [AMP] Add uint16 dtype check for compare ops (#52016) * add uint16 dtype check for compare ops * update doc --- python/paddle/tensor/logic.py | 120 ++++++++++++++++++++++++++++++---- 1 file changed, 108 insertions(+), 12 deletions(-) mode change 100644 => 100755 python/paddle/tensor/logic.py diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py old mode 100644 new mode 100755 index f214bf0c861..06f76c052b7 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -478,13 +478,29 @@ def equal(x, y, name=None): check_variable_and_dtype( x, "x", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "equal", ) check_variable_and_dtype( y, "y", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "equal", ) helper = LayerHelper("equal", **locals()) @@ -531,13 +547,29 @@ def greater_equal(x, y, name=None): check_variable_and_dtype( x, "x", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "greater_equal", ) check_variable_and_dtype( y, "y", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "greater_equal", ) helper = LayerHelper("greater_equal", **locals()) @@ -584,13 +616,29 @@ def greater_than(x, y, name=None): check_variable_and_dtype( x, "x", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "greater_than", ) check_variable_and_dtype( y, "y", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "greater_than", ) helper = LayerHelper("greater_than", **locals()) @@ -638,13 +686,29 @@ def less_equal(x, y, name=None): check_variable_and_dtype( x, "x", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "less_equal", ) check_variable_and_dtype( y, "y", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "less_equal", ) helper = LayerHelper("less_equal", **locals()) @@ -692,13 +756,29 @@ def less_than(x, y, name=None): check_variable_and_dtype( x, "x", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "less_than", ) check_variable_and_dtype( y, "y", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "less_than", ) helper = LayerHelper("less_than", **locals()) @@ -746,13 +826,29 @@ def not_equal(x, y, name=None): check_variable_and_dtype( x, "x", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "not_equal", ) check_variable_and_dtype( y, "y", - ["bool", "float16", "float32", "float64", "int32", "int64"], + [ + "bool", + "float16", + "float32", + "float64", + "int32", + "int64", + "uint16", + ], "not_equal", ) helper = LayerHelper("not_equal", **locals()) -- GitLab