diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 687d741cb22a081eab18c61752200b9fd48f68a7..7a36a9b21aa6a1b415ac5a232e65eda8051c87f8 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -225,11 +225,15 @@ REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, ops::ConvOpGrad); REGISTER_OP_CPU_KERNEL(conv2d, - ops::GemmConvKernel); + ops::GemmConvKernel, + ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( - conv2d_grad, ops::GemmConvGradKernel); + conv2d_grad, ops::GemmConvGradKernel, + ops::GemmConvGradKernel); REGISTER_OP_CPU_KERNEL(conv3d, - ops::GemmConvKernel); + ops::GemmConvKernel, + ops::GemmConvKernel); REGISTER_OP_CPU_KERNEL( - conv3d_grad, ops::GemmConvGradKernel); + conv3d_grad, ops::GemmConvGradKernel, + ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_op.cu.cc b/paddle/operators/conv_op.cu.cc index 8e6f9da455b7291049aee57189dae15b8bcc2150..546451234a1ed1a4d3119cb175c6d37ae3f0aac1 100644 --- a/paddle/operators/conv_op.cu.cc +++ b/paddle/operators/conv_op.cu.cc @@ -17,11 +17,15 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(conv2d, - ops::GemmConvKernel); + ops::GemmConvKernel, + ops::GemmConvKernel); REGISTER_OP_GPU_KERNEL( - conv2d_grad, ops::GemmConvGradKernel); + conv2d_grad, ops::GemmConvGradKernel, + ops::GemmConvGradKernel); REGISTER_OP_GPU_KERNEL(conv3d, - ops::GemmConvKernel); + ops::GemmConvKernel, + ops::GemmConvKernel); REGISTER_OP_GPU_KERNEL( - conv3d_grad, ops::GemmConvGradKernel); + conv3d_grad, ops::GemmConvGradKernel, + ops::GemmConvGradKernel); diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index 310e3f5c937bd1345663b2c2307610a485a027ef..3e55ef036a7fb976117054574d1347fa943acd55 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -185,17 +185,21 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker, REGISTER_OP_CPU_KERNEL( conv2d_transpose, - ops::GemmConvTransposeKernel); + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); REGISTER_OP_CPU_KERNEL( conv2d_transpose_grad, - ops::GemmConvTransposeGradKernel); + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker, conv3d_transpose_grad, ops::ConvTransposeOpGrad); REGISTER_OP_CPU_KERNEL( conv3d_transpose, - ops::GemmConvTransposeKernel); + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); REGISTER_OP_CPU_KERNEL( conv3d_transpose_grad, - ops::GemmConvTransposeGradKernel); + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); diff --git a/paddle/operators/conv_transpose_op.cu.cc b/paddle/operators/conv_transpose_op.cu.cc index 401cddb379ced134b800d2a078fe130a2850fbb2..4165eb0c7b048b83bbd94c57b971530043b66545 100644 --- a/paddle/operators/conv_transpose_op.cu.cc +++ b/paddle/operators/conv_transpose_op.cu.cc @@ -18,14 +18,18 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( conv2d_transpose, - ops::GemmConvTransposeKernel); + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); REGISTER_OP_GPU_KERNEL( conv2d_transpose_grad, - ops::GemmConvTransposeGradKernel); + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel); REGISTER_OP_GPU_KERNEL( conv3d_transpose, - ops::GemmConvTransposeKernel); + ops::GemmConvTransposeKernel, + ops::GemmConvTransposeKernel); REGISTER_OP_GPU_KERNEL( conv3d_transpose_grad, - ops::GemmConvTransposeGradKernel); + ops::GemmConvTransposeGradKernel, + ops::GemmConvTransposeGradKernel);