diff --git a/paddle/fluid/operators/pad2d_op.cc b/paddle/fluid/operators/pad2d_op.cc index 71919667ba1bbaf3456b074c261dd714078681bf..9dfdcfcf13969ca030466c1ae66a28817e6947d1 100644 --- a/paddle/fluid/operators/pad2d_op.cc +++ b/paddle/fluid/operators/pad2d_op.cc @@ -661,5 +661,8 @@ REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker, ops::Pad2dOpGradMaker); REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad, ops::Pad2dOpGradNoNeedBufferVarsInference); -REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel); -REGISTER_OP_CPU_KERNEL(pad2d_grad, ops::Pad2dGradCPUKernel); +REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel, + ops::Pad2dCPUKernel, ops::Pad2dCPUKernel, + ops::Pad2dCPUKernel); +REGISTER_OP_CPU_KERNEL(pad2d_grad, ops::Pad2dGradCPUKernel, + ops::Pad2dGradCPUKernel); diff --git a/paddle/fluid/operators/pad2d_op.cu b/paddle/fluid/operators/pad2d_op.cu index 72eca08b06b144335424a669241b5754beda758d..05fad5b3bbce10ead1cef0dc3aab79a2b1437673 100644 --- a/paddle/fluid/operators/pad2d_op.cu +++ b/paddle/fluid/operators/pad2d_op.cu @@ -215,8 +215,9 @@ __global__ void Pad2DGradReflectNCHW(const int out_size, T* d_in_data, in_w = max(in_w, -in_w); in_h = min(in_h, 2 * in_height - in_h - 2); in_w = min(in_w, 2 * in_width - in_w - 2); - atomicAdd(&d_in_data[(nc * in_height + in_h) * in_width + in_w], - d_out_data[out_index]); + platform::CudaAtomicAdd( + &d_in_data[(nc * in_height + in_h) * in_width + in_w], + d_out_data[out_index]); } } @@ -240,7 +241,7 @@ __global__ void Pad2DGradReflectNHWC(const int out_size, T* d_in_data, in_w = max(in_w, -in_w); in_h = min(in_h, in_height * 2 - in_h - 2); in_w = min(in_w, in_width * 2 - in_w - 2); - atomicAdd( + platform::CudaAtomicAdd( &d_in_data[((n * in_height + in_h) * in_width + in_w) * channels + c], d_out_data[out_index]); } @@ -260,8 +261,9 @@ __global__ void Pad2DGradEdgeNCHW(const int out_size, T* d_in_data, nc /= out_height; const int in_h = min(in_height - 1, max(out_h - pad_top, 0)); const int in_w = min(in_width - 1, max(out_w - pad_left, 0)); - atomicAdd(&d_in_data[(nc * in_height + in_h) * in_width + in_w], - d_out_data[out_index]); + platform::CudaAtomicAdd( + &d_in_data[(nc * in_height + in_h) * in_width + in_w], + d_out_data[out_index]); } } @@ -281,7 +283,7 @@ __global__ void Pad2DGradEdgeNHWC(const int out_size, T* d_in_data, n /= out_height; const int in_h = min(in_height - 1, max(out_h - pad_top, 0)); const int in_w = min(in_width - 1, max(out_w - pad_left, 0)); - atomicAdd( + platform::CudaAtomicAdd( &d_in_data[((n * in_height + in_h) * in_width + in_w) * channels + c], d_out_data[out_index]); } @@ -459,5 +461,8 @@ class Pad2dGradCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(pad2d, ops::Pad2dCUDAKernel); -REGISTER_OP_CUDA_KERNEL(pad2d_grad, ops::Pad2dGradCUDAKernel); +REGISTER_OP_CUDA_KERNEL(pad2d, ops::Pad2dCUDAKernel, + ops::Pad2dCUDAKernel, ops::Pad2dCUDAKernel, + ops::Pad2dCUDAKernel); +REGISTER_OP_CUDA_KERNEL(pad2d_grad, ops::Pad2dGradCUDAKernel, + ops::Pad2dGradCUDAKernel);