From bfbc25bdb87423d5334d826b8b87ce5e61e29d70 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Fri, 16 Mar 2018 12:09:24 -0700 Subject: [PATCH] add fp16 pool2d support --- paddle/fluid/operators/pool_cudnn_op.cu.cc | 18 ++-- paddle/fluid/operators/pool_op.cc | 10 ++- .../paddle/fluid/tests/unittests/op_test.py | 6 +- .../fluid/tests/unittests/test_conv2d_op.py | 41 +++++---- .../fluid/tests/unittests/test_pool2d_op.py | 89 ++++++++++++++++++- 5 files changed, 131 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 781d96981..b91a0c488 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -78,7 +78,8 @@ class PoolCUDNNOpKernel : public framework::OpKernel { // ------------------- cudnn pool algorithm --------------------- auto handle = ctx.cuda_device_context().cudnn_handle(); - T alpha = 1.0f, beta = 0.0f; + typename platform::CudnnDataType::ScalingParamType alpha = 1.0f, + beta = 0.0f; PADDLE_ENFORCE(platform::dynload::cudnnPoolingForward( handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta, @@ -144,7 +145,8 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { // ------------------- cudnn pool algorithm --------------------- auto handle = ctx.cuda_device_context().cudnn_handle(); - T alpha = 1.0f, beta = 0.0f; + typename platform::CudnnDataType::ScalingParamType alpha = 1.0f, + beta = 0.0f; if (input_grad) { T *input_grad_data = input_grad->mutable_data(ctx.GetPlace()); @@ -162,17 +164,19 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; -REGISTER_OP_KERNEL(pool2d, CUDNN, ::paddle::platform::CUDAPlace, +REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel, - ops::PoolCUDNNOpKernel); -REGISTER_OP_KERNEL(pool2d_grad, CUDNN, ::paddle::platform::CUDAPlace, + ops::PoolCUDNNOpKernel, + ops::PoolCUDNNOpKernel); +REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, ops::PoolCUDNNGradOpKernel); -REGISTER_OP_KERNEL(pool3d, CUDNN, ::paddle::platform::CUDAPlace, +REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel, ops::PoolCUDNNOpKernel); -REGISTER_OP_KERNEL(pool3d_grad, CUDNN, ::paddle::platform::CUDAPlace, +REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, ops::PoolCUDNNGradOpKernel); diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index d78da1001..b144ec5f7 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -124,11 +124,15 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( } #endif + auto input_data_type = framework::ToDataType(ctx.Input("X")->type()); + if (input_data_type == framework::proto::VarType::FP16) { + PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, + "float16 can only be used when CUDNN is used"); + } std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, + library_); } Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 6a42f763a..8393f7827 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -483,9 +483,9 @@ class OpTest(unittest.TestCase): input: input numpy array Returns: - input: if the dtype of input is np.float16, its dtype will be - changed to np.uint16 so that the internal memory will be - reinterpreted input as of dtype np.uint16. + input: The dtype of input will be changed to np.uint16 if + it is originally np.float16, such that the internal memory + of input will be reinterpreted as of dtype np.uint16. """ if input.dtype == np.float16: input.dtype = np.uint16 diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 7913b9824..dfd83fdb3 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -65,10 +65,10 @@ class TestConv2dOp(OpTest): def setUp(self): self.use_cudnn = False self.use_mkldnn = False + self.dtype = np.float32 self.init_op_type() self.init_group() self.init_dilation() - self.init_data_type() self.init_test_case() conv2d_param = { @@ -159,9 +159,6 @@ class TestConv2dOp(OpTest): f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3] - def init_data_type(self): - self.dtype = np.float32 - def init_dilation(self): self.dilations = [1, 1] @@ -246,8 +243,10 @@ class TestCUDNN(TestConv2dOp): self.op_type = "conv2d" -class TestFP16CUDNN(TestCUDNN): - def init_data_type(self): +class TestFP16CUDNN(TestConv2dOp): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d" self.dtype = np.float16 def test_check_output(self): @@ -263,8 +262,10 @@ class TestCUDNNWithPad(TestWithPad): self.op_type = "conv2d" -class TestFP16CUDNNWithPad(TestCUDNNWithPad): - def init_data_type(self): +class TestFP16CUDNNWithPad(TestWithPad): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d" self.dtype = np.float16 def test_check_output(self): @@ -280,8 +281,10 @@ class TestCUDNNWithStride(TestWithStride): self.op_type = "conv2d" -class TestFP16CUDNNWithStride(TestCUDNNWithStride): - def init_data_type(self): +class TestFP16CUDNNWithStride(TestWithStride): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d" self.dtype = np.float16 def test_check_output(self): @@ -297,8 +300,10 @@ class TestCUDNNWithGroup(TestWithGroup): self.op_type = "conv2d" -class TestFP16CUDNNWithGroup(TestCUDNNWithGroup): - def init_data_type(self): +class TestFP16CUDNNWithGroup(TestWithGroup): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d" self.dtype = np.float16 def test_check_output(self): @@ -314,8 +319,10 @@ class TestCUDNNWith1x1(TestWith1x1): self.op_type = "conv2d" -class TestFP16CUDNNWith1x1(TestCUDNNWith1x1): - def init_data_type(self): +class TestFP16CUDNNWith1x1(TestWith1x1): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d" self.dtype = np.float16 def test_check_output(self): @@ -331,8 +338,10 @@ class TestCUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1): self.op_type = "conv2d" -class TestFP16CUDNNWithInput1x1Filter1x1(TestCUDNNWithInput1x1Filter1x1): - def init_data_type(self): +class TestFP16CUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "conv2d" self.dtype = np.float16 def test_check_output(self): diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_op.py index 964d78f19..76b15e409 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_op.py @@ -80,6 +80,7 @@ class TestPool2d_Op(OpTest): def setUp(self): self.use_cudnn = False self.use_mkldnn = False + self.dtype = np.float32 self.init_test_case() self.init_global_pool() self.init_op_type() @@ -87,11 +88,11 @@ class TestPool2d_Op(OpTest): self.init_ceil_mode() if self.global_pool: self.paddings = [0 for _ in range(len(self.paddings))] - input = np.random.random(self.shape).astype("float32") + input = np.random.random(self.shape).astype(self.dtype) output = self.pool2D_forward_naive(input, self.ksize, self.strides, self.paddings, self.global_pool, - self.ceil_mode).astype("float32") - self.inputs = {'X': input} + self.ceil_mode).astype(self.dtype) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} self.attrs = { 'strides': self.strides, @@ -105,7 +106,7 @@ class TestPool2d_Op(OpTest): 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter } - self.outputs = {'Out': output.astype('float32')} + self.outputs = {'Out': output} def test_check_output(self): if self.use_cudnn: @@ -115,6 +116,8 @@ class TestPool2d_Op(OpTest): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return if self.use_cudnn and self.pool_type != "max": place = core.CUDAPlace(0) self.check_grad_with_place( @@ -212,36 +215,114 @@ class TestCUDNNCase1(TestPool2d_Op): self.op_type = "pool2d" +class TestFP16CUDNNCase1(TestPool2d_Op): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "pool2d" + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + + class TestCUDNNCase2(TestCase1): def init_op_type(self): self.use_cudnn = True self.op_type = "pool2d" +class TestFP16CUDNNCase2(TestCase1): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "pool2d" + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + + class TestCUDNNCase3(TestCase2): def init_op_type(self): self.use_cudnn = True self.op_type = "pool2d" +class TestFP16CUDNNCase3(TestCase2): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "pool2d" + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + + class TestCUDNNCase4(TestCase3): def init_op_type(self): self.use_cudnn = True self.op_type = "pool2d" +class TestFP16CUDNNCase4(TestCase3): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "pool2d" + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + + class TestCUDNNCase5(TestCase4): def init_op_type(self): self.use_cudnn = True self.op_type = "pool2d" +class TestFP16CUDNNCase5(TestCase4): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "pool2d" + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + + class TestCUDNNCase6(TestCase5): def init_op_type(self): self.use_cudnn = True self.op_type = "pool2d" +class TestFP16CUDNNCase6(TestCase5): + def init_op_type(self): + self.use_cudnn = True + self.op_type = "pool2d" + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + + class TestCeilModeCase1(TestCUDNNCase1): def init_ceil_mode(self): self.ceil_mode = True -- GitLab