提交 3a507b44 编写于 作者: C chengduoZH

add conv3d_trans_cudnn_op

上级 39715861
...@@ -55,6 +55,18 @@ function(op_library TARGET) ...@@ -55,6 +55,18 @@ function(op_library TARGET)
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
if ("${TARGET}" STREQUAL "compare_op")
set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n")
endif()
# conv_op contains several operators
if ("${TARGET}" STREQUAL "conv_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
endif()
# pool_op contains several operators # pool_op contains several operators
if ("${TARGET}" STREQUAL "pool_op") if ("${TARGET}" STREQUAL "pool_op")
set(pybind_flag 1) set(pybind_flag 1)
...@@ -62,9 +74,11 @@ function(op_library TARGET) ...@@ -62,9 +74,11 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(pool2d);\n") file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
endif() endif()
if ("${TARGET}" STREQUAL "compare_op") # pool_cudnn_op contains several operators
if ("${TARGET}" STREQUAL "pool_cudnn_op")
set(pybind_flag 1) set(pybind_flag 1)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n") # It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
endif() endif()
# pool_with_index_op contains several operators # pool_with_index_op contains several operators
...@@ -74,25 +88,18 @@ function(op_library TARGET) ...@@ -74,25 +88,18 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
endif() endif()
# conv_op contains several operators
if ("${TARGET}" STREQUAL "conv_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d);\n")
endif()
# conv_transpose_op contains several operators # conv_transpose_op contains several operators
if ("${TARGET}" STREQUAL "conv_transpose_op") if ("${TARGET}" STREQUAL "conv_transpose_op")
set(pybind_flag 1) set(pybind_flag 1)
# It's enough to just adding one operator to pybind # It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n") file(APPEND ${pybind_file} "USE_OP(conv2d_transpose);\n")
endif() endif()
# pool_cudnn_op contains several operators # conv_transpose_cudnn_op contains two operators
if ("${TARGET}" STREQUAL "pool_cudnn_op") if ("${TARGET}" STREQUAL "conv_transpose_cudnn_op")
set(pybind_flag 1) set(pybind_flag 1)
# It's enough to just adding one operator to pybind # It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") file(APPEND ${pybind_file} "USE_OP(conv2d_transpose_cudnn);\n")
endif() endif()
# save_restore_op contains several operators # save_restore_op contains several operators
......
...@@ -48,3 +48,14 @@ REGISTER_OP_CPU_KERNEL( ...@@ -48,3 +48,14 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad, conv2d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
...@@ -237,3 +237,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn, ...@@ -237,3 +237,8 @@ REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>); ops::CudnnConvTransposeOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad, REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>); ops::CudnnConvTransposeGradOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>);
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册