未验证 提交 b0dbf9fe 编写于 作者: C chenxujun 提交者: GitHub

【Hackathon No.62】增加pool3d算子BF16及单测,lgamma, masked_select FP16/BF16算子单测 (#51837)

* Add pool3d lgamma masked_select tests

* Fix code
上级 f6f104d5
...@@ -993,6 +993,7 @@ template class Pool2dDirectCUDAFunctor<AvgPool<float>, float>; ...@@ -993,6 +993,7 @@ template class Pool2dDirectCUDAFunctor<AvgPool<float>, float>;
template class MaxPool2dGradFunctor<phi::GPUContext, float>; template class MaxPool2dGradFunctor<phi::GPUContext, float>;
template class MaxPool2dGradFunctor<phi::GPUContext, double>; template class MaxPool2dGradFunctor<phi::GPUContext, double>;
template class MaxPool2dGradFunctor<phi::GPUContext, dtype::float16>; template class MaxPool2dGradFunctor<phi::GPUContext, dtype::float16>;
template class MaxPool2dGradFunctor<phi::GPUContext, dtype::bfloat16>;
template class Pool2dFunctor<phi::GPUContext, MaxPool<float>, float>; template class Pool2dFunctor<phi::GPUContext, MaxPool<float>, float>;
template class Pool2dFunctor<phi::GPUContext, AvgPool<float>, float>; template class Pool2dFunctor<phi::GPUContext, AvgPool<float>, float>;
...@@ -1015,6 +1016,18 @@ template class Pool2dGradFunctor<phi::GPUContext, ...@@ -1015,6 +1016,18 @@ template class Pool2dGradFunctor<phi::GPUContext,
template class Pool2dGradFunctor<phi::GPUContext, template class Pool2dGradFunctor<phi::GPUContext,
AvgPoolGrad<dtype::float16>, AvgPoolGrad<dtype::float16>,
dtype::float16>; dtype::float16>;
template class Pool2dFunctor<phi::GPUContext,
MaxPool<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool2dFunctor<phi::GPUContext,
AvgPool<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool2dGradFunctor<phi::GPUContext,
MaxPoolGrad<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool2dGradFunctor<phi::GPUContext,
AvgPoolGrad<dtype::bfloat16>,
dtype::bfloat16>;
template <typename PoolProcess, typename T> template <typename PoolProcess, typename T>
__global__ void KernelPool3D(const int nthreads, __global__ void KernelPool3D(const int nthreads,
...@@ -1863,6 +1876,7 @@ template class Pool3dDirectCUDAFunctor<AvgPool<float>, float>; ...@@ -1863,6 +1876,7 @@ template class Pool3dDirectCUDAFunctor<AvgPool<float>, float>;
template class MaxPool3dGradFunctor<phi::GPUContext, float>; template class MaxPool3dGradFunctor<phi::GPUContext, float>;
template class MaxPool3dGradFunctor<phi::GPUContext, double>; template class MaxPool3dGradFunctor<phi::GPUContext, double>;
template class MaxPool3dGradFunctor<phi::GPUContext, dtype::float16>; template class MaxPool3dGradFunctor<phi::GPUContext, dtype::float16>;
template class MaxPool3dGradFunctor<phi::GPUContext, dtype::bfloat16>;
template class Pool3dFunctor<phi::GPUContext, MaxPool<float>, float>; template class Pool3dFunctor<phi::GPUContext, MaxPool<float>, float>;
template class Pool3dFunctor<phi::GPUContext, AvgPool<float>, float>; template class Pool3dFunctor<phi::GPUContext, AvgPool<float>, float>;
...@@ -1879,12 +1893,24 @@ template class Pool3dFunctor<phi::GPUContext, ...@@ -1879,12 +1893,24 @@ template class Pool3dFunctor<phi::GPUContext,
template class Pool3dFunctor<phi::GPUContext, template class Pool3dFunctor<phi::GPUContext,
AvgPool<dtype::float16>, AvgPool<dtype::float16>,
dtype::float16>; dtype::float16>;
template class Pool3dFunctor<phi::GPUContext,
MaxPool<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool3dFunctor<phi::GPUContext,
AvgPool<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool3dGradFunctor<phi::GPUContext, template class Pool3dGradFunctor<phi::GPUContext,
MaxPoolGrad<dtype::float16>, MaxPoolGrad<dtype::float16>,
dtype::float16>; dtype::float16>;
template class Pool3dGradFunctor<phi::GPUContext, template class Pool3dGradFunctor<phi::GPUContext,
AvgPoolGrad<dtype::float16>, AvgPoolGrad<dtype::float16>,
dtype::float16>; dtype::float16>;
template class Pool3dGradFunctor<phi::GPUContext,
MaxPoolGrad<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool3dGradFunctor<phi::GPUContext,
AvgPoolGrad<dtype::bfloat16>,
dtype::bfloat16>;
template <typename T1, typename T2> template <typename T1, typename T2>
__global__ void KernelMaxPool2dWithIdx(const int nthreads, __global__ void KernelMaxPool2dWithIdx(const int nthreads,
......
...@@ -268,7 +268,7 @@ __device__ void SelectKernelImpl(OutT *out, ...@@ -268,7 +268,7 @@ __device__ void SelectKernelImpl(OutT *out,
using IdT = int64_t; using IdT = int64_t;
// Set index data type // Set index data type
using Add = kps::AddFunctor<IdT>; // for cumsum using Add = kps::AddFunctor<IdT>; // for cumsum
using Cast = NonZeroFunctor<InT>; // for mask using Cast = NonZeroFunctor<MT>; // for mask
IdT init_idx = static_cast<IdT>(0.0f); IdT init_idx = static_cast<IdT>(0.0f);
MT init_mask = static_cast<MT>(0.0f); MT init_mask = static_cast<MT>(0.0f);
......
...@@ -15,7 +15,14 @@ ...@@ -15,7 +15,14 @@
#include "paddle/phi/kernels/lgamma_grad_kernel.h" #include "paddle/phi/kernels/lgamma_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(lgamma_grad,
lgamma_grad, GPU, ALL_LAYOUT, phi::LgammaGradKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::LgammaGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/lgamma_kernel.h" #include "paddle/phi/kernels/lgamma_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
...@@ -22,7 +23,9 @@ namespace phi { ...@@ -22,7 +23,9 @@ namespace phi {
template <typename T> template <typename T>
struct CudaLgammaFunctor { struct CudaLgammaFunctor {
__device__ __forceinline__ T operator()(const T x) const { __device__ __forceinline__ T operator()(const T x) const {
return Eigen::numext::lgamma(x); using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(x);
return static_cast<T>(Eigen::numext::lgamma(mp_x));
} }
}; };
template <typename T, typename Context> template <typename T, typename Context>
...@@ -38,4 +41,11 @@ void LgammaKernel(const Context& dev_ctx, ...@@ -38,4 +41,11 @@ void LgammaKernel(const Context& dev_ctx,
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(lgamma, GPU, ALL_LAYOUT, phi::LgammaKernel, float, double) {} PD_REGISTER_KERNEL(lgamma,
GPU,
ALL_LAYOUT,
phi::LgammaKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <thrust/reverse.h> #include <thrust/reverse.h>
#include <thrust/scan.h> #include <thrust/scan.h>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h"
...@@ -66,4 +67,6 @@ PD_REGISTER_KERNEL(masked_select_grad, ...@@ -66,4 +67,6 @@ PD_REGISTER_KERNEL(masked_select_grad,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <thrust/scan.h> #include <thrust/scan.h>
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h"
...@@ -76,6 +77,8 @@ PD_REGISTER_KERNEL(masked_select, ...@@ -76,6 +77,8 @@ PD_REGISTER_KERNEL(masked_select,
float, float,
double, double,
int, int,
int64_t) { int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(1).SetDataType(phi::DataType::BOOL); kernel->InputAt(1).SetDataType(phi::DataType::BOOL);
} }
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/pool_grad_kernel.h" #include "paddle/phi/kernels/pool_grad_kernel.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pool_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/pool_grad_kernel_impl.h"
...@@ -46,7 +47,8 @@ PD_REGISTER_KERNEL(pool3d_grad, ...@@ -46,7 +47,8 @@ PD_REGISTER_KERNEL(pool3d_grad,
phi::Pool3dGradKernel, phi::Pool3dGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(max_pool3d_with_index_grad, PD_REGISTER_KERNEL(max_pool3d_with_index_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/pool_kernel.h" #include "paddle/phi/kernels/pool_kernel.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pool_kernel_impl.h" #include "paddle/phi/kernels/impl/pool_kernel_impl.h"
...@@ -40,7 +41,8 @@ PD_REGISTER_KERNEL(pool3d, ...@@ -40,7 +41,8 @@ PD_REGISTER_KERNEL(pool3d,
phi::Pool3dKernel, phi::Pool3dKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(max_pool3d_with_index, PD_REGISTER_KERNEL(max_pool3d_with_index,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <unsupported/Eigen/SpecialFunctions> #include <unsupported/Eigen/SpecialFunctions>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/for_range.h"
namespace phi { namespace phi {
template <typename T> template <typename T>
...@@ -23,7 +24,10 @@ struct LgammaGradFunctor { ...@@ -23,7 +24,10 @@ struct LgammaGradFunctor {
: dout_(dout), x_(x), output_(output), numel_(numel) {} : dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = dout_[idx] * Eigen::numext::digamma(x_[idx]); using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_dout = static_cast<MT>(dout_[idx]);
const MT mp_x = static_cast<MT>(x_[idx]);
output_[idx] = static_cast<T>(mp_dout * Eigen::numext::digamma(mp_x));
} }
private: private:
......
...@@ -16,10 +16,11 @@ import math ...@@ -16,10 +16,11 @@ import math
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
from scipy import special from scipy import special
import paddle import paddle
from paddle.fluid import core
paddle.enable_static() paddle.enable_static()
...@@ -56,6 +57,41 @@ class TestLgammaOpFp32(TestLgammaOp): ...@@ -56,6 +57,41 @@ class TestLgammaOpFp32(TestLgammaOp):
self.check_grad(['X'], 'Out', numeric_grad_delta=0.005) self.check_grad(['X'], 'Out', numeric_grad_delta=0.005)
class TestLgammaFP16Op(TestLgammaOp):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestLgammaBF16Op(OpTest):
def setUp(self):
self.op_type = 'lgamma'
self.python_api = paddle.lgamma
self.dtype = np.uint16
shape = (5, 20)
data = np.random.random(shape).astype("float32") + 1
self.inputs = {'X': convert_float_to_uint16(data)}
result = np.ones(shape).astype("float32")
for i in range(shape[0]):
for j in range(shape[1]):
result[i][j] = math.lgamma(data[i][j])
self.outputs = {'Out': convert_float_to_uint16(result)}
def test_check_output(self):
# After testing, bfloat16 needs to set the parameter place
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad_normal(self):
self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Out')
class TestLgammaOpApi(unittest.TestCase): class TestLgammaOpApi(unittest.TestCase):
def test_lgamma(self): def test_lgamma(self):
paddle.disable_static() paddle.disable_static()
......
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle.fluid import core
def np_masked_select(x, mask): def np_masked_select(x, mask):
...@@ -59,6 +60,75 @@ class TestMaskedSelectOp2(TestMaskedSelectOp): ...@@ -59,6 +60,75 @@ class TestMaskedSelectOp2(TestMaskedSelectOp):
self.shape = (168,) self.shape = (168,)
class TestMaskedSelectFP16Op(OpTest):
def setUp(self):
self.init()
self.op_type = "masked_select"
self.dtype = np.float16
self.python_api = paddle.masked_select
x = np.random.random(self.shape).astype("float16")
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool))
out = np_masked_select(x, mask)
self.inputs = {'X': x, 'Mask': mask}
self.outputs = {'Y': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
def init(self):
self.shape = (50, 3)
class TestMaskedSelectFP16Op1(TestMaskedSelectFP16Op):
def init(self):
self.shape = (6, 8, 9, 18)
class TestMaskedSelectFP16Op2(TestMaskedSelectFP16Op):
def init(self):
self.shape = (168,)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestMaskedSelectBF16Op(OpTest):
def setUp(self):
self.init()
self.op_type = "masked_select"
self.dtype = np.uint16
self.python_api = paddle.masked_select
x = np.random.random(self.shape).astype("float32")
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool))
out = np_masked_select(x, mask)
self.inputs = {'X': convert_float_to_uint16(x), 'Mask': mask}
self.outputs = {'Y': convert_float_to_uint16(out)}
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Y')
def init(self):
self.shape = (50, 3)
class TestMaskedSelectBF16Op1(TestMaskedSelectBF16Op):
def init(self):
self.shape = (6, 8, 9, 2)
class TestMaskedSelectBF16Op2(TestMaskedSelectBF16Op):
def init(self):
self.shape = (168,)
class TestMaskedSelectAPI(unittest.TestCase): class TestMaskedSelectAPI(unittest.TestCase):
def test_imperative_mode(self): def test_imperative_mode(self):
paddle.disable_static() paddle.disable_static()
......
...@@ -354,6 +354,7 @@ class TestPool3D_API(unittest.TestCase): ...@@ -354,6 +354,7 @@ class TestPool3D_API(unittest.TestCase):
np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05) np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05)
def test_pool3d(self): def test_pool3d(self):
paddle.enable_static()
for place in self.places: for place in self.places:
self.check_max_dygraph_results(place) self.check_max_dygraph_results(place)
...@@ -366,7 +367,8 @@ class TestPool3D_API(unittest.TestCase): ...@@ -366,7 +367,8 @@ class TestPool3D_API(unittest.TestCase):
self.check_max_dygraph_ndhwc_results(place) self.check_max_dygraph_ndhwc_results(place)
self.check_max_dygraph_ceilmode_results(place) self.check_max_dygraph_ceilmode_results(place)
def test_static_pf16_gpu(self): def test_static_fp16_gpu(self):
paddle.enable_static()
if paddle.fluid.core.is_compiled_with_cuda(): if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0) place = paddle.CUDAPlace(0)
with paddle.static.program_guard( with paddle.static.program_guard(
...@@ -392,6 +394,36 @@ class TestPool3D_API(unittest.TestCase): ...@@ -392,6 +394,36 @@ class TestPool3D_API(unittest.TestCase):
assert np.array_equal(res[0].shape, [1, 2, 1, 16, 16]) assert np.array_equal(res[0].shape, [1, 2, 1, 16, 16])
def test_static_bf16_gpu(self):
paddle.enable_static()
if (
paddle.fluid.core.is_compiled_with_cuda()
and paddle.fluid.core.is_bfloat16_supported(core.CUDAPlace(0))
):
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input = np.random.random([1, 2, 3, 32, 32]).astype(np.uint16)
x = paddle.static.data(
name="x", shape=[1, 2, 3, 32, 32], dtype="bfloat16"
)
m = paddle.nn.AvgPool3D(kernel_size=2, stride=2, padding=0)
y = m(x)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={
"x": input,
},
fetch_list=[y],
)
assert np.array_equal(res[0].shape, [1, 2, 1, 16, 16])
class TestPool3DError_API(unittest.TestCase): class TestPool3DError_API(unittest.TestCase):
def test_error_api(self): def test_error_api(self):
......
...@@ -399,9 +399,9 @@ class TestPool3D_Op(OpTest): ...@@ -399,9 +399,9 @@ class TestPool3D_Op(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if (
return self.has_cudnn() or self.dtype == np.uint16
if self.has_cudnn() and self.pool_type != "max": ) and self.pool_type != "max":
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_compiled_with_rocm(): if core.is_compiled_with_rocm():
self.check_grad_with_place( self.check_grad_with_place(
...@@ -566,6 +566,46 @@ def create_test_fp16_class(parent): ...@@ -566,6 +566,46 @@ def create_test_fp16_class(parent):
globals()[cls_name] = TestFp16Case globals()[cls_name] = TestFp16Case
def create_test_cudnn_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestCUDNNBf16Case(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
cls_name = "{}_{}".format(parent.__name__, "CUDNNBf16Op")
TestCUDNNBf16Case.__name__ = cls_name
globals()[cls_name] = TestCUDNNBf16Case
def create_test_bf16_class(parent):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestBf16Case(parent):
def init_kernel_type(self):
self.use_cudnn = False
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
cls_name = "{}_{}".format(parent.__name__, "Bf16Op")
TestBf16Case.__name__ = cls_name
globals()[cls_name] = TestBf16Case
create_test_cudnn_fp16_class(TestPool3D_Op) create_test_cudnn_fp16_class(TestPool3D_Op)
create_test_cudnn_fp16_class(TestCase1) create_test_cudnn_fp16_class(TestCase1)
create_test_cudnn_fp16_class(TestCase2) create_test_cudnn_fp16_class(TestCase2)
...@@ -580,6 +620,20 @@ create_test_fp16_class(TestCase3) ...@@ -580,6 +620,20 @@ create_test_fp16_class(TestCase3)
create_test_fp16_class(TestCase4) create_test_fp16_class(TestCase4)
create_test_fp16_class(TestCase5) create_test_fp16_class(TestCase5)
create_test_cudnn_bf16_class(TestPool3D_Op)
create_test_cudnn_bf16_class(TestCase1)
create_test_cudnn_bf16_class(TestCase2)
create_test_cudnn_bf16_class(TestCase3)
create_test_cudnn_bf16_class(TestCase4)
create_test_cudnn_bf16_class(TestCase5)
create_test_bf16_class(TestPool3D_Op)
create_test_bf16_class(TestCase1)
create_test_bf16_class(TestCase2)
create_test_bf16_class(TestCase3)
create_test_bf16_class(TestCase4)
create_test_bf16_class(TestCase5)
# ---- test ceil mode ------ # ---- test ceil mode ------
def create_test_cudnn_use_ceil_class(parent): def create_test_cudnn_use_ceil_class(parent):
...@@ -736,6 +790,13 @@ create_test_cudnn_fp16_class(TestCase3_AsyPadding) ...@@ -736,6 +790,13 @@ create_test_cudnn_fp16_class(TestCase3_AsyPadding)
create_test_cudnn_fp16_class(TestCase4_AsyPadding) create_test_cudnn_fp16_class(TestCase4_AsyPadding)
create_test_cudnn_fp16_class(TestCase5_AsyPadding) create_test_cudnn_fp16_class(TestCase5_AsyPadding)
create_test_cudnn_bf16_class(TestPool3D_Op_AsyPadding)
create_test_cudnn_bf16_class(TestCase1_AsyPadding)
create_test_cudnn_bf16_class(TestCase2_AsyPadding)
create_test_cudnn_bf16_class(TestCase3_AsyPadding)
create_test_cudnn_bf16_class(TestCase4_AsyPadding)
create_test_cudnn_bf16_class(TestCase5_AsyPadding)
create_test_cudnn_use_ceil_class(TestPool3D_Op_AsyPadding) create_test_cudnn_use_ceil_class(TestPool3D_Op_AsyPadding)
create_test_cudnn_use_ceil_class(TestCase1_AsyPadding) create_test_cudnn_use_ceil_class(TestCase1_AsyPadding)
......
...@@ -520,7 +520,7 @@ def avg_pool3d( ...@@ -520,7 +520,7 @@ def avg_pool3d(
op_type = "pool3d" op_type = "pool3d"
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'avg_pool3d' x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'avg_pool3d'
) )
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
......
...@@ -4027,7 +4027,7 @@ def lgamma(x, name=None): ...@@ -4027,7 +4027,7 @@ def lgamma(x, name=None):
Args: Args:
x (Tensor): Input Tensor. Must be one of the following types: float32, float64. x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64, uint16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -4046,7 +4046,9 @@ def lgamma(x, name=None): ...@@ -4046,7 +4046,9 @@ def lgamma(x, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.lgamma(x) return _C_ops.lgamma(x)
else: else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'lgamma') check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'lgamma'
)
helper = LayerHelper('lgamma', **locals()) helper = LayerHelper('lgamma', **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='lgamma', inputs={'X': x}, outputs={'Out': out}) helper.append_op(type='lgamma', inputs={'X': x}, outputs={'Out': out})
......
...@@ -807,7 +807,7 @@ def masked_select(x, mask, name=None): ...@@ -807,7 +807,7 @@ def masked_select(x, mask, name=None):
which is a tensor with data type of bool. which is a tensor with data type of bool.
Args: Args:
x (Tensor): The input Tensor, the data type can be int32, int64, float32, float64. x (Tensor): The input Tensor, the data type can be int32, int64, uint16, float16, float32, float64.
mask (Tensor): The Tensor containing the binary mask to index with, it's data type is bool. mask (Tensor): The Tensor containing the binary mask to index with, it's data type is bool.
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
...@@ -838,7 +838,7 @@ def masked_select(x, mask, name=None): ...@@ -838,7 +838,7 @@ def masked_select(x, mask, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'paddle.tensor.search.mask_select', 'paddle.tensor.search.mask_select',
) )
check_variable_and_dtype( check_variable_and_dtype(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册