未验证 提交 f8ec430e 编写于 作者: M Maple Xie 提交者: GitHub

[Bug fix] Fix fp16 dtype checking for AvgPool1D op (#50929)

* Fix fp16 dtype checking for AvgPool1D op

* Update code style for PR-CI-Static-Check
上级 38dad3b9
......@@ -139,6 +139,32 @@ class TestPool1D_API(unittest.TestCase):
)
np.testing.assert_allclose(fetches[0], result_np, rtol=1e-05)
def check_avg_static_results_fp16(self, place):
with paddle.static.program_guard(paddle.static.Program()):
input = paddle.static.data(
name="input", shape=[2, 3, 32], dtype="float16"
)
result = F.avg_pool1d(input, kernel_size=2, stride=2, padding=0)
input_np = np.random.random([2, 3, 32]).astype("float16")
result_np = avg_pool1D_forward_naive(
input_np,
ksize=[2],
strides=[2],
paddings=[0],
ceil_mode=False,
)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
fetches = exe.run(
paddle.static.default_main_program(),
feed={"input": input_np},
fetch_list=[result],
)
np.testing.assert_allclose(fetches[0], result_np, rtol=1e-03)
def check_avg_dygraph_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32]).astype("float32")
......@@ -272,6 +298,7 @@ class TestPool1D_API(unittest.TestCase):
self.check_max_dygraph_padding_same(place)
self.check_avg_dygraph_padding_same(place)
self.check_max_dygraph_return_index_results(place)
self.check_avg_static_results_fp16(place)
class TestPool1DError_API(unittest.TestCase):
......
......@@ -187,7 +187,7 @@ def avg_pool1d(
Args:
x (Tensor): The input tensor of pooling operator which is a 3-D tensor with
shape [N, C, L]. where `N` is batch size, `C` is the number of channels,
`L` is the length of the feature. The data type is float32 or float64.
`L` is the length of the feature. The data type is float16, float32 or float64.
kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain an integer.
stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list,
......@@ -223,7 +223,9 @@ def avg_pool1d(
"""NCL to NCHW"""
data_format = "NCHW"
if not in_dynamic_mode():
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'avg_pool1d')
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'avg_pool1d'
)
_check_input(x, 3)
x = unsqueeze(x, [2])
kernel_size = utils.convert_to_list(kernel_size, 1, 'kernel_size')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册