diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index d0a0416994169e50cd3bec4d4a0590f686e6330c..2f89b51815e64f887ca97165287aa59c65b3cdff 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -1963,7 +1963,7 @@ __global__ void KernelMaxPool2dWithIdx(const int nthreads, wstart = max(wstart, 0); } - T1 ele = -FLT_MAX; + T1 ele = static_cast(-FLT_MAX); int max_index = -1; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -2015,7 +2015,7 @@ __global__ void AdaptiveKernelMaxPool2dWithIdx(const int nthreads, wstart = AdaptStartIndex(w_offset, input_width, output_width); wend = AdaptEndIndex(w_offset, input_width, output_width); - T1 ele = -FLT_MAX; + T1 ele = static_cast(-FLT_MAX); int max_index = -1; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -2089,7 +2089,7 @@ __global__ void KernelMaxPool2DWithIdxGrad(const int nthreads, pwend = min((w_offset + padding_width) / stride_width + 1, output_width); } - T1 input_grad_data = 0; + T1 input_grad_data = static_cast(0); int input_current_featuremap_idx = h_offset * input_width + w_offset; for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { @@ -2259,6 +2259,14 @@ template class MaxPool2dWithIndexFunctor; template class MaxPool2dWithIndexGradFunctor; template class MaxPool2dWithIndexFunctor; template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; template __global__ void KernelMaxPool3DWithIdx(const int ncd, @@ -2324,7 +2332,7 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, wstart = max(wstart, 0); } - T1 ele = -FLT_MAX; + T1 ele = static_cast(-FLT_MAX); int max_index = -1; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { @@ -2560,6 +2568,14 @@ template class MaxPool3dWithIndexFunctor; template class MaxPool3dWithIndexGradFunctor; template class MaxPool3dWithIndexFunctor; template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/pool_grad_kernel.cu b/paddle/phi/kernels/gpu/pool_grad_kernel.cu index e4cfcb23b730e7a9179ab0dc1c4c84dba1b09ded..c625977543558ba5066c8a210dbd9b645f750a93 100644 --- a/paddle/phi/kernels/gpu/pool_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/pool_grad_kernel.cu @@ -38,7 +38,9 @@ PD_REGISTER_KERNEL(max_pool2d_with_index_grad, ALL_LAYOUT, phi::MaxPool2dWithIndexGradKernel, float, - double) { + double, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(1).SetDataType(phi::CppTypeToDataType::Type()); } @@ -55,6 +57,8 @@ PD_REGISTER_KERNEL(max_pool3d_with_index_grad, ALL_LAYOUT, phi::MaxPool3dWithIndexGradKernel, float, - double) { + double, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(1).SetDataType(phi::CppTypeToDataType::Type()); } diff --git a/paddle/phi/kernels/gpu/pool_kernel.cu b/paddle/phi/kernels/gpu/pool_kernel.cu index 65d0ef4bdc916882fba8a7a430f828cdafb51354..511cc263bc7602dd2d23e0c96f5413fef87e7716 100644 --- a/paddle/phi/kernels/gpu/pool_kernel.cu +++ b/paddle/phi/kernels/gpu/pool_kernel.cu @@ -32,7 +32,9 @@ PD_REGISTER_KERNEL(max_pool2d_with_index, ALL_LAYOUT, phi::MaxPool2dWithIndexKernel, float, - double) { + double, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->OutputAt(1).SetDataType(phi::CppTypeToDataType::Type()); } @@ -49,6 +51,8 @@ PD_REGISTER_KERNEL(max_pool3d_with_index, ALL_LAYOUT, phi::MaxPool3dWithIndexKernel, float, - double) { + double, + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->OutputAt(1).SetDataType(phi::CppTypeToDataType::Type()); } diff --git a/python/paddle/fluid/tests/unittests/test_pool_max_op.py b/python/paddle/fluid/tests/unittests/test_pool_max_op.py index d8d61f4fb2904e40832f078889dae19322afb3e2..16d1f356537bc9fe8470f1b19258cbc4b352c732 100644 --- a/python/paddle/fluid/tests/unittests/test_pool_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool_max_op.py @@ -15,9 +15,16 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import ( + OpTest, + convert_float_to_uint16, + convert_uint16_to_float, + get_numeric_gradient, +) import paddle +from paddle.fluid import core +from paddle.fluid.tests.unittests.testsuite import create_op def adaptive_start_index(index, input_size, output_size): @@ -149,9 +156,18 @@ class TestMaxPoolWithIndex_Op(OpTest): self.init_test_case() self.init_global() self.init_adaptive() + self.init_dtype() + + if self.is_bfloat16_op(): + input = np.random.random(self.shape).astype(np.float32) + input = convert_uint16_to_float( + convert_float_to_uint16(np.round(input * 100.0, 2)) + ) + + else: + input = np.random.random(self.shape).astype(self.dtype) + input = np.round(input * 100.0, 2) - input = np.random.random(self.shape).astype("float64") - input = np.round(input * 100.0, 2) output, mask = self.pool_forward_naive( input, self.ksize, @@ -160,8 +176,11 @@ class TestMaxPoolWithIndex_Op(OpTest): self.global_pool, self.adaptive, ) - output = output.astype("float64") mask = mask.astype("int32") + if self.is_bfloat16_op(): + output = output.astype(np.float32) + else: + output = output.astype(self.dtype) self.attrs = { 'strides': self.strides, @@ -171,8 +190,20 @@ class TestMaxPoolWithIndex_Op(OpTest): 'adaptive': self.adaptive, } - self.inputs = {'X': input} - self.outputs = {'Out': output, "Mask": mask} + if self.is_bfloat16_op(): + self.inputs = {'X': convert_float_to_uint16(input)} + self.outputs = { + 'Out': convert_float_to_uint16(output), + "Mask": mask, + } + self.inputs_fp32 = {'X': input} + + else: + self.inputs = {'X': input} + self.outputs = {'Out': output, "Mask": mask} + + def init_dtype(self): + self.dtype = np.float64 def test_check_output(self): self.check_output() @@ -220,9 +251,90 @@ class TestCase3(TestCase2): self.global_pool = False -# ----------------max_pool2d_with_index---------------- +class TestCastAdaptive3d(TestMaxPoolWithIndex_Op): + def init_adaptive(self): + self.adaptive = True +# ----------------max_pool3d_with_index_fp16---------------- +def create_test_fp16_class(parent): + @unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + ) + class TestMaxPool3dFP16(parent): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place(place, {'X'}, ['Out']) + + cls_name = "{}_{}".format(parent.__name__, "FP16OP") + TestMaxPool3dFP16.__name__ = cls_name + globals()[cls_name] = TestMaxPool3dFP16 + + +create_test_fp16_class(TestMaxPoolWithIndex_Op) +create_test_fp16_class(TestCase1) +create_test_fp16_class(TestCase2) +create_test_fp16_class(TestCase3) +create_test_fp16_class(TestCastAdaptive3d) + + +# ----------------max_pool3d_with_index_bf16---------------- +def create_test_bf16_class(parent): + @unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", + ) + class TestMaxPool3dBF16(parent): + def init_dtype(self): + self.dtype = np.uint16 + + def get_numeric_grad(self, place, check_name): + scope = core.Scope() + self._check_grad_helper() + op = create_op( + scope, self.op_type, self.inputs, self.outputs, self.attrs + ) + return get_numeric_gradient( + place, scope, op, self.inputs_fp32, check_name, ['Out'] + ) + + def test_check_output(self): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + numeric_grads = self.get_numeric_grad(place, 'X') + if core.is_bfloat16_supported(place): + self.check_grad_with_place( + place, {'X'}, ['Out'], user_defined_grads=[numeric_grads] + ) + + cls_name = "{}_{}".format(parent.__name__, "BF16OP") + TestMaxPool3dBF16.__name__ = cls_name + globals()[cls_name] = TestMaxPool3dBF16 + + +create_test_bf16_class(TestMaxPoolWithIndex_Op) +create_test_bf16_class(TestCase1) +create_test_bf16_class(TestCase2) +create_test_bf16_class(TestCase3) +create_test_bf16_class(TestCastAdaptive3d) + + +# ----------------max_pool2d_with_index---------------- def max_pool2d_with_index_wapper( x, kernel_size=[], @@ -279,9 +391,82 @@ class TestCastAdaptive2d(TestCase6): self.adaptive = True -class TestCastAdaptive3d(TestMaxPoolWithIndex_Op): - def init_adaptive(self): - self.adaptive = True +# ----------------max_pool2d_with_index_fp16---------------- +def create_test_fp16_class(parent): + @unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + ) + class TestMaxPool2dFP16(parent): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place(place, {'X'}, ['Out']) + + cls_name = "{}_{}".format(parent.__name__, "FP16OP") + TestMaxPool2dFP16.__name__ = cls_name + globals()[cls_name] = TestMaxPool2dFP16 + + +create_test_fp16_class(TestCase4) +create_test_fp16_class(TestCase5) +create_test_fp16_class(TestCase6) +create_test_fp16_class(TestCase7) +create_test_fp16_class(TestCastAdaptive2d) + + +# ----------------max_pool2d_with_index_bf16---------------- +def create_test_bf16_class(parent): + @unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", + ) + class TestMaxPool2dBF16(parent): + def init_dtype(self): + self.dtype = np.uint16 + + def get_numeric_grad(self, place, check_name): + scope = core.Scope() + self._check_grad_helper() + op = create_op( + scope, self.op_type, self.inputs, self.outputs, self.attrs + ) + return get_numeric_gradient( + place, scope, op, self.inputs_fp32, check_name, ['Out'] + ) + + def test_check_output(self): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + numeric_grads = self.get_numeric_grad(place, 'X') + if core.is_bfloat16_supported(place): + self.check_grad_with_place( + place, {'X'}, ['Out'], user_defined_grads=[numeric_grads] + ) + + cls_name = "{}_{}".format(parent.__name__, "BF16OP") + TestMaxPool2dBF16.__name__ = cls_name + globals()[cls_name] = TestMaxPool2dBF16 + + +create_test_bf16_class(TestCase4) +create_test_bf16_class(TestCase5) +create_test_bf16_class(TestCase6) +create_test_bf16_class(TestCase7) +create_test_bf16_class(TestCastAdaptive2d) if __name__ == '__main__':