未验证 提交 576d0d93 编写于 作者: H huangxu96 提交者: GitHub

add fp16 check into max and avg pool (#29479)

上级 1779e99f
...@@ -674,7 +674,8 @@ def max_pool2d(x, ...@@ -674,7 +674,8 @@ def max_pool2d(x,
return_mask=True) return_mask=True)
# out.shape [1, 3, 16, 16], max_indices.shape [1, 3, 16, 16], # out.shape [1, 3, 16, 16], max_indices.shape [1, 3, 16, 16],
""" """
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool2d') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'max_pool2d')
kernel_size = utils.convert_to_list(kernel_size, 2, 'pool_size') kernel_size = utils.convert_to_list(kernel_size, 2, 'pool_size')
if stride is None: if stride is None:
stride = kernel_size stride = kernel_size
...@@ -911,7 +912,8 @@ def adaptive_avg_pool1d(x, output_size, name=None): ...@@ -911,7 +912,8 @@ def adaptive_avg_pool1d(x, output_size, name=None):
# pool_out shape: [1, 3, 16]) # pool_out shape: [1, 3, 16])
""" """
pool_type = 'avg' pool_type = 'avg'
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'adaptive_pool2d') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'adaptive_pool2d')
_check_input(x, 3) _check_input(x, 3)
check_type(output_size, 'pool_size', (int), 'adaptive_pool1d') check_type(output_size, 'pool_size', (int), 'adaptive_pool1d')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册