未验证 提交 d674ea95 编写于 作者: L lijin23 提交者: GitHub

optimize unique and index_put (#56582)

上级 75ee1a88
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
...@@ -106,7 +107,14 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices( ...@@ -106,7 +107,14 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices(
SplitWithNumKernel<int64_t, Context>( SplitWithNumKernel<int64_t, Context>(
dev_ctx, nonzero_indices, rank, 1, integer_indices); dev_ctx, nonzero_indices, rank, 1, integer_indices);
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
auto place = dev_ctx.GetPlace();
if (place.GetType() == phi::AllocationType::XPU) {
auto& pool = phi::DeviceContextPool::Instance();
auto* xpu_ctx = static_cast<phi::XPUContext*>(pool.Get(place));
if (xpu_ctx->x_context()->xpu_stream) {
dev_ctx.Wait(); dev_ctx.Wait();
}
}
#endif #endif
} else if ((indices_v[i]->dtype() == phi::DataType::INT64) || } else if ((indices_v[i]->dtype() == phi::DataType::INT64) ||
......
...@@ -65,7 +65,9 @@ void XPUDealWithIndices(const Context& dev_ctx, ...@@ -65,7 +65,9 @@ void XPUDealWithIndices(const Context& dev_ctx,
} }
StackKernel<int64_t, Context>(dev_ctx, tmp_indices_ptr, -1, out); StackKernel<int64_t, Context>(dev_ctx, tmp_indices_ptr, -1, out);
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait(); dev_ctx.Wait();
}
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -140,7 +142,9 @@ void IndexPutKernel(const Context& dev_ctx, ...@@ -140,7 +142,9 @@ void IndexPutKernel(const Context& dev_ctx,
index_shape, index_shape,
accumulate); accumulate);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_put"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_put");
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait(); dev_ctx.Wait();
}
} }
} // namespace phi } // namespace phi
......
...@@ -228,17 +228,27 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx, ...@@ -228,17 +228,27 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx,
inverse_cpu[ori_idx_cpu[0]] = 0; inverse_cpu[ori_idx_cpu[0]] = 0;
IndexT unique_len = 1; IndexT unique_len = 1;
IndexT repeat_cnt = 1; IndexT repeat_cnt = 1;
if (axis_len > 1) {
DenseTensor adj_identical_cpu;
adj_identical_cpu.Resize({axis_len - 1});
bool* adj_identical_cpu_data =
dev_ctx.template HostAlloc<bool>(&adj_identical_cpu);
auto* adj_identical_xpu = RAII_GUARD.alloc_l3_or_gm<bool>(axis_len - 1);
r = xpu::reduce_all<bool>(dev_ctx.x_context(),
compare_results,
adj_identical_xpu,
{axis_len - 1, slice_size},
{1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_all");
memory_utils::Copy(phi::CPUPlace(),
adj_identical_cpu_data,
dev_ctx.GetPlace(),
adj_identical_xpu,
(axis_len - 1) * sizeof(bool));
for (IndexT i = 1; i < axis_len; ++i) { for (IndexT i = 1; i < axis_len; ++i) {
int cnt_cpu = 0; if (!adj_identical_cpu_data[i - 1]) {
int* cnt_xpu = RAII_GUARD.alloc_l3_or_gm<int>(1);
r = xpu::nonzero_count<bool>(dev_ctx.x_context(),
compare_results + (i - 1) * slice_size,
cnt_xpu,
slice_size);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "nonzero_count");
memory_utils::Copy(
phi::CPUPlace(), &cnt_cpu, dev_ctx.GetPlace(), cnt_xpu, sizeof(int));
if (cnt_cpu != slice_size) {
unique_axis.push_back(i); unique_axis.push_back(i);
indices_cpu.push_back(ori_idx_cpu[i]); indices_cpu.push_back(ori_idx_cpu[i]);
counts_cpu.push_back(repeat_cnt); counts_cpu.push_back(repeat_cnt);
...@@ -249,6 +259,7 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx, ...@@ -249,6 +259,7 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx,
} }
inverse_cpu[ori_idx_cpu[i]] = unique_len - 1; inverse_cpu[ori_idx_cpu[i]] = unique_len - 1;
} }
}
counts_cpu.push_back(repeat_cnt); counts_cpu.push_back(repeat_cnt);
DDim out_dims = x.dims(); DDim out_dims = x.dims();
out_dims[axis] = unique_len; out_dims[axis] = unique_len;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册