未验证 提交 cbb956b3 编写于 作者: N niuliling123 提交者: GitHub

Fix overwrite in where_index (#44181)

上级 69a4a39f
...@@ -395,7 +395,6 @@ void SelectKernel(const KPDevice &dev_ctx, ...@@ -395,7 +395,6 @@ void SelectKernel(const KPDevice &dev_ctx,
paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace(); paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace();
// 1.1 get stored data num of per block // 1.1 get stored data num of per block
int total_true_num = 0; // init
const int kVecSize = 4; const int kVecSize = 4;
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
int block = 64; int block = 64;
...@@ -424,6 +423,7 @@ void SelectKernel(const KPDevice &dev_ctx, ...@@ -424,6 +423,7 @@ void SelectKernel(const KPDevice &dev_ctx,
DenseTensor cumsum_mem = phi::Empty<CT, KPDevice>(dev_ctx, dims_array); DenseTensor cumsum_mem = phi::Empty<CT, KPDevice>(dev_ctx, dims_array);
CT *cumsum_data = cumsum_mem.data<CT>(); CT *cumsum_data = cumsum_mem.data<CT>();
// 2.2 get prefix of count_data for real out_index // 2.2 get prefix of count_data for real out_index
CT total_true_num = static_cast<CT>(0); // init
const int kCumVesize = 2; const int kCumVesize = 2;
const int block_c = 256; const int block_c = 256;
const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c)); const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c));
...@@ -448,7 +448,7 @@ void SelectKernel(const KPDevice &dev_ctx, ...@@ -448,7 +448,7 @@ void SelectKernel(const KPDevice &dev_ctx,
if (SelectData == 1) { if (SelectData == 1) {
out->Resize(phi::make_ddim(out_dim)); out->Resize(phi::make_ddim(out_dim));
} else if (SelectData == 0) { // == 0 where_index } else if (SelectData == 0) { // == 0 where_index
out_dim.push_back(rank); out_dim.push_back(static_cast<int64_t>(rank));
out->Resize(phi::make_ddim(out_dim)); out->Resize(phi::make_ddim(out_dim));
} }
auto out_data = out->mutable_data<OutT>(cuda_place); auto out_data = out->mutable_data<OutT>(cuda_place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册