From f82384113b4c91c0642b38e4b8103566a8a8941a Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 25 Aug 2020 20:12:39 +0800 Subject: [PATCH] Fix atomicAdd in grid sample op and affine grid op (#26647) test=develop --- paddle/fluid/operators/affine_grid_op.cu | 12 ++++++------ paddle/fluid/operators/grid_sampler_op.cu | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/affine_grid_op.cu b/paddle/fluid/operators/affine_grid_op.cu index 66e3c741e5d..7aaaa0002c5 100644 --- a/paddle/fluid/operators/affine_grid_op.cu +++ b/paddle/fluid/operators/affine_grid_op.cu @@ -86,14 +86,14 @@ __global__ void affine_grid_grad_kernel(const int count, int n, int out_h, int theta_offset = n * 6; // 2 * 3; T out_grad_x = out_grad[index * 2]; - atomicAdd(theta_grad + theta_offset, out_grad_x * h_coor); - atomicAdd(theta_grad + theta_offset + 1, out_grad_x * w_coor); - atomicAdd(theta_grad + theta_offset + 2, out_grad_x); + platform::CudaAtomicAdd(theta_grad + theta_offset, out_grad_x * h_coor); + platform::CudaAtomicAdd(theta_grad + theta_offset + 1, out_grad_x * w_coor); + platform::CudaAtomicAdd(theta_grad + theta_offset + 2, out_grad_x); T out_grad_y = out_grad[index * 2 + 1]; - atomicAdd(theta_grad + theta_offset + 3, out_grad_y * h_coor); - atomicAdd(theta_grad + theta_offset + 4, out_grad_y * w_coor); - atomicAdd(theta_grad + theta_offset + 5, out_grad_y); + platform::CudaAtomicAdd(theta_grad + theta_offset + 3, out_grad_y * h_coor); + platform::CudaAtomicAdd(theta_grad + theta_offset + 4, out_grad_y * w_coor); + platform::CudaAtomicAdd(theta_grad + theta_offset + 5, out_grad_y); } } diff --git a/paddle/fluid/operators/grid_sampler_op.cu b/paddle/fluid/operators/grid_sampler_op.cu index 7e1e7b1e692..999f990448c 100644 --- a/paddle/fluid/operators/grid_sampler_op.cu +++ b/paddle/fluid/operators/grid_sampler_op.cu @@ -31,7 +31,7 @@ static __forceinline__ __device__ void atomic_add(T* data, int h, int w, int sH, int sW, int H, int W, T delta) { if (in_bounds(h, w, H, W)) { - atomicAdd(data + h * sH + w * sW, delta); + platform::CudaAtomicAdd(data + h * sH + w * sW, delta); } } -- GitLab