diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index 39cd26a455c2660a638f586fdbbdd43541af91d2..d0a0416994169e50cd3bec4d4a0590f686e6330c 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -993,6 +993,7 @@ template class Pool2dDirectCUDAFunctor, float>; template class MaxPool2dGradFunctor; template class MaxPool2dGradFunctor; template class MaxPool2dGradFunctor; +template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; template class Pool2dFunctor, float>; @@ -1015,6 +1016,18 @@ template class Pool2dGradFunctor, dtype::float16>; +template class Pool2dFunctor, + dtype::bfloat16>; +template class Pool2dFunctor, + dtype::bfloat16>; +template class Pool2dGradFunctor, + dtype::bfloat16>; +template class Pool2dGradFunctor, + dtype::bfloat16>; template __global__ void KernelPool3D(const int nthreads, @@ -1863,6 +1876,7 @@ template class Pool3dDirectCUDAFunctor, float>; template class MaxPool3dGradFunctor; template class MaxPool3dGradFunctor; template class MaxPool3dGradFunctor; +template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; template class Pool3dFunctor, float>; @@ -1879,12 +1893,24 @@ template class Pool3dFunctor, dtype::float16>; +template class Pool3dFunctor, + dtype::bfloat16>; +template class Pool3dFunctor, + dtype::bfloat16>; template class Pool3dGradFunctor, dtype::float16>; template class Pool3dGradFunctor, dtype::float16>; +template class Pool3dGradFunctor, + dtype::bfloat16>; +template class Pool3dGradFunctor, + dtype::bfloat16>; template __global__ void KernelMaxPool2dWithIdx(const int nthreads, diff --git a/paddle/phi/kernels/funcs/select_impl.cu.h b/paddle/phi/kernels/funcs/select_impl.cu.h index d48598decaa8ba3f1d39965fdd880aa015925fff..96b7942cf27094378c59db7b479ce1b948779837 100644 --- a/paddle/phi/kernels/funcs/select_impl.cu.h +++ b/paddle/phi/kernels/funcs/select_impl.cu.h @@ -268,7 +268,7 @@ __device__ void SelectKernelImpl(OutT *out, using IdT = int64_t; // Set index data type using Add = kps::AddFunctor; // for cumsum - using Cast = NonZeroFunctor; // for mask + using Cast = NonZeroFunctor; // for mask IdT init_idx = static_cast(0.0f); MT init_mask = static_cast(0.0f); diff --git a/paddle/phi/kernels/gpu/lgamma_grad_kernel.cu b/paddle/phi/kernels/gpu/lgamma_grad_kernel.cu index 3e4cd21a658f103aca9bc611a2d42518245e4401..f21d4642e28a6ee801dc08bfc0d401cf55510e9d 100644 --- a/paddle/phi/kernels/gpu/lgamma_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lgamma_grad_kernel.cu @@ -15,7 +15,14 @@ #include "paddle/phi/kernels/lgamma_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - lgamma_grad, GPU, ALL_LAYOUT, phi::LgammaGradKernel, float, double) {} +PD_REGISTER_KERNEL(lgamma_grad, + GPU, + ALL_LAYOUT, + phi::LgammaGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/lgamma_kernel.cu b/paddle/phi/kernels/gpu/lgamma_kernel.cu index e94d67f4ce324ad9d8237a377d70a920cdbd30af..899031fb5c1cc3b930b64c2c4bf1464813d950ff 100644 --- a/paddle/phi/kernels/gpu/lgamma_kernel.cu +++ b/paddle/phi/kernels/gpu/lgamma_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/lgamma_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" @@ -22,7 +23,9 @@ namespace phi { template struct CudaLgammaFunctor { __device__ __forceinline__ T operator()(const T x) const { - return Eigen::numext::lgamma(x); + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(x); + return static_cast(Eigen::numext::lgamma(mp_x)); } }; template @@ -38,4 +41,11 @@ void LgammaKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL(lgamma, GPU, ALL_LAYOUT, phi::LgammaKernel, float, double) {} +PD_REGISTER_KERNEL(lgamma, + GPU, + ALL_LAYOUT, + phi::LgammaKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu b/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu index 05f85ac1e613d770827b5f3053b92ebab48fb028..1121ff361f8ac7b56f098d0c410b2ce554c89148 100644 --- a/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu @@ -19,6 +19,7 @@ #include #include +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h" @@ -66,4 +67,6 @@ PD_REGISTER_KERNEL(masked_select_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/masked_select_kernel.cu b/paddle/phi/kernels/gpu/masked_select_kernel.cu index 632d79929cc3d0c707bbe6557f2257613c13daa3..208bdd853cc30f086a5eb93efd931aab7d995581 100644 --- a/paddle/phi/kernels/gpu/masked_select_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_select_kernel.cu @@ -20,6 +20,7 @@ #include #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h" @@ -76,6 +77,8 @@ PD_REGISTER_KERNEL(masked_select, float, double, int, - int64_t) { + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(1).SetDataType(phi::DataType::BOOL); } diff --git a/paddle/phi/kernels/gpu/pool_grad_kernel.cu b/paddle/phi/kernels/gpu/pool_grad_kernel.cu index 61bd705ee5bfe86089be707f1848c381c1de354a..598a48f802891eabd38d1db6ea17cf3b31346f83 100644 --- a/paddle/phi/kernels/gpu/pool_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/pool_grad_kernel.cu @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/pool_grad_kernel.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/pool_grad_kernel_impl.h" @@ -46,7 +47,8 @@ PD_REGISTER_KERNEL(pool3d_grad, phi::Pool3dGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(max_pool3d_with_index_grad, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/pool_kernel.cu b/paddle/phi/kernels/gpu/pool_kernel.cu index 04c4b6d27b45475d62948b5098fafea18aabf030..6323909c9d0dca5c0e7371f6296acf2f5ae488c8 100644 --- a/paddle/phi/kernels/gpu/pool_kernel.cu +++ b/paddle/phi/kernels/gpu/pool_kernel.cu @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/pool_kernel.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/pool_kernel_impl.h" @@ -40,7 +41,8 @@ PD_REGISTER_KERNEL(pool3d, phi::Pool3dKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(max_pool3d_with_index, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h b/paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h index fd1c1dbc8d666dfc8659e9cecf81550553f56e71..b90c17df92b54134522121846b08fb81db8036f2 100644 --- a/paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h @@ -15,6 +15,7 @@ #pragma once #include +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/funcs/for_range.h" namespace phi { template @@ -23,7 +24,10 @@ struct LgammaGradFunctor { : dout_(dout), x_(x), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { - output_[idx] = dout_[idx] * Eigen::numext::digamma(x_[idx]); + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_dout = static_cast(dout_[idx]); + const MT mp_x = static_cast(x_[idx]); + output_[idx] = static_cast(mp_dout * Eigen::numext::digamma(mp_x)); } private: diff --git a/python/paddle/fluid/tests/unittests/test_lgamma_op.py b/python/paddle/fluid/tests/unittests/test_lgamma_op.py index 69e1a008b05765adffe4eec26dd5f5eae9aef201..52c079f8a62a15e49ad2e56ca905f7ff3b1559f2 100644 --- a/python/paddle/fluid/tests/unittests/test_lgamma_op.py +++ b/python/paddle/fluid/tests/unittests/test_lgamma_op.py @@ -16,10 +16,11 @@ import math import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 from scipy import special import paddle +from paddle.fluid import core paddle.enable_static() @@ -56,6 +57,41 @@ class TestLgammaOpFp32(TestLgammaOp): self.check_grad(['X'], 'Out', numeric_grad_delta=0.005) +class TestLgammaFP16Op(TestLgammaOp): + def init_dtype_type(self): + self.dtype = np.float16 + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out') + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestLgammaBF16Op(OpTest): + def setUp(self): + self.op_type = 'lgamma' + self.python_api = paddle.lgamma + self.dtype = np.uint16 + shape = (5, 20) + data = np.random.random(shape).astype("float32") + 1 + self.inputs = {'X': convert_float_to_uint16(data)} + result = np.ones(shape).astype("float32") + for i in range(shape[0]): + for j in range(shape[1]): + result[i][j] = math.lgamma(data[i][j]) + self.outputs = {'Out': convert_float_to_uint16(result)} + + def test_check_output(self): + # After testing, bfloat16 needs to set the parameter place + self.check_output_with_place(core.CUDAPlace(0)) + + def test_check_grad_normal(self): + self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Out') + + class TestLgammaOpApi(unittest.TestCase): def test_lgamma(self): paddle.disable_static() diff --git a/python/paddle/fluid/tests/unittests/test_masked_select_op.py b/python/paddle/fluid/tests/unittests/test_masked_select_op.py index 4775d815075abf0dd3c1ecc1f40a401ea094f4e5..fb02653a632bac0309d26e648b084c6ca31d13ac 100644 --- a/python/paddle/fluid/tests/unittests/test_masked_select_op.py +++ b/python/paddle/fluid/tests/unittests/test_masked_select_op.py @@ -15,9 +15,10 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle +from paddle.fluid import core def np_masked_select(x, mask): @@ -59,6 +60,75 @@ class TestMaskedSelectOp2(TestMaskedSelectOp): self.shape = (168,) +class TestMaskedSelectFP16Op(OpTest): + def setUp(self): + self.init() + self.op_type = "masked_select" + self.dtype = np.float16 + self.python_api = paddle.masked_select + x = np.random.random(self.shape).astype("float16") + mask = np.array(np.random.randint(2, size=self.shape, dtype=bool)) + out = np_masked_select(x, mask) + self.inputs = {'X': x, 'Mask': mask} + self.outputs = {'Y': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y') + + def init(self): + self.shape = (50, 3) + + +class TestMaskedSelectFP16Op1(TestMaskedSelectFP16Op): + def init(self): + self.shape = (6, 8, 9, 18) + + +class TestMaskedSelectFP16Op2(TestMaskedSelectFP16Op): + def init(self): + self.shape = (168,) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestMaskedSelectBF16Op(OpTest): + def setUp(self): + self.init() + self.op_type = "masked_select" + self.dtype = np.uint16 + self.python_api = paddle.masked_select + x = np.random.random(self.shape).astype("float32") + mask = np.array(np.random.randint(2, size=self.shape, dtype=bool)) + out = np_masked_select(x, mask) + self.inputs = {'X': convert_float_to_uint16(x), 'Mask': mask} + self.outputs = {'Y': convert_float_to_uint16(out)} + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Y') + + def init(self): + self.shape = (50, 3) + + +class TestMaskedSelectBF16Op1(TestMaskedSelectBF16Op): + def init(self): + self.shape = (6, 8, 9, 2) + + +class TestMaskedSelectBF16Op2(TestMaskedSelectBF16Op): + def init(self): + self.shape = (168,) + + class TestMaskedSelectAPI(unittest.TestCase): def test_imperative_mode(self): paddle.disable_static() diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_api.py b/python/paddle/fluid/tests/unittests/test_pool3d_api.py index cd24f5da9eb4e906b53af6f375b8d29da23ee388..e557b6830a3d0d8809a37ebde8cb84f2a6e09317 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_api.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_api.py @@ -354,6 +354,7 @@ class TestPool3D_API(unittest.TestCase): np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) def test_pool3d(self): + paddle.enable_static() for place in self.places: self.check_max_dygraph_results(place) @@ -366,7 +367,8 @@ class TestPool3D_API(unittest.TestCase): self.check_max_dygraph_ndhwc_results(place) self.check_max_dygraph_ceilmode_results(place) - def test_static_pf16_gpu(self): + def test_static_fp16_gpu(self): + paddle.enable_static() if paddle.fluid.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) with paddle.static.program_guard( @@ -392,6 +394,36 @@ class TestPool3D_API(unittest.TestCase): assert np.array_equal(res[0].shape, [1, 2, 1, 16, 16]) + def test_static_bf16_gpu(self): + paddle.enable_static() + if ( + paddle.fluid.core.is_compiled_with_cuda() + and paddle.fluid.core.is_bfloat16_supported(core.CUDAPlace(0)) + ): + place = paddle.CUDAPlace(0) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + input = np.random.random([1, 2, 3, 32, 32]).astype(np.uint16) + + x = paddle.static.data( + name="x", shape=[1, 2, 3, 32, 32], dtype="bfloat16" + ) + + m = paddle.nn.AvgPool3D(kernel_size=2, stride=2, padding=0) + y = m(x) + + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={ + "x": input, + }, + fetch_list=[y], + ) + + assert np.array_equal(res[0].shape, [1, 2, 1, 16, 16]) + class TestPool3DError_API(unittest.TestCase): def test_error_api(self): diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_op.py b/python/paddle/fluid/tests/unittests/test_pool3d_op.py index e1f4d048a4cb49fd6cb4004737865b85b08dcbe6..893adbb2560914d4ce57de5e6579d98879be19c2 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_op.py @@ -399,9 +399,9 @@ class TestPool3D_Op(OpTest): self.check_output() def test_check_grad(self): - if self.dtype == np.float16: - return - if self.has_cudnn() and self.pool_type != "max": + if ( + self.has_cudnn() or self.dtype == np.uint16 + ) and self.pool_type != "max": place = core.CUDAPlace(0) if core.is_compiled_with_rocm(): self.check_grad_with_place( @@ -566,6 +566,46 @@ def create_test_fp16_class(parent): globals()[cls_name] = TestFp16Case +def create_test_cudnn_bf16_class(parent): + @unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", + ) + class TestCUDNNBf16Case(parent): + def init_kernel_type(self): + self.use_cudnn = True + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + cls_name = "{}_{}".format(parent.__name__, "CUDNNBf16Op") + TestCUDNNBf16Case.__name__ = cls_name + globals()[cls_name] = TestCUDNNBf16Case + + +def create_test_bf16_class(parent): + @unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", + ) + class TestBf16Case(parent): + def init_kernel_type(self): + self.use_cudnn = False + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + cls_name = "{}_{}".format(parent.__name__, "Bf16Op") + TestBf16Case.__name__ = cls_name + globals()[cls_name] = TestBf16Case + + create_test_cudnn_fp16_class(TestPool3D_Op) create_test_cudnn_fp16_class(TestCase1) create_test_cudnn_fp16_class(TestCase2) @@ -580,6 +620,20 @@ create_test_fp16_class(TestCase3) create_test_fp16_class(TestCase4) create_test_fp16_class(TestCase5) +create_test_cudnn_bf16_class(TestPool3D_Op) +create_test_cudnn_bf16_class(TestCase1) +create_test_cudnn_bf16_class(TestCase2) +create_test_cudnn_bf16_class(TestCase3) +create_test_cudnn_bf16_class(TestCase4) +create_test_cudnn_bf16_class(TestCase5) + +create_test_bf16_class(TestPool3D_Op) +create_test_bf16_class(TestCase1) +create_test_bf16_class(TestCase2) +create_test_bf16_class(TestCase3) +create_test_bf16_class(TestCase4) +create_test_bf16_class(TestCase5) + # ---- test ceil mode ------ def create_test_cudnn_use_ceil_class(parent): @@ -736,6 +790,13 @@ create_test_cudnn_fp16_class(TestCase3_AsyPadding) create_test_cudnn_fp16_class(TestCase4_AsyPadding) create_test_cudnn_fp16_class(TestCase5_AsyPadding) +create_test_cudnn_bf16_class(TestPool3D_Op_AsyPadding) +create_test_cudnn_bf16_class(TestCase1_AsyPadding) +create_test_cudnn_bf16_class(TestCase2_AsyPadding) +create_test_cudnn_bf16_class(TestCase3_AsyPadding) +create_test_cudnn_bf16_class(TestCase4_AsyPadding) +create_test_cudnn_bf16_class(TestCase5_AsyPadding) + create_test_cudnn_use_ceil_class(TestPool3D_Op_AsyPadding) create_test_cudnn_use_ceil_class(TestCase1_AsyPadding) diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index f58ad0f0238a69581da60d89a212250f6b1750c7..5483de2b70edbc0b1b406d2f3ac29f516233d34a 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -520,7 +520,7 @@ def avg_pool3d( op_type = "pool3d" helper = LayerHelper(op_type, **locals()) check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64'], 'avg_pool3d' + x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'avg_pool3d' ) dtype = helper.input_dtype(input_param_name='x') pool_out = helper.create_variable_for_type_inference(dtype) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index cc662d83457fa268e12dab7c06ce0d4513bc1499..7c01c3109179837b4213c5e388bd7048a1a95fdf 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4027,7 +4027,7 @@ def lgamma(x, name=None): Args: - x (Tensor): Input Tensor. Must be one of the following types: float32, float64. + x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64, uint16. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -4046,7 +4046,9 @@ def lgamma(x, name=None): if in_dygraph_mode(): return _C_ops.lgamma(x) else: - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'lgamma') + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'lgamma' + ) helper = LayerHelper('lgamma', **locals()) out = helper.create_variable_for_type_inference(x.dtype) helper.append_op(type='lgamma', inputs={'X': x}, outputs={'Out': out}) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index f93b0ea12feacfce868ee3bd5c2ee2983fdba0cb..ffcb9458a0618f25f28a7bed06e634ab266f45ff 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -807,7 +807,7 @@ def masked_select(x, mask, name=None): which is a tensor with data type of bool. Args: - x (Tensor): The input Tensor, the data type can be int32, int64, float32, float64. + x (Tensor): The input Tensor, the data type can be int32, int64, uint16, float16, float32, float64. mask (Tensor): The Tensor containing the binary mask to index with, it's data type is bool. name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. @@ -838,7 +838,7 @@ def masked_select(x, mask, name=None): check_variable_and_dtype( x, 'x', - ['float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], 'paddle.tensor.search.mask_select', ) check_variable_and_dtype(