未验证 提交 0615815d 编写于 作者: N niuliling123 提交者: GitHub

Fix a bug in IndexKernel data overflow (#39891)

上级 b56ac35c
......@@ -31,24 +31,24 @@ namespace operators {
namespace kps = phi::kps;
template <typename T, typename Functor, int VecSize>
__global__ void VectorizedIndexKernel(T *out, int numel, int main_offset,
__global__ void VectorizedIndexKernel(T *out, size_t numel, size_t main_offset,
Functor func) {
int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
int args[VecSize];
size_t data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
size_t stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
size_t args[VecSize];
T result[VecSize];
for (; data_offset < main_offset; data_offset += stride) {
kps::InitWithDataIndex<int, VecSize, 1, 1>(&args[0], data_offset);
kps::ElementwiseUnary<int, T, VecSize, 1, 1, Functor>(&result[0], &args[0],
func);
kps::InitWithDataIndex<size_t, VecSize, 1, 1>(&args[0], data_offset);
kps::ElementwiseUnary<size_t, T, VecSize, 1, 1, Functor>(&result[0],
&args[0], func);
kps::WriteData<T, VecSize, 1, 1, false>(out + data_offset, &result[0],
BLOCK_NUM_X * VecSize);
}
int num = numel - data_offset;
size_t num = numel - data_offset;
if (num > 0) {
kps::InitWithDataIndex<int, VecSize, 1, 1>(&args[0], data_offset);
kps::ElementwiseUnary<int, T, VecSize, 1, 1, Functor>(&result[0], &args[0],
func);
kps::InitWithDataIndex<size_t, VecSize, 1, 1>(&args[0], data_offset);
kps::ElementwiseUnary<size_t, T, VecSize, 1, 1, Functor>(&result[0],
&args[0], func);
kps::WriteData<T, VecSize, 1, 1, true>(out + data_offset, &result[0], num);
}
}
......@@ -58,7 +58,7 @@ void IndexKernel(const KPDevice &dev_ctx, Tensor *out, Functor func) {
int numel = out->numel();
T *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (numel <= 0) return;
int vec_size = paddle::platform::GetVectorizedSize((out->data<T>()));
int vec_size = paddle::platform::GetVectorizedSize(out_data);
#ifdef PADDLE_WITH_XPU_KP
int block = 64;
int grid = 8;
......@@ -70,8 +70,7 @@ void IndexKernel(const KPDevice &dev_ctx, Tensor *out, Functor func) {
int block = config.thread_per_block.x;
auto stream = dev_ctx.stream();
#endif
int main_offset = (numel / (vec_size * block)) * vec_size * block;
size_t main_offset = (numel / (vec_size * block)) * vec_size * block;
switch (vec_size) {
case 4:
VectorizedIndexKernel<T, Functor, 4><<<grid, block, 0, stream>>>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册