未验证 提交 667082c0 编写于 作者: C carryyu 提交者: GitHub

fix P40 topk: Make the optimized topk compatible with P40. (#46547)

* fix P40 topk: Make the optimized topk compatible with P40.

* fix P40 topk: Make the optimized topk compatible with P40.

* fix P40 topk: Make the optimized topk compatible with P40.
上级 d71f1b3f
...@@ -354,16 +354,11 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[], ...@@ -354,16 +354,11 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
} }
if (--(*k) == 0) break; if (--(*k) == 0) break;
if (MaxLength < 5) { unsigned mask = 0u;
if (*beam >= MaxLength) break; CREATE_SHFL_MASK(mask, true);
} else { if (tid_max / 32 == wid) {
unsigned mask = 0u; if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == MaxLength)
CREATE_SHFL_MASK(mask, true); break;
if (tid_max / 32 == wid) {
if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) ==
MaxLength)
break;
}
} }
} }
} }
......
...@@ -127,7 +127,8 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -127,7 +127,8 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
input_height)); input_height));
default: default:
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"the input k has error in the topk cuda kernel.")); "the input k has error when use getMaxLength function to get the "
"maxLength."));
}); });
default: default:
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
......
...@@ -205,7 +205,8 @@ void TopkKernel(const Context& dev_ctx, ...@@ -205,7 +205,8 @@ void TopkKernel(const Context& dev_ctx,
largest)); largest));
default: default:
PADDLE_THROW( PADDLE_THROW(
errors::Fatal("the input k has error in the topk cuda kernel.")); errors::Fatal("the input k has error when use getMaxLength "
"function to get the maxLength."));
}); });
#endif #endif
default: default:
...@@ -313,7 +314,8 @@ void TopkKernel(const Context& dev_ctx, ...@@ -313,7 +314,8 @@ void TopkKernel(const Context& dev_ctx,
largest)); largest));
default: default:
PADDLE_THROW( PADDLE_THROW(
errors::Fatal("the input k has error in the topk cuda kernel.")); errors::Fatal("the input k has error when use getMaxLength "
"function to get the maxLength."));
}); });
#endif #endif
default: default:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册