From cbb956b328d2f5be8e46d978c207623f55ea7917 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Mon, 11 Jul 2022 10:17:24 +0800 Subject: [PATCH] Fix overwrite in where_index (#44181) --- paddle/phi/kernels/funcs/select_impl.cu.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/funcs/select_impl.cu.h b/paddle/phi/kernels/funcs/select_impl.cu.h index a036f27cc2b..831e0ca907b 100644 --- a/paddle/phi/kernels/funcs/select_impl.cu.h +++ b/paddle/phi/kernels/funcs/select_impl.cu.h @@ -395,7 +395,6 @@ void SelectKernel(const KPDevice &dev_ctx, paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace(); // 1.1 get stored data num of per block - int total_true_num = 0; // init const int kVecSize = 4; #ifdef PADDLE_WITH_XPU_KP int block = 64; @@ -424,6 +423,7 @@ void SelectKernel(const KPDevice &dev_ctx, DenseTensor cumsum_mem = phi::Empty(dev_ctx, dims_array); CT *cumsum_data = cumsum_mem.data(); // 2.2 get prefix of count_data for real out_index + CT total_true_num = static_cast(0); // init const int kCumVesize = 2; const int block_c = 256; const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c)); @@ -448,7 +448,7 @@ void SelectKernel(const KPDevice &dev_ctx, if (SelectData == 1) { out->Resize(phi::make_ddim(out_dim)); } else if (SelectData == 0) { // == 0 where_index - out_dim.push_back(rank); + out_dim.push_back(static_cast(rank)); out->Resize(phi::make_ddim(out_dim)); } auto out_data = out->mutable_data(cuda_place); -- GitLab