diff --git a/paddle/phi/kernels/funcs/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h index d228174846a1ec182615ff1b6c67493a69541266..09da00d7cca14771d66d78d40a7f23ac3168eaf9 100644 --- a/paddle/phi/kernels/funcs/index_put_utils.h +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -15,6 +15,7 @@ #pragma once #include +#include "paddle/phi/backends/context_pool.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/place.h" @@ -106,7 +107,14 @@ std::vector DealWithBoolIndices( SplitWithNumKernel( dev_ctx, nonzero_indices, rank, 1, integer_indices); #ifdef PADDLE_WITH_XPU - dev_ctx.Wait(); + auto place = dev_ctx.GetPlace(); + if (place.GetType() == phi::AllocationType::XPU) { + auto& pool = phi::DeviceContextPool::Instance(); + auto* xpu_ctx = static_cast(pool.Get(place)); + if (xpu_ctx->x_context()->xpu_stream) { + dev_ctx.Wait(); + } + } #endif } else if ((indices_v[i]->dtype() == phi::DataType::INT64) || diff --git a/paddle/phi/kernels/xpu/index_put_kernel.cc b/paddle/phi/kernels/xpu/index_put_kernel.cc index f059da4bb4574664f4de2b912a272b28e85067b5..4197b9698cb3c1573a6590120582043dd3de9c2a 100644 --- a/paddle/phi/kernels/xpu/index_put_kernel.cc +++ b/paddle/phi/kernels/xpu/index_put_kernel.cc @@ -65,7 +65,9 @@ void XPUDealWithIndices(const Context& dev_ctx, } StackKernel(dev_ctx, tmp_indices_ptr, -1, out); - dev_ctx.Wait(); + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } } template @@ -140,7 +142,9 @@ void IndexPutKernel(const Context& dev_ctx, index_shape, accumulate); PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_put"); - dev_ctx.Wait(); + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } } } // namespace phi diff --git a/paddle/phi/kernels/xpu/unique_kernel.cc b/paddle/phi/kernels/xpu/unique_kernel.cc index 18ad41b14e88958b3021cf8d9c24bec24e0a1e24..6f2d8f470a2120f8b6f95c6c3842c365fc094067 100644 --- a/paddle/phi/kernels/xpu/unique_kernel.cc +++ b/paddle/phi/kernels/xpu/unique_kernel.cc @@ -228,26 +228,37 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx, inverse_cpu[ori_idx_cpu[0]] = 0; IndexT unique_len = 1; IndexT repeat_cnt = 1; - for (IndexT i = 1; i < axis_len; ++i) { - int cnt_cpu = 0; - int* cnt_xpu = RAII_GUARD.alloc_l3_or_gm(1); - r = xpu::nonzero_count(dev_ctx.x_context(), - compare_results + (i - 1) * slice_size, - cnt_xpu, - slice_size); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "nonzero_count"); - memory_utils::Copy( - phi::CPUPlace(), &cnt_cpu, dev_ctx.GetPlace(), cnt_xpu, sizeof(int)); - if (cnt_cpu != slice_size) { - unique_axis.push_back(i); - indices_cpu.push_back(ori_idx_cpu[i]); - counts_cpu.push_back(repeat_cnt); - ++unique_len; - repeat_cnt = 1; - } else { - ++repeat_cnt; + if (axis_len > 1) { + DenseTensor adj_identical_cpu; + adj_identical_cpu.Resize({axis_len - 1}); + bool* adj_identical_cpu_data = + dev_ctx.template HostAlloc(&adj_identical_cpu); + auto* adj_identical_xpu = RAII_GUARD.alloc_l3_or_gm(axis_len - 1); + r = xpu::reduce_all(dev_ctx.x_context(), + compare_results, + adj_identical_xpu, + {axis_len - 1, slice_size}, + {1}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_all"); + + memory_utils::Copy(phi::CPUPlace(), + adj_identical_cpu_data, + dev_ctx.GetPlace(), + adj_identical_xpu, + (axis_len - 1) * sizeof(bool)); + + for (IndexT i = 1; i < axis_len; ++i) { + if (!adj_identical_cpu_data[i - 1]) { + unique_axis.push_back(i); + indices_cpu.push_back(ori_idx_cpu[i]); + counts_cpu.push_back(repeat_cnt); + ++unique_len; + repeat_cnt = 1; + } else { + ++repeat_cnt; + } + inverse_cpu[ori_idx_cpu[i]] = unique_len - 1; } - inverse_cpu[ori_idx_cpu[i]] = unique_len - 1; } counts_cpu.push_back(repeat_cnt); DDim out_dims = x.dims();