未验证 提交 667082c0 编写于 作者: C carryyu 提交者: GitHub

fix P40 topk: Make the optimized topk compatible with P40. (#46547)

* fix P40 topk: Make the optimized topk compatible with P40.

* fix P40 topk: Make the optimized topk compatible with P40.

* fix P40 topk: Make the optimized topk compatible with P40.
上级 d71f1b3f
...@@ -354,18 +354,13 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[], ...@@ -354,18 +354,13 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
} }
if (--(*k) == 0) break; if (--(*k) == 0) break;
if (MaxLength < 5) {
if (*beam >= MaxLength) break;
} else {
unsigned mask = 0u; unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true); CREATE_SHFL_MASK(mask, true);
if (tid_max / 32 == wid) { if (tid_max / 32 == wid) {
if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == if (platform::CudaShuffleSync(mask, *beam, tid_max % 32, 32) == MaxLength)
MaxLength)
break; break;
} }
} }
}
} }
/** /**
......
...@@ -127,7 +127,8 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -127,7 +127,8 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
input_height)); input_height));
default: default:
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"the input k has error in the topk cuda kernel.")); "the input k has error when use getMaxLength function to get the "
"maxLength."));
}); });
default: default:
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
......
...@@ -205,7 +205,8 @@ void TopkKernel(const Context& dev_ctx, ...@@ -205,7 +205,8 @@ void TopkKernel(const Context& dev_ctx,
largest)); largest));
default: default:
PADDLE_THROW( PADDLE_THROW(
errors::Fatal("the input k has error in the topk cuda kernel.")); errors::Fatal("the input k has error when use getMaxLength "
"function to get the maxLength."));
}); });
#endif #endif
default: default:
...@@ -313,7 +314,8 @@ void TopkKernel(const Context& dev_ctx, ...@@ -313,7 +314,8 @@ void TopkKernel(const Context& dev_ctx,
largest)); largest));
default: default:
PADDLE_THROW( PADDLE_THROW(
errors::Fatal("the input k has error in the topk cuda kernel.")); errors::Fatal("the input k has error when use getMaxLength "
"function to get the maxLength."));
}); });
#endif #endif
default: default:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册