未验证 提交 40fea722 编写于 作者: Y yeliang2258 提交者: GitHub

[AMP] Add uint16 dtype check for compare ops (#52016)

* add uint16 dtype check for compare ops

* update doc
上级 b110085f
......@@ -478,13 +478,29 @@ def equal(x, y, name=None):
check_variable_and_dtype(
x,
"x",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"equal",
)
check_variable_and_dtype(
y,
"y",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"equal",
)
helper = LayerHelper("equal", **locals())
......@@ -531,13 +547,29 @@ def greater_equal(x, y, name=None):
check_variable_and_dtype(
x,
"x",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"greater_equal",
)
check_variable_and_dtype(
y,
"y",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"greater_equal",
)
helper = LayerHelper("greater_equal", **locals())
......@@ -584,13 +616,29 @@ def greater_than(x, y, name=None):
check_variable_and_dtype(
x,
"x",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"greater_than",
)
check_variable_and_dtype(
y,
"y",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"greater_than",
)
helper = LayerHelper("greater_than", **locals())
......@@ -638,13 +686,29 @@ def less_equal(x, y, name=None):
check_variable_and_dtype(
x,
"x",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"less_equal",
)
check_variable_and_dtype(
y,
"y",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"less_equal",
)
helper = LayerHelper("less_equal", **locals())
......@@ -692,13 +756,29 @@ def less_than(x, y, name=None):
check_variable_and_dtype(
x,
"x",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"less_than",
)
check_variable_and_dtype(
y,
"y",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"less_than",
)
helper = LayerHelper("less_than", **locals())
......@@ -746,13 +826,29 @@ def not_equal(x, y, name=None):
check_variable_and_dtype(
x,
"x",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"not_equal",
)
check_variable_and_dtype(
y,
"y",
["bool", "float16", "float32", "float64", "int32", "int64"],
[
"bool",
"float16",
"float32",
"float64",
"int32",
"int64",
"uint16",
],
"not_equal",
)
helper = LayerHelper("not_equal", **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册