From 01c37c8e02ecbd17cecb20195203b898775b3994 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Tue, 22 Dec 2020 13:50:46 +0800 Subject: [PATCH] refine the compiler error for half2 operation (#29816) --- paddle/fluid/operators/elementwise/elementwise_add_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index e78b0c03fc..731cef3d36 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -183,6 +183,7 @@ template __global__ void VecFP16MatrixColReduce(const __half2 *__restrict__ in, __half2 *__restrict__ out, size_t width, size_t height) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) int idx = threadIdx.x + blockIdx.x * blockDim.x; int by = blockIdx.y; __half2 zero = __half2half2(static_cast<__half>(0)); @@ -196,6 +197,7 @@ __global__ void VecFP16MatrixColReduce(const __half2 *__restrict__ in, atomicAdd(&(out[idx]), sum); } +#endif } template @@ -363,7 +365,6 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { int max_blocks = std::max(max_physical_threads / (block_x * block_y), 1); int theory_block = (width + blocks.x - 1) / blocks.x; dim3 grids(std::min(theory_block, max_blocks)); -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) if (std::is_same::value && width < 2048 && width % 2 == 0 && height % 64 == 0) { auto &dev_ctx = @@ -381,7 +382,6 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { width, height); return; } -#endif if (width / height < 32) { MatrixColReduce<<>>( -- GitLab