未验证 提交 d674ea95 编写于 作者: L lijin23 提交者: GitHub

optimize unique and index_put (#56582)

上级 75ee1a88
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
...@@ -106,7 +107,14 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices( ...@@ -106,7 +107,14 @@ std::vector<const phi::DenseTensor*> DealWithBoolIndices(
SplitWithNumKernel<int64_t, Context>( SplitWithNumKernel<int64_t, Context>(
dev_ctx, nonzero_indices, rank, 1, integer_indices); dev_ctx, nonzero_indices, rank, 1, integer_indices);
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
dev_ctx.Wait(); auto place = dev_ctx.GetPlace();
if (place.GetType() == phi::AllocationType::XPU) {
auto& pool = phi::DeviceContextPool::Instance();
auto* xpu_ctx = static_cast<phi::XPUContext*>(pool.Get(place));
if (xpu_ctx->x_context()->xpu_stream) {
dev_ctx.Wait();
}
}
#endif #endif
} else if ((indices_v[i]->dtype() == phi::DataType::INT64) || } else if ((indices_v[i]->dtype() == phi::DataType::INT64) ||
......
...@@ -65,7 +65,9 @@ void XPUDealWithIndices(const Context& dev_ctx, ...@@ -65,7 +65,9 @@ void XPUDealWithIndices(const Context& dev_ctx,
} }
StackKernel<int64_t, Context>(dev_ctx, tmp_indices_ptr, -1, out); StackKernel<int64_t, Context>(dev_ctx, tmp_indices_ptr, -1, out);
dev_ctx.Wait(); if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -140,7 +142,9 @@ void IndexPutKernel(const Context& dev_ctx, ...@@ -140,7 +142,9 @@ void IndexPutKernel(const Context& dev_ctx,
index_shape, index_shape,
accumulate); accumulate);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_put"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_put");
dev_ctx.Wait(); if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
} }
} // namespace phi } // namespace phi
......
...@@ -228,26 +228,37 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx, ...@@ -228,26 +228,37 @@ void XPUDimUniqueKernelImpl(const Context& dev_ctx,
inverse_cpu[ori_idx_cpu[0]] = 0; inverse_cpu[ori_idx_cpu[0]] = 0;
IndexT unique_len = 1; IndexT unique_len = 1;
IndexT repeat_cnt = 1; IndexT repeat_cnt = 1;
for (IndexT i = 1; i < axis_len; ++i) { if (axis_len > 1) {
int cnt_cpu = 0; DenseTensor adj_identical_cpu;
int* cnt_xpu = RAII_GUARD.alloc_l3_or_gm<int>(1); adj_identical_cpu.Resize({axis_len - 1});
r = xpu::nonzero_count<bool>(dev_ctx.x_context(), bool* adj_identical_cpu_data =
compare_results + (i - 1) * slice_size, dev_ctx.template HostAlloc<bool>(&adj_identical_cpu);
cnt_xpu, auto* adj_identical_xpu = RAII_GUARD.alloc_l3_or_gm<bool>(axis_len - 1);
slice_size); r = xpu::reduce_all<bool>(dev_ctx.x_context(),
PADDLE_ENFORCE_XDNN_SUCCESS(r, "nonzero_count"); compare_results,
memory_utils::Copy( adj_identical_xpu,
phi::CPUPlace(), &cnt_cpu, dev_ctx.GetPlace(), cnt_xpu, sizeof(int)); {axis_len - 1, slice_size},
if (cnt_cpu != slice_size) { {1});
unique_axis.push_back(i); PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_all");
indices_cpu.push_back(ori_idx_cpu[i]);
counts_cpu.push_back(repeat_cnt); memory_utils::Copy(phi::CPUPlace(),
++unique_len; adj_identical_cpu_data,
repeat_cnt = 1; dev_ctx.GetPlace(),
} else { adj_identical_xpu,
++repeat_cnt; (axis_len - 1) * sizeof(bool));
for (IndexT i = 1; i < axis_len; ++i) {
if (!adj_identical_cpu_data[i - 1]) {
unique_axis.push_back(i);
indices_cpu.push_back(ori_idx_cpu[i]);
counts_cpu.push_back(repeat_cnt);
++unique_len;
repeat_cnt = 1;
} else {
++repeat_cnt;
}
inverse_cpu[ori_idx_cpu[i]] = unique_len - 1;
} }
inverse_cpu[ori_idx_cpu[i]] = unique_len - 1;
} }
counts_cpu.push_back(repeat_cnt); counts_cpu.push_back(repeat_cnt);
DDim out_dims = x.dims(); DDim out_dims = x.dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册