未验证 提交 0d4f4960 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Add check_dtype for some API (part 1) (#53077)

* [AMP OP&Test] Add check_dtype for some API (part 1)

* fix
上级 3206fa80
...@@ -184,7 +184,7 @@ def gelu(x, approximate=False, name=None): ...@@ -184,7 +184,7 @@ def gelu(x, approximate=False, name=None):
return _C_ops.gelu(x, approximate) return _C_ops.gelu(x, approximate)
else: else:
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'gelu' x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'gelu'
) )
helper = LayerHelper("gelu", **locals()) helper = LayerHelper("gelu", **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
......
...@@ -1150,7 +1150,7 @@ def dropout( ...@@ -1150,7 +1150,7 @@ def dropout(
else: else:
helper = LayerHelper('dropout', **locals()) helper = LayerHelper('dropout', **locals())
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'dropout' x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'dropout'
) )
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
...@@ -1191,7 +1191,7 @@ def dropout( ...@@ -1191,7 +1191,7 @@ def dropout(
else: # sometimes called dropout_nd #TODO: optimize with c++ else: # sometimes called dropout_nd #TODO: optimize with c++
if not in_dynamic_mode(): if not in_dynamic_mode():
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'dropout' x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'dropout'
) )
dtype = x.dtype dtype = x.dtype
keep_prob = 1 - p keep_prob = 1 - p
...@@ -1407,7 +1407,7 @@ def alpha_dropout(x, p=0.5, training=True, name=None): ...@@ -1407,7 +1407,7 @@ def alpha_dropout(x, p=0.5, training=True, name=None):
if not in_dynamic_mode(): if not in_dynamic_mode():
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'alpha_dropout' x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'alpha_dropout'
) )
if training: if training:
......
...@@ -409,7 +409,10 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"): ...@@ -409,7 +409,10 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
else: else:
helper = LayerHelper("temporal_shift", **locals()) helper = LayerHelper("temporal_shift", **locals())
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'temporal_shift' x,
'x',
['float16', 'uint16', 'float32', 'float64'],
'temporal_shift',
) )
check_type(seg_num, 'seg_num', int, 'temporal_shift') check_type(seg_num, 'seg_num', int, 'temporal_shift')
check_type(shift_ratio, 'shift_ratio', float, 'temporal_shift') check_type(shift_ratio, 'shift_ratio', float, 'temporal_shift')
......
...@@ -397,7 +397,9 @@ def avg_pool2d( ...@@ -397,7 +397,9 @@ def avg_pool2d(
else: else:
op_type = 'pool2d' op_type = 'pool2d'
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'avg_pool2d') check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'avg_pool2d'
)
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
...@@ -1259,7 +1261,7 @@ def max_pool2d( ...@@ -1259,7 +1261,7 @@ def max_pool2d(
op_type = 'max_pool2d_with_index' if return_mask else "pool2d" op_type = 'max_pool2d_with_index' if return_mask else "pool2d"
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'max_pool2d' x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'max_pool2d'
) )
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
...@@ -1419,7 +1421,9 @@ def max_pool3d( ...@@ -1419,7 +1421,9 @@ def max_pool3d(
else: else:
op_type = "max_pool3d_with_index" if return_mask else "pool3d" op_type = "max_pool3d_with_index" if return_mask else "pool3d"
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool3d') check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'max_pool3d'
)
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
mask = helper.create_variable_for_type_inference('int32') mask = helper.create_variable_for_type_inference('int32')
......
...@@ -333,7 +333,7 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -333,7 +333,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype( check_dtype(
start.dtype, start.dtype,
'start', 'start',
['float32', 'float64', 'int32', 'int64', 'float16', 'bfloat16'], ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'linspace', 'linspace',
) )
else: else:
...@@ -343,7 +343,7 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -343,7 +343,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype( check_dtype(
stop.dtype, stop.dtype,
'stop', 'stop',
['float32', 'float64', 'int32', 'int64', 'float16', 'bfloat16'], ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'linspace', 'linspace',
) )
else: else:
...@@ -353,7 +353,7 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -353,7 +353,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype( check_dtype(
dtype, dtype,
'dtype', 'dtype',
['int32', 'int64', 'float32', 'float64', 'float16', 'bfloat16'], ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'linspace', 'linspace',
) )
if ( if (
...@@ -1787,7 +1787,7 @@ def diag(x, offset=0, padding_value=0, name=None): ...@@ -1787,7 +1787,7 @@ def diag(x, offset=0, padding_value=0, name=None):
check_dtype( check_dtype(
x.dtype, x.dtype,
'x', 'x',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'diag_v2', 'diag_v2',
) )
check_type(offset, 'offset', (int), 'diag_v2') check_type(offset, 'offset', (int), 'diag_v2')
......
...@@ -1112,10 +1112,16 @@ def dot(x, y, name=None): ...@@ -1112,10 +1112,16 @@ def dot(x, y, name=None):
assert y is not None, f'y cannot be None in {op_type}' assert y is not None, f'y cannot be None in {op_type}'
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], op_type x,
'x',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
op_type,
) )
check_variable_and_dtype( check_variable_and_dtype(
y, 'y', ['float32', 'float64', 'int32', 'int64'], op_type y,
'y',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
op_type,
) )
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
...@@ -1374,13 +1380,13 @@ def cross(x, y, axis=9, name=None): ...@@ -1374,13 +1380,13 @@ def cross(x, y, axis=9, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['float16', 'float32', 'float64', "int32", "int64"], ['float16', 'uint16', 'float32', 'float64', "int32", "int64"],
'cross', 'cross',
) )
check_variable_and_dtype( check_variable_and_dtype(
y, y,
'y', 'y',
['float16', 'float32', 'float64', "int32", "int64"], ['float16', 'uint16', 'float32', 'float64', "int32", "int64"],
'cross', 'cross',
) )
helper = LayerHelper("cross", **locals()) helper = LayerHelper("cross", **locals())
......
...@@ -1971,12 +1971,12 @@ def split(x, num_or_sections, axis=0, name=None): ...@@ -1971,12 +1971,12 @@ def split(x, num_or_sections, axis=0, name=None):
[ [
'bool', 'bool',
'float16', 'float16',
'uint16',
'float32', 'float32',
'float64', 'float64',
'int32', 'int32',
'int64', 'int64',
'uint8', 'uint8',
'uint16',
'int8', 'int8',
], ],
'split', 'split',
...@@ -2804,7 +2804,15 @@ def unbind(input, axis=0): ...@@ -2804,7 +2804,15 @@ def unbind(input, axis=0):
check_dtype( check_dtype(
dtype, dtype,
'unbind', 'unbind',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], [
'bool',
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
],
'unbind', 'unbind',
) )
outs = [ outs = [
......
...@@ -513,7 +513,15 @@ def _elementwise_op(helper): ...@@ -513,7 +513,15 @@ def _elementwise_op(helper):
'complex128', 'complex128',
] ]
else: else:
data_type = ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'] data_type = [
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'bool',
]
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
...@@ -3118,7 +3126,15 @@ def diagonal(x, offset=0, axis1=0, axis2=1, name=None): ...@@ -3118,7 +3126,15 @@ def diagonal(x, offset=0, axis1=0, axis2=1, name=None):
check_dtype( check_dtype(
x.dtype, x.dtype,
'Input', 'Input',
['bool', 'int32', 'int64', 'float16', 'float32', 'float64'], [
'bool',
'int32',
'int64',
'float16',
'uint16',
'float32',
'float64',
],
'diagonal', 'diagonal',
) )
...@@ -3277,7 +3293,7 @@ def cumsum(x, axis=None, dtype=None, name=None): ...@@ -3277,7 +3293,7 @@ def cumsum(x, axis=None, dtype=None, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
'cumsum', 'cumsum',
) )
check_type(x, 'x', (Variable), 'cumsum') check_type(x, 'x', (Variable), 'cumsum')
...@@ -3425,7 +3441,16 @@ def cumprod(x, dim=None, dtype=None, name=None): ...@@ -3425,7 +3441,16 @@ def cumprod(x, dim=None, dtype=None, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
"x", "x",
['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64'], [
'complex64',
'complex128',
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
],
'cumprod', 'cumprod',
) )
check_type(dim, 'dim', int, 'cumprod') check_type(dim, 'dim', int, 'cumprod')
...@@ -3995,6 +4020,7 @@ def conj(x, name=None): ...@@ -3995,6 +4020,7 @@ def conj(x, name=None):
'complex64', 'complex64',
'complex128', 'complex128',
'float16', 'float16',
'uint16',
'float32', 'float32',
'float64', 'float64',
'int32', 'int32',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册