未验证 提交 40fea722 编写于 作者: Y yeliang2258 提交者: GitHub

[AMP] Add uint16 dtype check for compare ops (#52016)

* add uint16 dtype check for compare ops

* update doc
上级 b110085f
...@@ -478,13 +478,29 @@ def equal(x, y, name=None): ...@@ -478,13 +478,29 @@ def equal(x, y, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
"x", "x",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"equal", "equal",
) )
check_variable_and_dtype( check_variable_and_dtype(
y, y,
"y", "y",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"equal", "equal",
) )
helper = LayerHelper("equal", **locals()) helper = LayerHelper("equal", **locals())
...@@ -531,13 +547,29 @@ def greater_equal(x, y, name=None): ...@@ -531,13 +547,29 @@ def greater_equal(x, y, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
"x", "x",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"greater_equal", "greater_equal",
) )
check_variable_and_dtype( check_variable_and_dtype(
y, y,
"y", "y",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"greater_equal", "greater_equal",
) )
helper = LayerHelper("greater_equal", **locals()) helper = LayerHelper("greater_equal", **locals())
...@@ -584,13 +616,29 @@ def greater_than(x, y, name=None): ...@@ -584,13 +616,29 @@ def greater_than(x, y, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
"x", "x",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"greater_than", "greater_than",
) )
check_variable_and_dtype( check_variable_and_dtype(
y, y,
"y", "y",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"greater_than", "greater_than",
) )
helper = LayerHelper("greater_than", **locals()) helper = LayerHelper("greater_than", **locals())
...@@ -638,13 +686,29 @@ def less_equal(x, y, name=None): ...@@ -638,13 +686,29 @@ def less_equal(x, y, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
"x", "x",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"less_equal", "less_equal",
) )
check_variable_and_dtype( check_variable_and_dtype(
y, y,
"y", "y",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"less_equal", "less_equal",
) )
helper = LayerHelper("less_equal", **locals()) helper = LayerHelper("less_equal", **locals())
...@@ -692,13 +756,29 @@ def less_than(x, y, name=None): ...@@ -692,13 +756,29 @@ def less_than(x, y, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
"x", "x",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"less_than", "less_than",
) )
check_variable_and_dtype( check_variable_and_dtype(
y, y,
"y", "y",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"less_than", "less_than",
) )
helper = LayerHelper("less_than", **locals()) helper = LayerHelper("less_than", **locals())
...@@ -746,13 +826,29 @@ def not_equal(x, y, name=None): ...@@ -746,13 +826,29 @@ def not_equal(x, y, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
"x", "x",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"not_equal", "not_equal",
) )
check_variable_and_dtype( check_variable_and_dtype(
y, y,
"y", "y",
["bool", "float16", "float32", "float64", "int32", "int64"], [
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"not_equal", "not_equal",
) )
helper = LayerHelper("not_equal", **locals()) helper = LayerHelper("not_equal", **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册