From e21e5646a574b9e2fa299bacb3a8ee85472e84b5 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 10 Oct 2017 13:55:27 +0800 Subject: [PATCH] fix atomicAdd -> CudaAtomicAdd --- paddle/operators/math/pooling.cu | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 06263737a98..4d50121de42 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -144,7 +144,7 @@ __global__ void KernelMaxPool2DGrad( if (maxIndex != -1) { // atomic add - atomicAdd(input_grad + maxIndex, output_grad[index]); + platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); } } } @@ -278,9 +278,7 @@ class MaxPool2dGradFunctor { }; template class MaxPool2dGradFunctor; -// template class MaxPool2dGradFunctor; // The -// 64-bit floating-point version of atomicAdd() is only supported by devices of -// compute capability 6.x and higher. +template class MaxPool2dGradFunctor; template class Pool2dFunctor, float>; @@ -453,7 +451,7 @@ __global__ void KernelMaxPool3DGrad( } if (maxIdx != -1) { // atomic add - atomicAdd(input_grad + maxIdx, output_grad[index]); + platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]); } } } @@ -609,9 +607,7 @@ class MaxPool3dGradFunctor { }; template class MaxPool3dGradFunctor; -// template class MaxPool3dGradFunctor; // The -// 64-bit floating-point version of atomicAdd() is only supported by devices of -// compute capability 6.x and higher. +template class MaxPool3dGradFunctor; template class Pool3dFunctor, float>; -- GitLab