未验证 提交 d674ea95 编写于 作者: L lijin23 提交者: GitHub

optimize unique and index_put (#56582)

上级 75ee1a88
......@@ -15,6 +15,7 @@
#pragma once
#include <vector>
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/common/place.h"
......@@ -106,7 +107,14 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices(
SplitWithNumKernel<int64_t, Context>(
dev_ctx, nonzero_indices, rank, 1, integer_indices);
#ifdef PADDLE_WITH_XPU
auto place = dev_ctx.GetPlace();
if (place.GetType() == phi::AllocationType::XPU) {
auto& pool = phi::DeviceContextPool::Instance();
auto* xpu_ctx = static_cast<phi::XPUContext*>(pool.Get(place));
if (xpu_ctx->x_context()->xpu_stream) {
dev_ctx.Wait();
}
}
#endif
} else if ((indices_v[i]->dtype() == phi::DataType::INT64) ||
......
......@@ -65,7 +65,9 @@ void XPUDealWithIndices(const Context& dev_ctx,
}
StackKernel<int64_t, Context>(dev_ctx, tmp_indices_ptr, -1, out);
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
}
template <typename T, typename Context>
......@@ -140,7 +142,9 @@ void IndexPutKernel(const Context& dev_ctx,
index_shape,
accumulate);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_put");
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
}
} // namespace phi
......
......@@ -228,17 +228,27 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx,
inverse_cpu[ori_idx_cpu[0]] = 0;
IndexT unique_len = 1;
IndexT repeat_cnt = 1;
if (axis_len > 1) {
DenseTensor adj_identical_cpu;
adj_identical_cpu.Resize({axis_len - 1});
bool* adj_identical_cpu_data =
dev_ctx.template HostAlloc<bool>(&adj_identical_cpu);
auto* adj_identical_xpu = RAII_GUARD.alloc_l3_or_gm<bool>(axis_len - 1);
r = xpu::reduce_all<bool>(dev_ctx.x_context(),
compare_results,
adj_identical_xpu,
{axis_len - 1, slice_size},
{1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_all");
memory_utils::Copy(phi::CPUPlace(),
adj_identical_cpu_data,
dev_ctx.GetPlace(),
adj_identical_xpu,
(axis_len - 1) * sizeof(bool));
for (IndexT i = 1; i < axis_len; ++i) {
int cnt_cpu = 0;
int* cnt_xpu = RAII_GUARD.alloc_l3_or_gm<int>(1);
r = xpu::nonzero_count<bool>(dev_ctx.x_context(),
compare_results + (i - 1) * slice_size,
cnt_xpu,
slice_size);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "nonzero_count");
memory_utils::Copy(
phi::CPUPlace(), &cnt_cpu, dev_ctx.GetPlace(), cnt_xpu, sizeof(int));
if (cnt_cpu != slice_size) {
if (!adj_identical_cpu_data[i - 1]) {
unique_axis.push_back(i);
indices_cpu.push_back(ori_idx_cpu[i]);
counts_cpu.push_back(repeat_cnt);
......@@ -249,6 +259,7 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx,
}
inverse_cpu[ori_idx_cpu[i]] = unique_len - 1;
}
}
counts_cpu.push_back(repeat_cnt);
DDim out_dims = x.dims();
out_dims[axis] = unique_len;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册