提交 e21e5646 编写于 作者: C chengduoZH

fix atomicAdd -> CudaAtomicAdd

上级 6c6474cb
...@@ -144,7 +144,7 @@ __global__ void KernelMaxPool2DGrad( ...@@ -144,7 +144,7 @@ __global__ void KernelMaxPool2DGrad(
if (maxIndex != -1) { if (maxIndex != -1) {
// atomic add // atomic add
atomicAdd(input_grad + maxIndex, output_grad[index]); platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
} }
} }
} }
...@@ -278,9 +278,7 @@ class MaxPool2dGradFunctor<platform::GPUPlace, T> { ...@@ -278,9 +278,7 @@ class MaxPool2dGradFunctor<platform::GPUPlace, T> {
}; };
template class MaxPool2dGradFunctor<platform::GPUPlace, float>; template class MaxPool2dGradFunctor<platform::GPUPlace, float>;
// template class MaxPool2dGradFunctor<platform::GPUPlace, double>; // The template class MaxPool2dGradFunctor<platform::GPUPlace, double>;
// 64-bit floating-point version of atomicAdd() is only supported by devices of
// compute capability 6.x and higher.
template class Pool2dFunctor<platform::GPUPlace, template class Pool2dFunctor<platform::GPUPlace,
paddle::operators::math::MaxPool<float>, float>; paddle::operators::math::MaxPool<float>, float>;
...@@ -453,7 +451,7 @@ __global__ void KernelMaxPool3DGrad( ...@@ -453,7 +451,7 @@ __global__ void KernelMaxPool3DGrad(
} }
if (maxIdx != -1) { if (maxIdx != -1) {
// atomic add // atomic add
atomicAdd(input_grad + maxIdx, output_grad[index]); platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
} }
} }
} }
...@@ -609,9 +607,7 @@ class MaxPool3dGradFunctor<platform::GPUPlace, T> { ...@@ -609,9 +607,7 @@ class MaxPool3dGradFunctor<platform::GPUPlace, T> {
}; };
template class MaxPool3dGradFunctor<platform::GPUPlace, float>; template class MaxPool3dGradFunctor<platform::GPUPlace, float>;
// template class MaxPool3dGradFunctor<platform::GPUPlace, double>; // The template class MaxPool3dGradFunctor<platform::GPUPlace, double>;
// 64-bit floating-point version of atomicAdd() is only supported by devices of
// compute capability 6.x and higher.
template class Pool3dFunctor<platform::GPUPlace, template class Pool3dFunctor<platform::GPUPlace,
paddle::operators::math::MaxPool<float>, float>; paddle::operators::math::MaxPool<float>, float>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册