未验证 提交 cbb956b3 编写于 作者: N niuliling123 提交者: GitHub

Fix overwrite in where_index (#44181)

上级 69a4a39f
......@@ -395,7 +395,6 @@ void SelectKernel(const KPDevice &dev_ctx,
paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace();
// 1.1 get stored data num of per block
int total_true_num = 0; // init
const int kVecSize = 4;
#ifdef PADDLE_WITH_XPU_KP
int block = 64;
......@@ -424,6 +423,7 @@ void SelectKernel(const KPDevice &dev_ctx,
DenseTensor cumsum_mem = phi::Empty<CT, KPDevice>(dev_ctx, dims_array);
CT *cumsum_data = cumsum_mem.data<CT>();
// 2.2 get prefix of count_data for real out_index
CT total_true_num = static_cast<CT>(0); // init
const int kCumVesize = 2;
const int block_c = 256;
const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c));
......@@ -448,7 +448,7 @@ void SelectKernel(const KPDevice &dev_ctx,
if (SelectData == 1) {
out->Resize(phi::make_ddim(out_dim));
} else if (SelectData == 0) { // == 0 where_index
out_dim.push_back(rank);
out_dim.push_back(static_cast<int64_t>(rank));
out->Resize(phi::make_ddim(out_dim));
}
auto out_data = out->mutable_data<OutT>(cuda_place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册