From 3a507b44bdf41f082145e8c028adfb976c8571ac Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 13 Nov 2017 17:55:08 +0800 Subject: [PATCH] add conv3d_trans_cudnn_op --- paddle/operators/CMakeLists.txt | 33 +++++++++++-------- ...cudnn_op.cc => conv_transpose_cudnn_op.cc} | 11 +++++++ ...cudnn_op.cu => conv_transpose_cudnn_op.cu} | 5 +++ 3 files changed, 36 insertions(+), 13 deletions(-) rename paddle/operators/{conv2d_transpose_cudnn_op.cc => conv_transpose_cudnn_op.cc} (82%) rename paddle/operators/{conv2d_transpose_cudnn_op.cu => conv_transpose_cudnn_op.cu} (97%) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 709f7de2e..71740b8b0 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,6 +55,18 @@ function(op_library TARGET) set(pybind_flag 1) endif() + if ("${TARGET}" STREQUAL "compare_op") + set(pybind_flag 1) + file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n") + endif() + + # conv_op contains several operators + if ("${TARGET}" STREQUAL "conv_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(conv2d);\n") + endif() + # pool_op contains several operators if ("${TARGET}" STREQUAL "pool_op") set(pybind_flag 1) @@ -62,9 +74,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(pool2d);\n") endif() - if ("${TARGET}" STREQUAL "compare_op") + # pool_cudnn_op contains several operators + if ("${TARGET}" STREQUAL "pool_cudnn_op") set(pybind_flag 1) - file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n") + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") endif() # pool_with_index_op contains several operators @@ -74,25 +88,18 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") endif() - # conv_op contains several operators - if ("${TARGET}" STREQUAL "conv_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(conv2d);\n") - endif() - # conv_transpose_op contains several operators if ("${TARGET}" STREQUAL "conv_transpose_op") set(pybind_flag 1) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n") endif() - - # pool_cudnn_op contains several operators - if ("${TARGET}" STREQUAL "pool_cudnn_op") + + # conv_transpose_cudnn_op contains two operators + if ("${TARGET}" STREQUAL "conv_transpose_cudnn_op") set(pybind_flag 1) # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") + file(APPEND ${pybind_file} "USE_OP(conv2d_transpose_cudnn);\n") endif() # save_restore_op contains several operators diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cc b/paddle/operators/conv_transpose_cudnn_op.cc similarity index 82% rename from paddle/operators/conv2d_transpose_cudnn_op.cc rename to paddle/operators/conv_transpose_cudnn_op.cc index fce1357ce..7ec3319cd 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cc +++ b/paddle/operators/conv_transpose_cudnn_op.cc @@ -48,3 +48,14 @@ REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL( conv2d_transpose_cudnn_grad, ops::GemmConvTransposeGradKernel); + +REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp, + ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad, + ops::ConvTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv3d_transpose_cudnn, + ops::GemmConvTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv3d_transpose_cudnn_grad, + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cu b/paddle/operators/conv_transpose_cudnn_op.cu similarity index 97% rename from paddle/operators/conv2d_transpose_cudnn_op.cu rename to paddle/operators/conv_transpose_cudnn_op.cu index 694526ec0..cd31896f2 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cu +++ b/paddle/operators/conv_transpose_cudnn_op.cu @@ -237,3 +237,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn, ops::CudnnConvTransposeOpKernel); REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad, ops::CudnnConvTransposeGradOpKernel); + +REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn, + ops::CudnnConvTransposeOpKernel); +REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad, + ops::CudnnConvTransposeGradOpKernel); -- GitLab