未验证 提交 b69af7ad 编写于 作者: 张春乔 提交者: GitHub

[fp16] support fp16 on LocalResponseNorm (#50918)

* support fp16 on LocalResponseNorm

* add docs in avgpool3d
上级 d841062b
......@@ -328,6 +328,32 @@ class TestLocalResponseNormCAPI(unittest.TestCase):
res2_tran = np.transpose(res2.numpy(), (0, 3, 1, 2))
np.testing.assert_allclose(res1.numpy(), res2_tran, rtol=1e-05)
def test_static_fp16_gpu(self):
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input = np.random.random([3, 3, 112, 112]).astype("float16")
x = paddle.static.data(
name="x", shape=[3, 3, 112, 112], dtype="float16"
)
m = paddle.nn.LocalResponseNorm(size=5)
y = m(x)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={
"x": input,
},
fetch_list=[y],
)
assert np.array_equal(res[0].shape, input.shape)
if __name__ == "__main__":
unittest.main()
......@@ -478,7 +478,7 @@ def local_response_norm(
Args:
x (Tensor): The input 3-D/4-D/5-D tensor. The data type is float32.
x (Tensor): The input 3-D/4-D/5-D tensor. The data type is float16 or float32.
size (int): The number of channels to sum over.
alpha (float, optional): The scaling parameter, positive. Default:1e-4
beta (float, optional): The exponent, positive. Default:0.75
......@@ -509,7 +509,9 @@ def local_response_norm(
print(y.shape) # [3, 3, 112, 112]
"""
if not in_dynamic_mode():
check_variable_and_dtype(x, 'x', ['float32'], 'local_response_norm')
check_variable_and_dtype(
x, 'x', ['float16', 'float32'], 'local_response_norm'
)
if data_format not in ['NCL', 'NLC', 'NCHW', 'NHWC', 'NCDHW', 'NDHWC']:
raise ValueError(
"data_format should be in one of [NCL, NCHW, NCDHW, NLC, NHWC, NDHWC], "
......
......@@ -261,7 +261,7 @@ class AvgPool3D(Layer):
Shape:
- x(Tensor): The input tensor of avg pool3d operator, which is a 5-D tensor.
The data type can be float32, float64.
The data type can be float16, float32, float64.
- output(Tensor): The output tensor of avg pool3d operator, which is a 5-D tensor.
The data type is same as input x.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册