未验证 提交 3e9ad093 编写于 作者: F FlyingQianMM 提交者: GitHub

fix index_select kernel configuration error where input numel is 0 (#41383)

上级 1888d874
...@@ -85,6 +85,9 @@ void IndexSelectGradKernel(const Context& ctx, ...@@ -85,6 +85,9 @@ void IndexSelectGradKernel(const Context& ctx,
phi::DataType::INT64)); phi::DataType::INT64));
int64_t numel = x_grad->numel(); int64_t numel = x_grad->numel();
if (numel == 0) {
return;
}
int64_t index_nums = index.numel(); int64_t index_nums = index.numel();
int64_t out_nums = out_grad.numel(); int64_t out_nums = out_grad.numel();
......
...@@ -72,6 +72,9 @@ void IndexSelectKernel(const Context& ctx, ...@@ -72,6 +72,9 @@ void IndexSelectKernel(const Context& ctx,
T* out_data = ctx.template Alloc<T>(output); T* out_data = ctx.template Alloc<T>(output);
int64_t numel = output->numel(); int64_t numel = output->numel();
if (numel == 0) {
return;
}
auto stream = ctx.stream(); auto stream = ctx.stream();
unsigned int block_dim = PADDLE_CUDA_NUM_THREADS; unsigned int block_dim = PADDLE_CUDA_NUM_THREADS;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册