diff --git a/paddle/fluid/operators/math/pooling.cu b/paddle/fluid/operators/math/pooling.cu index 22164131468a46bffc239509a9213a21f1611ed5..b64dbb771a3398d0a48ee5bb4e11607fcef99b73 100644 --- a/paddle/fluid/operators/math/pooling.cu +++ b/paddle/fluid/operators/math/pooling.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include #include + #include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/platform/cuda_primitives.h" @@ -126,7 +127,7 @@ __global__ void KernelPool2DGrad( phend = min(h_offset / stride_height + 1, output_height); pwend = min(w_offset / stride_width + 1, output_width); } - T gradient = 0; + T gradient = static_cast(0.0); T input = input_data[index]; int output_stride; @@ -264,12 +265,12 @@ void Pool2dDirectCUDAFunctor::operator()( } /* -* Tensors are in NCHW or NHWC format. -* Ksize, strides are two elements. These two elements represent height -* and width, respectively. -* Paddings are four elements. These four elements represent height_up, -* height_down, width_left and width_right, respectively. -*/ + * Tensors are in NCHW or NHWC format. + * Ksize, strides are two elements. These two elements represent height + * and width, respectively. + * Paddings are four elements. These four elements represent height_up, + * height_down, width_left and width_right, respectively. + */ template class Pool2dFunctor { public: @@ -351,12 +352,12 @@ class Pool2dFunctor { } }; /* -* Tensors are in NCHW or NHWC format. -* Ksize, strides are two elements. These two elements represent height -* and width, respectively. -* Paddings are four elements. These four elements represent height_up, -* height_down, width_left and width_right, respectively. -*/ + * Tensors are in NCHW or NHWC format. + * Ksize, strides are two elements. These two elements represent height + * and width, respectively. + * Paddings are four elements. These four elements represent height_up, + * height_down, width_left and width_right, respectively. + */ template class Pool2dGradFunctor { public: @@ -448,12 +449,12 @@ class Pool2dGradFunctor { }; /* -* Tensors are in NCHW or NHWC format. -* Ksize, strides are two elements. These two elements represent height -* and width, respectively. -* Paddings are four elements. These four elements represent height_up, -* height_down, width_left and width_right, respectively. -*/ + * Tensors are in NCHW or NHWC format. + * Ksize, strides are two elements. These two elements represent height + * and width, respectively. + * Paddings are four elements. These four elements represent height_up, + * height_down, width_left and width_right, respectively. + */ template class MaxPool2dGradFunctor { public: @@ -549,6 +550,8 @@ template class Pool2dDirectCUDAFunctor, template class MaxPool2dGradFunctor; template class MaxPool2dGradFunctor; +template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; @@ -571,6 +574,23 @@ template class Pool2dGradFunctor, double>; +template class Pool2dFunctor< + platform::CUDADeviceContext, + paddle::operators::math::MaxPool, + paddle::platform::float16>; +template class Pool2dFunctor< + platform::CUDADeviceContext, + paddle::operators::math::AvgPool, + paddle::platform::float16>; +template class Pool2dGradFunctor< + platform::CUDADeviceContext, + paddle::operators::math::MaxPoolGrad, + paddle::platform::float16>; +template class Pool2dGradFunctor< + platform::CUDADeviceContext, + paddle::operators::math::AvgPoolGrad, + paddle::platform::float16>; + template __global__ void KernelPool3D( const int nthreads, const T* input_data, const int channels, @@ -712,7 +732,7 @@ __global__ void KernelPool3DGrad( pwend = min((w_offset) / stride_width + 1, output_width); } - T gradient = 0; + T gradient = static_cast(0.0); T input = input_data[index]; int output_stride; @@ -848,13 +868,13 @@ __global__ void KernelMaxPool3DGrad( } /* -* Tensors are in NCDHW or NDHWC format. -* Ksize, strides, paddings are three elements. These three elements represent -* depth, height and width, respectively. -* Paddings are six elements. These six elements represent depth_forth, -* depth_back, -* height_up, height_down, width_left and width_right, respectively. -*/ + * Tensors are in NCDHW or NDHWC format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + * Paddings are six elements. These six elements represent depth_forth, + * depth_back, + * height_up, height_down, width_left and width_right, respectively. + */ template class Pool3dFunctor { public: @@ -952,13 +972,13 @@ class Pool3dFunctor { }; /* -* Tensors are in NCDHW or NDHWC format. -* Ksize, strides, paddings are three elements. These three elements represent -* depth, height and width, respectively. -* Paddings are six elements. These six elements represent depth_forth, -* depth_back, -* height_up, height_down, width_left and width_right, respectively. -*/ + * Tensors are in NCDHW or NDHWC format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + * Paddings are six elements. These six elements represent depth_forth, + * depth_back, + * height_up, height_down, width_left and width_right, respectively. + */ template class Pool3dGradFunctor { public: @@ -1064,13 +1084,13 @@ class Pool3dGradFunctor { }; /* -* tensors are in NCDHW or NDHWC format. -* Ksize, strides, paddings are three elements. These three elements represent -* depth, height and width, respectively. -* Paddings are six elements. These six elements represent depth_forth, -* depth_back, -* height_up, height_down, width_left and width_right, respectively. -*/ + * tensors are in NCDHW or NDHWC format. + * Ksize, strides, paddings are three elements. These three elements represent + * depth, height and width, respectively. + * Paddings are six elements. These six elements represent depth_forth, + * depth_back, + * height_up, height_down, width_left and width_right, respectively. + */ template class MaxPool3dGradFunctor { public: @@ -1174,6 +1194,8 @@ class MaxPool3dGradFunctor { template class MaxPool3dGradFunctor; template class MaxPool3dGradFunctor; +template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; @@ -1196,6 +1218,23 @@ template class Pool3dGradFunctor, double>; +template class Pool3dFunctor< + platform::CUDADeviceContext, + paddle::operators::math::MaxPool, + paddle::platform::float16>; +template class Pool3dFunctor< + platform::CUDADeviceContext, + paddle::operators::math::AvgPool, + paddle::platform::float16>; +template class Pool3dGradFunctor< + platform::CUDADeviceContext, + paddle::operators::math::MaxPoolGrad, + paddle::platform::float16>; +template class Pool3dGradFunctor< + platform::CUDADeviceContext, + paddle::operators::math::AvgPoolGrad, + paddle::platform::float16>; + template __global__ void KernelMaxPool2dWithIdx( const int nthreads, const T1* input_data, const int channels, diff --git a/paddle/fluid/operators/math/pooling.h b/paddle/fluid/operators/math/pooling.h index 572295f138d599433f0d947a10df867cded9aa71..5a6ae224789a23319aa79a0e7159e183f8c28205 100644 --- a/paddle/fluid/operators/math/pooling.h +++ b/paddle/fluid/operators/math/pooling.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include #include + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -56,7 +57,7 @@ class MaxPoolGrad { public: DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale, T* dx) { - *dx += dy * (x == y); + *dx += dy * static_cast(x == y); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index ba468b79605575b8957e1e963dabad9dff957eb3..5b0980a98513bfa0bb619ac182fbfca4961dada2 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/pool_op.h" + #include #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" @@ -219,11 +220,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( #endif auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - if (input_data_type == framework::proto::VarType::FP16) { - PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, - platform::errors::InvalidArgument( - "Float16 can only be used when CUDNN is used")); - } + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, library_); } diff --git a/paddle/fluid/operators/pool_op.cu.cc b/paddle/fluid/operators/pool_op.cu.cc index 37bc14e2cbb3437a750c416d39b7b914370961b0..6b1e9f93033aa7139dfba865c4dff27f67461ed6 100644 --- a/paddle/fluid/operators/pool_op.cu.cc +++ b/paddle/fluid/operators/pool_op.cu.cc @@ -18,16 +18,24 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( pool2d, ops::PoolKernel, - ops::PoolKernel); + ops::PoolKernel, + ops::PoolKernel); REGISTER_OP_CUDA_KERNEL( pool2d_grad, ops::PoolGradKernel, - ops::PoolGradKernel); + ops::PoolGradKernel, + ops::PoolGradKernel); REGISTER_OP_CUDA_KERNEL( pool3d, ops::PoolKernel, - ops::PoolKernel); + ops::PoolKernel, + ops::PoolKernel); REGISTER_OP_CUDA_KERNEL( pool3d_grad, ops::PoolGradKernel, - ops::PoolGradKernel); + ops::PoolGradKernel, + ops::PoolGradKernel); diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index 677c724069cf49a4bf7c1dc298cf30828d7e6e71..71bef11b67225b3ac20907eeba4d10a0883e8e21 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include + #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" @@ -257,7 +258,7 @@ class PoolGradKernel : public framework::OpKernel { if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); paddle::operators::math::SetConstant set_constant; - set_constant(dev_ctx, in_x_grad, 0.0); + set_constant(dev_ctx, in_x_grad, static_cast(0.0)); switch (ksize.size()) { case 2: { diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_op.py index 8553fa8b99a92723430c2e2bfbb84bb0bdc3b258..e6d41902a7c6d3e3ff7f5f91ddee013822cd0277 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_op.py @@ -475,6 +475,41 @@ def create_test_cudnn_fp16_class(parent, check_grad=True): globals()[cls_name] = TestCUDNNFp16Case +def create_test_fp16_class(parent, check_grad=True): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestFp16Case(parent): + def init_kernel_type(self): + self.use_cudnn = False + self.dtype = np.float16 + + def test_check_output(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place( + place, + atol=1e-3, + check_dygraph=(self.use_mkldnn == False)) + + def test_check_grad(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + place = core.CUDAPlace(0) + if core.is_float16_supported( + place) and self.pool_type != "max" and check_grad: + self.check_grad_with_place( + place, + set(['X']), + 'Out', + max_relative_error=0.07, + check_dygraph=(self.use_mkldnn == False)) + + cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op") + TestFp16Case.__name__ = cls_name + globals()[cls_name] = TestFp16Case + + create_test_cudnn_fp16_class(TestPool2D_Op) create_test_cudnn_fp16_class(TestCase1, check_grad=False) create_test_cudnn_fp16_class(TestCase2) @@ -482,6 +517,13 @@ create_test_cudnn_fp16_class(TestCase3) create_test_cudnn_fp16_class(TestCase4) create_test_cudnn_fp16_class(TestCase5) +create_test_fp16_class(TestPool2D_Op) +create_test_fp16_class(TestCase1, check_grad=False) +create_test_fp16_class(TestCase2) +create_test_fp16_class(TestCase3) +create_test_fp16_class(TestCase4) +create_test_fp16_class(TestCase5) + #--------------------test pool2d use ceil mode-------------------- diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_op.py b/python/paddle/fluid/tests/unittests/test_pool3d_op.py index fade1691210a4c8da151b67fda2bf356ece3a616..2d20cfc4cfc9b52c2987bec6157b221cbef8f5b5 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_op.py @@ -405,6 +405,25 @@ def create_test_cudnn_fp16_class(parent): globals()[cls_name] = TestCUDNNFp16Case +def create_test_fp16_class(parent): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestFp16Case(parent): + def init_kernel_type(self): + self.use_cudnn = False + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-2) + + cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op") + TestFp16Case.__name__ = cls_name + globals()[cls_name] = TestFp16Case + + create_test_cudnn_fp16_class(TestPool3D_Op) create_test_cudnn_fp16_class(TestCase1) create_test_cudnn_fp16_class(TestCase2) @@ -412,6 +431,13 @@ create_test_cudnn_fp16_class(TestCase3) create_test_cudnn_fp16_class(TestCase4) create_test_cudnn_fp16_class(TestCase5) +create_test_fp16_class(TestPool3D_Op) +create_test_fp16_class(TestCase1) +create_test_fp16_class(TestCase2) +create_test_fp16_class(TestCase3) +create_test_fp16_class(TestCase4) +create_test_fp16_class(TestCase5) + # ---- test ceil mode ------ def create_test_cudnn_use_ceil_class(parent):