diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 709f7de2e43093114d096cbfca5b5d49293a6d3e..71740b8b0c63886d4a6d2665de671664f10f1729 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,6 +55,18 @@ function(op_library TARGET) set(pybind_flag 1) endif() + if ("${TARGET}" STREQUAL "compare_op") + set(pybind_flag 1) + file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n") + endif() + + # conv_op contains several operators + if ("${TARGET}" STREQUAL "conv_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(conv2d);\n") + endif() + # pool_op contains several operators if ("${TARGET}" STREQUAL "pool_op") set(pybind_flag 1) @@ -62,9 +74,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(pool2d);\n") endif() - if ("${TARGET}" STREQUAL "compare_op") + # pool_cudnn_op contains several operators + if ("${TARGET}" STREQUAL "pool_cudnn_op") set(pybind_flag 1) - file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n") + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") endif() # pool_with_index_op contains several operators @@ -74,25 +88,18 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") endif() - # conv_op contains several operators - if ("${TARGET}" STREQUAL "conv_op") - set(pybind_flag 1) - # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(conv2d);\n") - endif() - # conv_transpose_op contains several operators if ("${TARGET}" STREQUAL "conv_transpose_op") set(pybind_flag 1) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n") endif() - - # pool_cudnn_op contains several operators - if ("${TARGET}" STREQUAL "pool_cudnn_op") + + # conv_transpose_cudnn_op contains two operators + if ("${TARGET}" STREQUAL "conv_transpose_cudnn_op") set(pybind_flag 1) # It's enough to just adding one operator to pybind - file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") + file(APPEND ${pybind_file} "USE_OP(conv2d_transpose_cudnn);\n") endif() # save_restore_op contains several operators diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cc b/paddle/operators/conv_transpose_cudnn_op.cc similarity index 82% rename from paddle/operators/conv2d_transpose_cudnn_op.cc rename to paddle/operators/conv_transpose_cudnn_op.cc index fce1357ce5af5f11ccc5941690431393301e6725..7ec3319cd0cd41a812026aae09542d92825c1a69 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cc +++ b/paddle/operators/conv_transpose_cudnn_op.cc @@ -48,3 +48,14 @@ REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL( conv2d_transpose_cudnn_grad, ops::GemmConvTransposeGradKernel); + +REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp, + ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad, + ops::ConvTransposeOpGrad); + +REGISTER_OP_CPU_KERNEL( + conv3d_transpose_cudnn, + ops::GemmConvTransposeKernel); +REGISTER_OP_CPU_KERNEL( + conv3d_transpose_cudnn_grad, + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv2d_transpose_cudnn_op.cu b/paddle/operators/conv_transpose_cudnn_op.cu similarity index 97% rename from paddle/operators/conv2d_transpose_cudnn_op.cu rename to paddle/operators/conv_transpose_cudnn_op.cu index 694526ec01214acf2ec6a3d68d3cf072739ac185..cd31896f2c06e6b7e11841eb5e8533d717e76b16 100644 --- a/paddle/operators/conv2d_transpose_cudnn_op.cu +++ b/paddle/operators/conv_transpose_cudnn_op.cu @@ -237,3 +237,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn, ops::CudnnConvTransposeOpKernel); REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad, ops::CudnnConvTransposeGradOpKernel); + +REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn, + ops::CudnnConvTransposeOpKernel); +REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad, + ops::CudnnConvTransposeGradOpKernel);