diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a719da2560291dbc7e98aadfae41d4692d8afcad..46c2833030c936119e98adcdd338245bbdaddce7 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -61,6 +61,18 @@ function(op_library TARGET) set(pybind_flag 1) endif() + if ("${TARGET}" STREQUAL "compare_op") + set(pybind_flag 1) + file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n") + endif() + + # conv_op contains several operators + if ("${TARGET}" STREQUAL "conv_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(conv2d);\n") + endif() + # pool_op contains several operators if ("${TARGET}" STREQUAL "pool_op") set(pybind_flag 1) @@ -68,9 +80,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(pool2d);\n") endif() - if ("${TARGET}" STREQUAL "compare_op") + # pool_cudnn_op contains several operators + if ("${TARGET}" STREQUAL "pool_cudnn_op") set(pybind_flag 1) - file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n") + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") endif() # pool_with_index_op contains several operators @@ -80,25 +94,18 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") endif() - # conv_op contains several operators - if ("${TARGET}" STREQUAL "conv_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(conv2d);\n") - endif() - # conv_transpose_op contains several operators if ("${TARGET}" STREQUAL "conv_transpose_op") set(pybind_flag 1) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n") endif() - - # pool_cudnn_op contains several operators - if ("${TARGET}" STREQUAL "pool_cudnn_op") + + # conv_transpose_cudnn_op contains two operators + if ("${TARGET}" STREQUAL "conv_transpose_cudnn_op") set(pybind_flag 1) # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") + file(APPEND ${pybind_file} "USE_OP(conv2d_transpose_cudnn);\n") endif() # save_restore_op contains several operators diff --git a/paddle/operators/conv_cudnn_op.cu.cc b/paddle/operators/conv_cudnn_op.cu.cc index 2aec4a2760260623c4c7054c590afa8e1c6c3fea..4900f7b086c869b496c492743c71ab7047c5f672 100644 --- a/paddle/operators/conv_cudnn_op.cu.cc +++ b/paddle/operators/conv_cudnn_op.cu.cc @@ -226,9 +226,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel { T alpha = 1.0f, beta = 0.0f; if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - auto t = framework::EigenVector::Flatten(*input_grad); - t.device(ctx.GetEigenDevice()) = - t.constant(static_cast(0)); + // Because beta is zero, it is unnecessary to reset input_grad. + for (int i = 0; i < groups; i++) { PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( handle, &alpha, cudnn_filter_desc, @@ -241,9 +240,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel { // ------------------- cudnn conv backward filter --------------------- if (filter_grad) { T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); - auto t = framework::EigenVector::Flatten(*filter_grad); - t.device(ctx.GetEigenDevice()) = - t.constant(static_cast(0)); + // Because beta is zero, it is unnecessary to reset filter_grad. + for (int i = 0; i < groups; i++) { PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 687d741cb22a081eab18c61752200b9fd48f68a7..7a36a9b21aa6a1b415ac5a232e65eda8051c87f8 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -225,11 +225,15 @@ REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, ops::ConvOpGrad); REGISTER_OP_CPU_KERNEL(conv2d, - ops::GemmConvKernel); + ops::GemmConvKernel, + ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( - conv2d_grad, ops::GemmConvGradKernel); + conv2d_grad, ops::GemmConvGradKernel, + ops::GemmConvGradKernel); REGISTER_OP_CPU_KERNEL(conv3d, - ops::GemmConvKernel); + ops::GemmConvKernel, + ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( - conv3d_grad, ops::GemmConvGradKernel); + conv3d_grad, ops::GemmConvGradKernel, + ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_op.cu.cc b/paddle/operators/conv_op.cu.cc index 8e6f9da455b7291049aee57189dae15b8bcc2150..546451234a1ed1a4d3119cb175c6d37ae3f0aac1 100644 --- a/paddle/operators/conv_op.cu.cc +++ b/paddle/operators/conv_op.cu.cc @@ -17,11 +17,15 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(conv2d, - ops::GemmConvKernel); + ops::GemmConvKernel, + ops::GemmConvKernel); REGISTER_OP_GPU_KERNEL( - conv2d_grad, ops::GemmConvGradKernel); + conv2d_grad, ops::GemmConvGradKernel, + ops::GemmConvGradKernel); REGISTER_OP_GPU_KERNEL(conv3d, - ops::GemmConvKernel); + ops::GemmConvKernel, + ops::GemmConvKernel); REGISTER_OP_GPU_KERNEL( - conv3d_grad, ops::GemmConvGradKernel); + conv3d_grad, ops::GemmConvGradKernel, + ops::GemmConvGradKernel); diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cc b/paddle/operators/conv_transpose_cudnn_op.cc similarity index 61% rename from paddle/operators/conv2d_transpose_cudnn_op.cc rename to paddle/operators/conv_transpose_cudnn_op.cc index fce1357ce5af5f11ccc5941690431393301e6725..dbd1bc3c3bc2d026f13ddcf62919db6cf7d87bc5 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cc +++ b/paddle/operators/conv_transpose_cudnn_op.cc @@ -23,7 +23,24 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker { framework::OpAttrChecker* op_checker) : Conv2DTransposeOpMaker(proto, op_checker) { AddAttr>("dilations", "dilations of convolution operator.") - .SetDefault(std::vector{1, 1}); + .SetDefault({1, 1}); + AddAttr("workspace_size_MB", + "workspace size for cudnn, in MB, " + "workspace is a section of GPU memory which will be " + "allocated/freed each time the operator runs, larger " + "workspace size can increase performance but also requires " + "better hardward. This size should be carefully setted.") + .SetDefault(4096); + } +}; + +class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker { + public: + CudnnConv3DTransposeOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : Conv3DTransposeOpMaker(proto, op_checker) { + AddAttr>("dilations", "dilations of convolution operator.") + .SetDefault({1, 1, 1}); AddAttr("workspace_size_MB", "workspace size for cudnn, in MB, " "workspace is a section of GPU memory which will be " @@ -48,3 +65,14 @@ REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL( conv2d_transpose_cudnn_grad, ops::GemmConvTransposeGradKernel); + +REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp, + ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad, + ops::ConvTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv3d_transpose_cudnn, + ops::GemmConvTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv3d_transpose_cudnn_grad, + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cu.cc b/paddle/operators/conv_transpose_cudnn_op.cu.cc similarity index 92% rename from paddle/operators/conv2d_transpose_cudnn_op.cu.cc rename to paddle/operators/conv_transpose_cudnn_op.cu.cc index eff058afc6cc5dacf2a054a33f352824865c1924..e2ba77086e737a07471f14e483cbd32ab1d4ee12 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cu.cc +++ b/paddle/operators/conv_transpose_cudnn_op.cu.cc @@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel { ScopedTensorDescriptor output_desc; ScopedFilterDescriptor filter_desc; ScopedConvolutionDescriptor conv_desc; - DataLayout layout = DataLayout::kNCHW; + DataLayout layout; + + if (strides.size() == 2U) { + layout = DataLayout::kNCHW; + } else { + layout = DataLayout::kNCDHW; + } - // N, M, H, W + // (N, M, H, W) or (N, M, D, H, W) cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize2int(input->dims())); - // N, C, O_h, O_w + // (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w) cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( layout, framework::vectorize2int(output->dims())); - // M, C, K_h, K_w + // (M, C, K_h, K_w) or (M, C, K_d, K_h, K_w) cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( layout, framework::vectorize2int(filter->dims())); cudnnConvolutionDescriptor_t cudnn_conv_desc = @@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { ScopedConvolutionDescriptor conv_desc; DataLayout layout = DataLayout::kNCHW; - // Input: (N, M, H, W) + // Input: (N, M, H, W) or (N, M, D, H, W) cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( layout, framework::vectorize2int(input->dims())); - // Output: (N, C, O_H, O_W) + // Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w) cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( layout, framework::vectorize2int(output_grad->dims())); - // Filter (M, C, K_H, K_W) + // Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w) cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( layout, framework::vectorize2int(filter->dims())); @@ -200,8 +206,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { T alpha = 1.0f, beta = 0.0f; if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - math::set_constant(ctx.device_context(), input_grad, 0); - + // Because beta is zero, it is unnecessary to reset input_grad. PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_filter_desc, filter_data, cudnn_conv_desc, data_algo, @@ -212,8 +217,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel { // ------------------- cudnn conv backward filter --------------------- if (filter_grad) { T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); - math::set_constant(ctx.device_context(), filter_grad, 0); - + // Because beta is zero, it is unnecessary to reset filter_grad. // Gradient with respect to the filter PADDLE_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( handle, &alpha, cudnn_output_desc, output_grad_data, cudnn_input_desc, @@ -234,3 +238,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn, ops::CudnnConvTransposeOpKernel); REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad, ops::CudnnConvTransposeGradOpKernel); + +REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn, + ops::CudnnConvTransposeOpKernel); +REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad, + ops::CudnnConvTransposeGradOpKernel); diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index 310e3f5c937bd1345663b2c2307610a485a027ef..3e55ef036a7fb976117054574d1347fa943acd55 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -185,17 +185,21 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker, REGISTER_OP_CPU_KERNEL( conv2d_transpose, - ops::GemmConvTransposeKernel); + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); REGISTER_OP_CPU_KERNEL( conv2d_transpose_grad, - ops::GemmConvTransposeGradKernel); + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker, conv3d_transpose_grad, ops::ConvTransposeOpGrad); REGISTER_OP_CPU_KERNEL( conv3d_transpose, - ops::GemmConvTransposeKernel); + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); REGISTER_OP_CPU_KERNEL( conv3d_transpose_grad, - ops::GemmConvTransposeGradKernel); + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv_transpose_op.cu.cc b/paddle/operators/conv_transpose_op.cu.cc index 401cddb379ced134b800d2a078fe130a2850fbb2..4165eb0c7b048b83bbd94c57b971530043b66545 100644 --- a/paddle/operators/conv_transpose_op.cu.cc +++ b/paddle/operators/conv_transpose_op.cu.cc @@ -18,14 +18,18 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( conv2d_transpose, - ops::GemmConvTransposeKernel); + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); REGISTER_OP_GPU_KERNEL( conv2d_transpose_grad, - ops::GemmConvTransposeGradKernel); + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); REGISTER_OP_GPU_KERNEL( conv3d_transpose, - ops::GemmConvTransposeKernel); + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); REGISTER_OP_GPU_KERNEL( conv3d_transpose_grad, - ops::GemmConvTransposeGradKernel); + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/pool_cudnn_op.cu.cc b/paddle/operators/pool_cudnn_op.cu.cc index 8711567b95fea355396173b5312d26d31f9ffb12..f9d8af3e1c5db49873979fdfeb17a32d16341a1a 100644 --- a/paddle/operators/pool_cudnn_op.cu.cc +++ b/paddle/operators/pool_cudnn_op.cu.cc @@ -135,8 +135,7 @@ class PoolCudnnGradOpKernel : public framework::OpKernel { if (input_grad) { T *input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - math::SetConstant set_zero; - set_zero(ctx.device_context(), input_grad, static_cast(0)); + // Because beta is zero, it is unnecessary to reset input_grad. PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward( handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data, diff --git a/paddle/platform/cudnn_helper.h b/paddle/platform/cudnn_helper.h index ce3421a3cb840e4c1e872eea12dedc1150c85962..dd48605b9ed688e4656d4cd1ddf1f298d0a50a9e 100644 --- a/paddle/platform/cudnn_helper.h +++ b/paddle/platform/cudnn_helper.h @@ -63,9 +63,10 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { } \ } while (false) -enum class DataLayout { +enum class DataLayout { // Not use kNHWC, kNCHW, + kNCDHW, kNCHW_VECT_C, }; @@ -107,12 +108,15 @@ class CudnnDataType { } }; -inline cudnnTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) { +inline cudnnTensorFormat_t GetCudnnTensorFormat( + const DataLayout& order) { // Not use switch (order) { case DataLayout::kNHWC: return CUDNN_TENSOR_NHWC; case DataLayout::kNCHW: return CUDNN_TENSOR_NCHW; + case DataLayout::kNCDHW: + return CUDNN_TENSOR_NCHW; // TODO(chengduoZH) : add CUDNN_TENSOR_NCDHW default: PADDLE_THROW("Unknown cudnn equivalent for order"); } @@ -139,7 +143,7 @@ class ScopedTensorDescriptor { strides[i] = dims[i + 1] * strides[i + 1]; } // Update tensor descriptor dims setting if groups > 1 - // FIXME(typhoonzero): Assume using NCHW order + // FIXME(typhoonzero): Assume using NCHW or NCDHW order std::vector dims_with_group(dims.begin(), dims.end()); // copy if (groups > 1) { dims_with_group[1] = dims_with_group[1] / groups; @@ -176,9 +180,10 @@ class ScopedFilterDescriptor { const cudnnDataType_t type, const std::vector& kernel, const int groups = 1) { - // filter layout: MCHW, where M is the number of + // filter layout: MCHW(MCDHW), where M is the number of // output image channels, C is the number of input image channels, - // H and W is height and width of filter. + // D is the depth of the filter, H is the height of the filter, and W is the + // width of the filter. std::vector kernel_with_group(kernel.begin(), kernel.end()); if (groups > 1) { // M /= groups diff --git a/python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py b/python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py index 59a32c40821f2109306e898a6a798fea52b1e0ca..8fd34b87bfea91307f52fdcbb9f71f2e1a9c6c56 100644 --- a/python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py +++ b/python/paddle/v2/fluid/tests/test_conv3d_transpose_op.py @@ -108,5 +108,11 @@ class TestWithStride(TestConv3dTransposeOp): self.filter_size = [f_c, 6, 3, 3, 3] +# ------------ test_cudnn ------------ +class TestCudnn(TestConv3dTransposeOp): + def init_op_type(self): + self.op_type = "conv3d_transpose_cudnn" + + if __name__ == '__main__': unittest.main()