diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index 0ddbfdb4aa9e844adbb291e1c5612e96681831d6..a32aba4c1ff2f5e775aeb41f25b02322dbc6a64a 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -28,6 +28,8 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using DataLayout = platform::DataLayout; +template +using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static_cast(1024) * 1024 * 1024; @@ -134,8 +136,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv forward --------------------- - typename platform::CudnnDataType::ScalingParamType alpha = 1.0f, - beta = 0.0f; + ScalingParamType alpha = 1.0f, beta = 0.0f; for (int i = 0; i < groups; i++) { PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, @@ -282,8 +283,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { platform::CUDAPlace gpu = boost::get(ctx.GetPlace()); cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); // ------------------- cudnn conv backward data --------------------- - typename platform::CudnnDataType::ScalingParamType alpha = 1.0f, - beta = 0.0f; + ScalingParamType alpha = 1.0f, beta = 0.0f; if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 781d96981e4c033d9287ab3de9860dfd9fcd2875..39c862b03ad497dca5c38ccecff20be510ab60e5 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -24,6 +24,8 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor; using DataLayout = platform::DataLayout; using PoolingMode = platform::PoolingMode; +template +using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; template class PoolCUDNNOpKernel : public framework::OpKernel { @@ -78,8 +80,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel { // ------------------- cudnn pool algorithm --------------------- auto handle = ctx.cuda_device_context().cudnn_handle(); - T alpha = 1.0f, beta = 0.0f; - + ScalingParamType alpha = 1.0f, beta = 0.0f; PADDLE_ENFORCE(platform::dynload::cudnnPoolingForward( handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta, cudnn_output_desc, output_data)); @@ -144,8 +145,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { // ------------------- cudnn pool algorithm --------------------- auto handle = ctx.cuda_device_context().cudnn_handle(); - T alpha = 1.0f, beta = 0.0f; - + ScalingParamType alpha = 1.0f, beta = 0.0f; if (input_grad) { T *input_grad_data = input_grad->mutable_data(ctx.GetPlace()); // Because beta is zero, it is unnecessary to reset input_grad. @@ -162,17 +162,19 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; -REGISTER_OP_KERNEL(pool2d, CUDNN, ::paddle::platform::CUDAPlace, +REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel, - ops::PoolCUDNNOpKernel); -REGISTER_OP_KERNEL(pool2d_grad, CUDNN, ::paddle::platform::CUDAPlace, + ops::PoolCUDNNOpKernel, + ops::PoolCUDNNOpKernel); +REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, ops::PoolCUDNNGradOpKernel); -REGISTER_OP_KERNEL(pool3d, CUDNN, ::paddle::platform::CUDAPlace, +REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel, ops::PoolCUDNNOpKernel); -REGISTER_OP_KERNEL(pool3d_grad, CUDNN, ::paddle::platform::CUDAPlace, +REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, ops::PoolCUDNNGradOpKernel); diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index d78da10016a0e2b1d9a0ca9f3dfe4e8009bbe61d..b144ec5f7d315cb340dcd94b4a519bfcfd2a0e66 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -124,11 +124,15 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( } #endif + auto input_data_type = framework::ToDataType(ctx.Input("X")->type()); + if (input_data_type == framework::proto::VarType::FP16) { + PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, + "float16 can only be used when CUDNN is used"); + } std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, + library_); } Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 6a42f763a6c436f5d33569dc65f711c2930a8b2e..8393f7827b1c7d361ebea72f2cfc6033268772f0 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -483,9 +483,9 @@ class OpTest(unittest.TestCase): input: input numpy array Returns: - input: if the dtype of input is np.float16, its dtype will be - changed to np.uint16 so that the internal memory will be - reinterpreted input as of dtype np.uint16. + input: The dtype of input will be changed to np.uint16 if + it is originally np.float16, such that the internal memory + of input will be reinterpreted as of dtype np.uint16. """ if input.dtype == np.float16: input.dtype = np.uint16 diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 7913b98240fbd0aaa8d94911e68f61a0e416e7c8..4b6e3fb69a12095c77f343515fe3b6d1f3fccb14 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -63,12 +63,13 @@ def conv2d_forward_naive(input, filter, group, conv_param): class TestConv2dOp(OpTest): def setUp(self): + self.op_type = "conv2d" self.use_cudnn = False self.use_mkldnn = False - self.init_op_type() + self.dtype = np.float32 + self.init_kernel_type() self.init_group() self.init_dilation() - self.init_data_type() self.init_test_case() conv2d_param = { @@ -159,17 +160,14 @@ class TestConv2dOp(OpTest): f_c = self.input_size[1] / self.groups self.filter_size = [6, f_c, 3, 3] - def init_data_type(self): - self.dtype = np.float32 - def init_dilation(self): self.dilations = [1, 1] def init_group(self): self.groups = 1 - def init_op_type(self): - self.op_type = "conv2d" + def init_kernel_type(self): + pass class TestWithPad(TestConv2dOp): @@ -241,13 +239,13 @@ class TestWithInput1x1Filter1x1(TestConv2dOp): #----------------Conv2dCUDNN---------------- class TestCUDNN(TestConv2dOp): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "conv2d" -class TestFP16CUDNN(TestCUDNN): - def init_data_type(self): +class TestFP16CUDNN(TestConv2dOp): + def init_kernel_type(self): + self.use_cudnn = True self.dtype = np.float16 def test_check_output(self): @@ -258,13 +256,13 @@ class TestFP16CUDNN(TestCUDNN): class TestCUDNNWithPad(TestWithPad): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "conv2d" -class TestFP16CUDNNWithPad(TestCUDNNWithPad): - def init_data_type(self): +class TestFP16CUDNNWithPad(TestWithPad): + def init_kernel_type(self): + self.use_cudnn = True self.dtype = np.float16 def test_check_output(self): @@ -275,13 +273,13 @@ class TestFP16CUDNNWithPad(TestCUDNNWithPad): class TestCUDNNWithStride(TestWithStride): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "conv2d" -class TestFP16CUDNNWithStride(TestCUDNNWithStride): - def init_data_type(self): +class TestFP16CUDNNWithStride(TestWithStride): + def init_kernel_type(self): + self.use_cudnn = True self.dtype = np.float16 def test_check_output(self): @@ -292,13 +290,13 @@ class TestFP16CUDNNWithStride(TestCUDNNWithStride): class TestCUDNNWithGroup(TestWithGroup): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "conv2d" -class TestFP16CUDNNWithGroup(TestCUDNNWithGroup): - def init_data_type(self): +class TestFP16CUDNNWithGroup(TestWithGroup): + def init_kernel_type(self): + self.use_cudnn = True self.dtype = np.float16 def test_check_output(self): @@ -309,13 +307,13 @@ class TestFP16CUDNNWithGroup(TestCUDNNWithGroup): class TestCUDNNWith1x1(TestWith1x1): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "conv2d" -class TestFP16CUDNNWith1x1(TestCUDNNWith1x1): - def init_data_type(self): +class TestFP16CUDNNWith1x1(TestWith1x1): + def init_kernel_type(self): + self.use_cudnn = True self.dtype = np.float16 def test_check_output(self): @@ -326,13 +324,13 @@ class TestFP16CUDNNWith1x1(TestCUDNNWith1x1): class TestCUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "conv2d" -class TestFP16CUDNNWithInput1x1Filter1x1(TestCUDNNWithInput1x1Filter1x1): - def init_data_type(self): +class TestFP16CUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1): + def init_kernel_type(self): + self.use_cudnn = True self.dtype = np.float16 def test_check_output(self): @@ -375,21 +373,18 @@ class TestDepthwiseConv2(TestConv2dOp): #----------------Conv2dMKLDNN---------------- class TestMKLDNN(TestConv2dOp): - def init_op_type(self): + def init_kernel_type(self): self.use_mkldnn = True - self.op_type = "conv2d" class TestMKLDNNWithPad(TestWithPad): - def init_op_type(self): + def init_kernel_type(self): self.use_mkldnn = True - self.op_type = "conv2d" class TestMKLDNNWithStride(TestWithStride): - def init_op_type(self): + def init_kernel_type(self): self.use_mkldnn = True - self.op_type = "conv2d" if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_op.py b/python/paddle/fluid/tests/unittests/test_pool2d_op.py index 964d78f1966aa10e36eeaabe943d44e002d50293..764fa575fba1615de3171e848890b3836e640849 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_op.py @@ -78,20 +78,22 @@ def avg_pool2D_forward_naive(x, class TestPool2d_Op(OpTest): def setUp(self): + self.op_type = "pool2d" self.use_cudnn = False self.use_mkldnn = False + self.dtype = np.float32 self.init_test_case() self.init_global_pool() - self.init_op_type() + self.init_kernel_type() self.init_pool_type() self.init_ceil_mode() if self.global_pool: self.paddings = [0 for _ in range(len(self.paddings))] - input = np.random.random(self.shape).astype("float32") + input = np.random.random(self.shape).astype(self.dtype) output = self.pool2D_forward_naive(input, self.ksize, self.strides, self.paddings, self.global_pool, - self.ceil_mode).astype("float32") - self.inputs = {'X': input} + self.ceil_mode).astype(self.dtype) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} self.attrs = { 'strides': self.strides, @@ -105,7 +107,7 @@ class TestPool2d_Op(OpTest): 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter } - self.outputs = {'Out': output.astype('float32')} + self.outputs = {'Out': output} def test_check_output(self): if self.use_cudnn: @@ -115,6 +117,8 @@ class TestPool2d_Op(OpTest): self.check_output() def test_check_grad(self): + if self.dtype == np.float16: + return if self.use_cudnn and self.pool_type != "max": place = core.CUDAPlace(0) self.check_grad_with_place( @@ -128,8 +132,8 @@ class TestPool2d_Op(OpTest): self.strides = [1, 1] self.paddings = [0, 0] - def init_op_type(self): - self.op_type = "pool2d" + def init_kernel_type(self): + pass def init_pool_type(self): self.pool_type = "avg" @@ -149,9 +153,6 @@ class TestCase1(TestPool2d_Op): self.strides = [1, 1] self.paddings = [0, 0] - def init_op_type(self): - self.op_type = "pool2d" - def init_pool_type(self): self.pool_type = "avg" self.pool2D_forward_naive = avg_pool2D_forward_naive @@ -167,9 +168,6 @@ class TestCase2(TestPool2d_Op): self.strides = [1, 1] self.paddings = [1, 1] - def init_op_type(self): - self.op_type = "pool2d" - def init_pool_type(self): self.pool_type = "avg" self.pool2D_forward_naive = avg_pool2D_forward_naive @@ -179,27 +177,18 @@ class TestCase2(TestPool2d_Op): class TestCase3(TestPool2d_Op): - def init_op_type(self): - self.op_type = "pool2d" - def init_pool_type(self): self.pool_type = "max" self.pool2D_forward_naive = max_pool2D_forward_naive class TestCase4(TestCase1): - def init_op_type(self): - self.op_type = "pool2d" - def init_pool_type(self): self.pool_type = "max" self.pool2D_forward_naive = max_pool2D_forward_naive class TestCase5(TestCase2): - def init_op_type(self): - self.op_type = "pool2d" - def init_pool_type(self): self.pool_type = "max" self.pool2D_forward_naive = max_pool2D_forward_naive @@ -207,39 +196,105 @@ class TestCase5(TestCase2): #--------------------test pool2d-------------------- class TestCUDNNCase1(TestPool2d_Op): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "pool2d" + + +class TestFP16CUDNNCase1(TestPool2d_Op): + def init_kernel_type(self): + self.use_cudnn = True + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) class TestCUDNNCase2(TestCase1): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "pool2d" + + +class TestFP16CUDNNCase2(TestCase1): + def init_kernel_type(self): + self.use_cudnn = True + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) class TestCUDNNCase3(TestCase2): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "pool2d" + + +class TestFP16CUDNNCase3(TestCase2): + def init_kernel_type(self): + self.use_cudnn = True + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) class TestCUDNNCase4(TestCase3): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "pool2d" + + +class TestFP16CUDNNCase4(TestCase3): + def init_kernel_type(self): + self.use_cudnn = True + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) class TestCUDNNCase5(TestCase4): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "pool2d" + + +class TestFP16CUDNNCase5(TestCase4): + def init_kernel_type(self): + self.use_cudnn = True + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) class TestCUDNNCase6(TestCase5): - def init_op_type(self): + def init_kernel_type(self): self.use_cudnn = True - self.op_type = "pool2d" + + +class TestFP16CUDNNCase6(TestCase5): + def init_kernel_type(self): + self.use_cudnn = True + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) class TestCeilModeCase1(TestCUDNNCase1): @@ -264,39 +319,33 @@ class TestCeilModeCase4(TestCase2): #--------------------test pool2d MKLDNN-------------------- class TestMKLDNNCase1(TestPool2d_Op): - def init_op_type(self): + def init_kernel_type(self): self.use_mkldnn = True - self.op_type = "pool2d" class TestMKLDNNCase2(TestCase1): - def init_op_type(self): + def init_kernel_type(self): self.use_mkldnn = True - self.op_type = "pool2d" class TestMKLDNNCase3(TestCase2): - def init_op_type(self): + def init_kernel_type(self): self.use_mkldnn = True - self.op_type = "pool2d" class TestMKLDNNCase4(TestCase3): - def init_op_type(self): + def init_kernel_type(self): self.use_mkldnn = True - self.op_type = "pool2d" class TestMKLDNNCase5(TestCase4): - def init_op_type(self): + def init_kernel_type(self): self.use_mkldnn = True - self.op_type = "pool2d" class TestMKLDNNCase6(TestCase5): - def init_op_type(self): + def init_kernel_type(self): self.use_mkldnn = True - self.op_type = "pool2d" if __name__ == '__main__':