未验证 提交 e7848c31 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Cherry-Pick] Add check_dtype for some API (part 1) (#53079)

【Part I】补充API静态图中的check_dtype支持对float16和bfloat16的检查
上级 00b7c819
......@@ -184,7 +184,7 @@ def gelu(x, approximate=False, name=None):
return _C_ops.gelu(x, approximate)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'gelu'
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'gelu'
)
helper = LayerHelper("gelu", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......
......@@ -1150,7 +1150,7 @@ def dropout(
else:
helper = LayerHelper('dropout', **locals())
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'dropout'
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'dropout'
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -1191,7 +1191,7 @@ def dropout(
else: # sometimes called dropout_nd #TODO: optimize with c++
if not in_dynamic_mode():
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'dropout'
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'dropout'
)
dtype = x.dtype
keep_prob = 1 - p
......@@ -1407,7 +1407,7 @@ def alpha_dropout(x, p=0.5, training=True, name=None):
if not in_dynamic_mode():
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'alpha_dropout'
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'alpha_dropout'
)
if training:
......
......@@ -409,7 +409,10 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
else:
helper = LayerHelper("temporal_shift", **locals())
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'temporal_shift'
x,
'x',
['float16', 'uint16', 'float32', 'float64'],
'temporal_shift',
)
check_type(seg_num, 'seg_num', int, 'temporal_shift')
check_type(shift_ratio, 'shift_ratio', float, 'temporal_shift')
......
......@@ -397,7 +397,9 @@ def avg_pool2d(
else:
op_type = 'pool2d'
helper = LayerHelper(op_type, **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'avg_pool2d')
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'avg_pool2d'
)
dtype = helper.input_dtype(input_param_name='x')
pool_out = helper.create_variable_for_type_inference(dtype)
......@@ -1259,7 +1261,7 @@ def max_pool2d(
op_type = 'max_pool2d_with_index' if return_mask else "pool2d"
helper = LayerHelper(op_type, **locals())
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'max_pool2d'
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'max_pool2d'
)
dtype = helper.input_dtype(input_param_name='x')
pool_out = helper.create_variable_for_type_inference(dtype)
......@@ -1419,7 +1421,9 @@ def max_pool3d(
else:
op_type = "max_pool3d_with_index" if return_mask else "pool3d"
helper = LayerHelper(op_type, **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool3d')
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'max_pool3d'
)
dtype = helper.input_dtype(input_param_name='x')
pool_out = helper.create_variable_for_type_inference(dtype)
mask = helper.create_variable_for_type_inference('int32')
......
......@@ -333,7 +333,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype(
start.dtype,
'start',
['float32', 'float64', 'int32', 'int64', 'float16', 'bfloat16'],
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'linspace',
)
else:
......@@ -343,7 +343,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype(
stop.dtype,
'stop',
['float32', 'float64', 'int32', 'int64', 'float16', 'bfloat16'],
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'linspace',
)
else:
......@@ -353,7 +353,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype(
dtype,
'dtype',
['int32', 'int64', 'float32', 'float64', 'float16', 'bfloat16'],
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'linspace',
)
if (
......@@ -1787,7 +1787,7 @@ def diag(x, offset=0, padding_value=0, name=None):
check_dtype(
x.dtype,
'x',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'diag_v2',
)
check_type(offset, 'offset', (int), 'diag_v2')
......
......@@ -1112,10 +1112,16 @@ def dot(x, y, name=None):
assert y is not None, f'y cannot be None in {op_type}'
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], op_type
x,
'x',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
op_type,
)
check_variable_and_dtype(
y, 'y', ['float32', 'float64', 'int32', 'int64'], op_type
y,
'y',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
op_type,
)
helper = LayerHelper(op_type, **locals())
......@@ -1374,13 +1380,13 @@ def cross(x, y, axis=9, name=None):
check_variable_and_dtype(
x,
'x',
['float16', 'float32', 'float64', "int32", "int64"],
['float16', 'uint16', 'float32', 'float64', "int32", "int64"],
'cross',
)
check_variable_and_dtype(
y,
'y',
['float16', 'float32', 'float64', "int32", "int64"],
['float16', 'uint16', 'float32', 'float64', "int32", "int64"],
'cross',
)
helper = LayerHelper("cross", **locals())
......
......@@ -1971,6 +1971,7 @@ def split(x, num_or_sections, axis=0, name=None):
[
'bool',
'float16',
'uint16',
'float32',
'float64',
'int32',
......@@ -2803,7 +2804,15 @@ def unbind(input, axis=0):
check_dtype(
dtype,
'unbind',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
[
'bool',
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
],
'unbind',
)
outs = [
......
......@@ -513,7 +513,15 @@ def _elementwise_op(helper):
'complex128',
]
else:
data_type = ['float16', 'float32', 'float64', 'int32', 'int64', 'bool']
data_type = [
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'bool',
]
check_variable_and_dtype(
x,
'x',
......@@ -3118,7 +3126,15 @@ def diagonal(x, offset=0, axis1=0, axis2=1, name=None):
check_dtype(
x.dtype,
'Input',
['bool', 'int32', 'int64', 'float16', 'float32', 'float64'],
[
'bool',
'int32',
'int64',
'float16',
'uint16',
'float32',
'float64',
],
'diagonal',
)
......@@ -3277,7 +3293,7 @@ def cumsum(x, axis=None, dtype=None, name=None):
check_variable_and_dtype(
x,
'x',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'cumsum',
)
check_type(x, 'x', (Variable), 'cumsum')
......@@ -3425,7 +3441,16 @@ def cumprod(x, dim=None, dtype=None, name=None):
check_variable_and_dtype(
x,
"x",
['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64'],
[
'complex64',
'complex128',
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
],
'cumprod',
)
check_type(dim, 'dim', int, 'cumprod')
......@@ -3995,6 +4020,7 @@ def conj(x, name=None):
'complex64',
'complex128',
'float16',
'uint16',
'float32',
'float64',
'int32',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册