未验证 提交 6115c14f 编写于 作者: L Leo Chen 提交者: GitHub

Pool2d cuda kernel supports fp16 (#28316)

* pool2d cuda kernel supports fp16

* fix compile issue of template

* add ut
上级 f41104ef
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
...@@ -126,7 +127,7 @@ __global__ void KernelPool2DGrad( ...@@ -126,7 +127,7 @@ __global__ void KernelPool2DGrad(
phend = min(h_offset / stride_height + 1, output_height); phend = min(h_offset / stride_height + 1, output_height);
pwend = min(w_offset / stride_width + 1, output_width); pwend = min(w_offset / stride_width + 1, output_width);
} }
T gradient = 0; T gradient = static_cast<T>(0.0);
T input = input_data[index]; T input = input_data[index];
int output_stride; int output_stride;
...@@ -264,12 +265,12 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()( ...@@ -264,12 +265,12 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
} }
/* /*
* Tensors are in NCHW or NHWC format. * Tensors are in NCHW or NHWC format.
* Ksize, strides are two elements. These two elements represent height * Ksize, strides are two elements. These two elements represent height
* and width, respectively. * and width, respectively.
* Paddings are four elements. These four elements represent height_up, * Paddings are four elements. These four elements represent height_up,
* height_down, width_left and width_right, respectively. * height_down, width_left and width_right, respectively.
*/ */
template <typename PoolProcess, typename T> template <typename PoolProcess, typename T>
class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> { class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
public: public:
...@@ -351,12 +352,12 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> { ...@@ -351,12 +352,12 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
} }
}; };
/* /*
* Tensors are in NCHW or NHWC format. * Tensors are in NCHW or NHWC format.
* Ksize, strides are two elements. These two elements represent height * Ksize, strides are two elements. These two elements represent height
* and width, respectively. * and width, respectively.
* Paddings are four elements. These four elements represent height_up, * Paddings are four elements. These four elements represent height_up,
* height_down, width_left and width_right, respectively. * height_down, width_left and width_right, respectively.
*/ */
template <typename PoolProcess, typename T> template <typename PoolProcess, typename T>
class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> { class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
public: public:
...@@ -448,12 +449,12 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> { ...@@ -448,12 +449,12 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
}; };
/* /*
* Tensors are in NCHW or NHWC format. * Tensors are in NCHW or NHWC format.
* Ksize, strides are two elements. These two elements represent height * Ksize, strides are two elements. These two elements represent height
* and width, respectively. * and width, respectively.
* Paddings are four elements. These four elements represent height_up, * Paddings are four elements. These four elements represent height_up,
* height_down, width_left and width_right, respectively. * height_down, width_left and width_right, respectively.
*/ */
template <typename T> template <typename T>
class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> { class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
public: public:
...@@ -549,6 +550,8 @@ template class Pool2dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>, ...@@ -549,6 +550,8 @@ template class Pool2dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>; template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>; template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext,
paddle::platform::float16>;
template class Pool2dFunctor<platform::CUDADeviceContext, template class Pool2dFunctor<platform::CUDADeviceContext,
paddle::operators::math::MaxPool<float>, float>; paddle::operators::math::MaxPool<float>, float>;
...@@ -571,6 +574,23 @@ template class Pool2dGradFunctor<platform::CUDADeviceContext, ...@@ -571,6 +574,23 @@ template class Pool2dGradFunctor<platform::CUDADeviceContext,
paddle::operators::math::AvgPoolGrad<double>, paddle::operators::math::AvgPoolGrad<double>,
double>; double>;
template class Pool2dFunctor<
platform::CUDADeviceContext,
paddle::operators::math::MaxPool<paddle::platform::float16>,
paddle::platform::float16>;
template class Pool2dFunctor<
platform::CUDADeviceContext,
paddle::operators::math::AvgPool<paddle::platform::float16>,
paddle::platform::float16>;
template class Pool2dGradFunctor<
platform::CUDADeviceContext,
paddle::operators::math::MaxPoolGrad<paddle::platform::float16>,
paddle::platform::float16>;
template class Pool2dGradFunctor<
platform::CUDADeviceContext,
paddle::operators::math::AvgPoolGrad<paddle::platform::float16>,
paddle::platform::float16>;
template <typename PoolProcess, typename T> template <typename PoolProcess, typename T>
__global__ void KernelPool3D( __global__ void KernelPool3D(
const int nthreads, const T* input_data, const int channels, const int nthreads, const T* input_data, const int channels,
...@@ -712,7 +732,7 @@ __global__ void KernelPool3DGrad( ...@@ -712,7 +732,7 @@ __global__ void KernelPool3DGrad(
pwend = min((w_offset) / stride_width + 1, output_width); pwend = min((w_offset) / stride_width + 1, output_width);
} }
T gradient = 0; T gradient = static_cast<T>(0.0);
T input = input_data[index]; T input = input_data[index];
int output_stride; int output_stride;
...@@ -848,13 +868,13 @@ __global__ void KernelMaxPool3DGrad( ...@@ -848,13 +868,13 @@ __global__ void KernelMaxPool3DGrad(
} }
/* /*
* Tensors are in NCDHW or NDHWC format. * Tensors are in NCDHW or NDHWC format.
* Ksize, strides, paddings are three elements. These three elements represent * Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively. * depth, height and width, respectively.
* Paddings are six elements. These six elements represent depth_forth, * Paddings are six elements. These six elements represent depth_forth,
* depth_back, * depth_back,
* height_up, height_down, width_left and width_right, respectively. * height_up, height_down, width_left and width_right, respectively.
*/ */
template <typename PoolProcess, class T> template <typename PoolProcess, class T>
class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> { class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
public: public:
...@@ -952,13 +972,13 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> { ...@@ -952,13 +972,13 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
}; };
/* /*
* Tensors are in NCDHW or NDHWC format. * Tensors are in NCDHW or NDHWC format.
* Ksize, strides, paddings are three elements. These three elements represent * Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively. * depth, height and width, respectively.
* Paddings are six elements. These six elements represent depth_forth, * Paddings are six elements. These six elements represent depth_forth,
* depth_back, * depth_back,
* height_up, height_down, width_left and width_right, respectively. * height_up, height_down, width_left and width_right, respectively.
*/ */
template <typename PoolProcess, class T> template <typename PoolProcess, class T>
class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> { class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
public: public:
...@@ -1064,13 +1084,13 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> { ...@@ -1064,13 +1084,13 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
}; };
/* /*
* tensors are in NCDHW or NDHWC format. * tensors are in NCDHW or NDHWC format.
* Ksize, strides, paddings are three elements. These three elements represent * Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively. * depth, height and width, respectively.
* Paddings are six elements. These six elements represent depth_forth, * Paddings are six elements. These six elements represent depth_forth,
* depth_back, * depth_back,
* height_up, height_down, width_left and width_right, respectively. * height_up, height_down, width_left and width_right, respectively.
*/ */
template <class T> template <class T>
class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> { class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
public: public:
...@@ -1174,6 +1194,8 @@ class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> { ...@@ -1174,6 +1194,8 @@ class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>; template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>; template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext,
paddle::platform::float16>;
template class Pool3dFunctor<platform::CUDADeviceContext, template class Pool3dFunctor<platform::CUDADeviceContext,
paddle::operators::math::MaxPool<float>, float>; paddle::operators::math::MaxPool<float>, float>;
...@@ -1196,6 +1218,23 @@ template class Pool3dGradFunctor<platform::CUDADeviceContext, ...@@ -1196,6 +1218,23 @@ template class Pool3dGradFunctor<platform::CUDADeviceContext,
paddle::operators::math::AvgPoolGrad<double>, paddle::operators::math::AvgPoolGrad<double>,
double>; double>;
template class Pool3dFunctor<
platform::CUDADeviceContext,
paddle::operators::math::MaxPool<paddle::platform::float16>,
paddle::platform::float16>;
template class Pool3dFunctor<
platform::CUDADeviceContext,
paddle::operators::math::AvgPool<paddle::platform::float16>,
paddle::platform::float16>;
template class Pool3dGradFunctor<
platform::CUDADeviceContext,
paddle::operators::math::MaxPoolGrad<paddle::platform::float16>,
paddle::platform::float16>;
template class Pool3dGradFunctor<
platform::CUDADeviceContext,
paddle::operators::math::AvgPoolGrad<paddle::platform::float16>,
paddle::platform::float16>;
template <typename T1, typename T2> template <typename T1, typename T2>
__global__ void KernelMaxPool2dWithIdx( __global__ void KernelMaxPool2dWithIdx(
const int nthreads, const T1* input_data, const int channels, const int nthreads, const T1* input_data, const int channels,
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -56,7 +57,7 @@ class MaxPoolGrad { ...@@ -56,7 +57,7 @@ class MaxPoolGrad {
public: public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale, DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) { T* dx) {
*dx += dy * (x == y); *dx += dy * static_cast<T>(x == y);
} }
}; };
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/pool_op.h" #include "paddle/fluid/operators/pool_op.h"
#include <unordered_map> #include <unordered_map>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
...@@ -219,11 +220,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( ...@@ -219,11 +220,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
#endif #endif
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
platform::errors::InvalidArgument(
"Float16 can only be used when CUDNN is used"));
}
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_); library_);
} }
......
...@@ -18,16 +18,24 @@ namespace ops = paddle::operators; ...@@ -18,16 +18,24 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
pool2d, ops::PoolKernel<paddle::platform::CUDADeviceContext, float>, pool2d, ops::PoolKernel<paddle::platform::CUDADeviceContext, float>,
ops::PoolKernel<paddle::platform::CUDADeviceContext, double>); ops::PoolKernel<paddle::platform::CUDADeviceContext, double>,
ops::PoolKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
pool2d_grad, pool2d_grad,
ops::PoolGradKernel<paddle::platform::CUDADeviceContext, float>, ops::PoolGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::PoolGradKernel<paddle::platform::CUDADeviceContext, double>); ops::PoolGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::PoolGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
pool3d, ops::PoolKernel<paddle::platform::CUDADeviceContext, float>, pool3d, ops::PoolKernel<paddle::platform::CUDADeviceContext, float>,
ops::PoolKernel<paddle::platform::CUDADeviceContext, double>); ops::PoolKernel<paddle::platform::CUDADeviceContext, double>,
ops::PoolKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
pool3d_grad, pool3d_grad,
ops::PoolGradKernel<paddle::platform::CUDADeviceContext, float>, ops::PoolGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::PoolGradKernel<paddle::platform::CUDADeviceContext, double>); ops::PoolGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::PoolGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -257,7 +258,7 @@ class PoolGradKernel : public framework::OpKernel<T> { ...@@ -257,7 +258,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
paddle::operators::math::SetConstant<DeviceContext, T> set_constant; paddle::operators::math::SetConstant<DeviceContext, T> set_constant;
set_constant(dev_ctx, in_x_grad, 0.0); set_constant(dev_ctx, in_x_grad, static_cast<T>(0.0));
switch (ksize.size()) { switch (ksize.size()) {
case 2: { case 2: {
......
...@@ -475,6 +475,41 @@ def create_test_cudnn_fp16_class(parent, check_grad=True): ...@@ -475,6 +475,41 @@ def create_test_cudnn_fp16_class(parent, check_grad=True):
globals()[cls_name] = TestCUDNNFp16Case globals()[cls_name] = TestCUDNNFp16Case
def create_test_fp16_class(parent, check_grad=True):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFp16Case(parent):
def init_kernel_type(self):
self.use_cudnn = False
self.dtype = np.float16
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(
place,
atol=1e-3,
check_dygraph=(self.use_mkldnn == False))
def test_check_grad(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
place = core.CUDAPlace(0)
if core.is_float16_supported(
place) and self.pool_type != "max" and check_grad:
self.check_grad_with_place(
place,
set(['X']),
'Out',
max_relative_error=0.07,
check_dygraph=(self.use_mkldnn == False))
cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op")
TestFp16Case.__name__ = cls_name
globals()[cls_name] = TestFp16Case
create_test_cudnn_fp16_class(TestPool2D_Op) create_test_cudnn_fp16_class(TestPool2D_Op)
create_test_cudnn_fp16_class(TestCase1, check_grad=False) create_test_cudnn_fp16_class(TestCase1, check_grad=False)
create_test_cudnn_fp16_class(TestCase2) create_test_cudnn_fp16_class(TestCase2)
...@@ -482,6 +517,13 @@ create_test_cudnn_fp16_class(TestCase3) ...@@ -482,6 +517,13 @@ create_test_cudnn_fp16_class(TestCase3)
create_test_cudnn_fp16_class(TestCase4) create_test_cudnn_fp16_class(TestCase4)
create_test_cudnn_fp16_class(TestCase5) create_test_cudnn_fp16_class(TestCase5)
create_test_fp16_class(TestPool2D_Op)
create_test_fp16_class(TestCase1, check_grad=False)
create_test_fp16_class(TestCase2)
create_test_fp16_class(TestCase3)
create_test_fp16_class(TestCase4)
create_test_fp16_class(TestCase5)
#--------------------test pool2d use ceil mode-------------------- #--------------------test pool2d use ceil mode--------------------
......
...@@ -405,6 +405,25 @@ def create_test_cudnn_fp16_class(parent): ...@@ -405,6 +405,25 @@ def create_test_cudnn_fp16_class(parent):
globals()[cls_name] = TestCUDNNFp16Case globals()[cls_name] = TestCUDNNFp16Case
def create_test_fp16_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFp16Case(parent):
def init_kernel_type(self):
self.use_cudnn = False
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-2)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op")
TestFp16Case.__name__ = cls_name
globals()[cls_name] = TestFp16Case
create_test_cudnn_fp16_class(TestPool3D_Op) create_test_cudnn_fp16_class(TestPool3D_Op)
create_test_cudnn_fp16_class(TestCase1) create_test_cudnn_fp16_class(TestCase1)
create_test_cudnn_fp16_class(TestCase2) create_test_cudnn_fp16_class(TestCase2)
...@@ -412,6 +431,13 @@ create_test_cudnn_fp16_class(TestCase3) ...@@ -412,6 +431,13 @@ create_test_cudnn_fp16_class(TestCase3)
create_test_cudnn_fp16_class(TestCase4) create_test_cudnn_fp16_class(TestCase4)
create_test_cudnn_fp16_class(TestCase5) create_test_cudnn_fp16_class(TestCase5)
create_test_fp16_class(TestPool3D_Op)
create_test_fp16_class(TestCase1)
create_test_fp16_class(TestCase2)
create_test_fp16_class(TestCase3)
create_test_fp16_class(TestCase4)
create_test_fp16_class(TestCase5)
# ---- test ceil mode ------ # ---- test ceil mode ------
def create_test_cudnn_use_ceil_class(parent): def create_test_cudnn_use_ceil_class(parent):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册