提交 3a507b44 编写于 作者: C chengduoZH

add conv3d_trans_cudnn_op

上级 39715861
......@@ -55,6 +55,18 @@ function(op_library TARGET)
set(pybind_flag 1)
endif()
if ("${TARGET}" STREQUAL "compare_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
endif()
# conv_op contains several operators
if ("${TARGET}" STREQUAL "conv_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
endif()
# pool_op contains several operators
if ("${TARGET}" STREQUAL "pool_op")
set(pybind_flag 1)
......@@ -62,9 +74,11 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
endif()
if ("${TARGET}" STREQUAL "compare_op")
# pool_cudnn_op contains several operators
if ("${TARGET}" STREQUAL "pool_cudnn_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
endif()
# pool_with_index_op contains several operators
......@@ -74,25 +88,18 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
endif()
# conv_op contains several operators
if ("${TARGET}" STREQUAL "conv_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
endif()
# conv_transpose_op contains several operators
if ("${TARGET}" STREQUAL "conv_transpose_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n")
endif()
# pool_cudnn_op contains several operators
if ("${TARGET}" STREQUAL "pool_cudnn_op")
# conv_transpose_cudnn_op contains two operators
if ("${TARGET}" STREQUAL "conv_transpose_cudnn_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
file(APPEND ${pybind_file} "USE_OP(conv2d_transpose_cudnn);\n")
endif()
# save_restore_op contains several operators
......
......@@ -48,3 +48,14 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
......@@ -237,3 +237,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册