未验证 提交 8cbeefea 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of unique kernel (#52736)

* Optimize performance of unique kernel

* fix ci
上级 c376a940
......@@ -30,6 +30,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/unique_functor.h"
#include "paddle/phi/kernels/index_select_kernel.h"
namespace phi {
......@@ -98,76 +99,6 @@ struct BinaryNotEqual {
}
};
// index_select() function for DenseTensor
template <typename Context, typename InT, typename IndexT>
void IndexSelect(const Context& context,
const DenseTensor& input,
const DenseTensor& index,
DenseTensor* output,
int dim) {
auto input_dim = input.dims();
auto input_dim_size = input_dim.size();
auto output_dim = output->dims();
auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) {
slice_size *= input_dim[i];
}
auto input_width = slice_size * input_dim[dim];
auto output_width = slice_size * output_dim[dim];
auto outer_nums = 1;
for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i];
}
auto index_size = index.dims()[0];
std::vector<InT> input_vec;
std::vector<IndexT> index_vec;
phi::TensorToVector(input, context, &input_vec);
phi::TensorToVector(index, context, &index_vec);
std::vector<InT> out_vec(output->numel());
for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_GE(
index_vec[i],
0,
phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
input_dim[dim],
index_vec[i]));
PADDLE_ENFORCE_LT(
index_vec[i],
input_dim[dim],
phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
input_dim[dim],
index_vec[i]));
}
for (auto i = 0; i < outer_nums; i++) {
auto input_start_offset = i * input_width;
auto output_start_offset = i * output_width;
for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_vec[j];
for (auto k = 0; k < slice_size; k++) {
out_vec[output_start_offset + j * slice_size + k] =
input_vec[input_start_offset + index_value * slice_size + k];
}
}
}
context.template Alloc<IndexT>(output);
phi::TensorFromVector(out_vec, context, output);
output->Resize(output_dim);
}
// The core logic of computing Unique for a flattend DenseTensor
template <typename Context,
typename InT,
......@@ -354,24 +285,29 @@ static void UniqueDimsCUDATensor(const Context& context,
int axis) {
// 1. Transpose & reshape
// Transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2]
std::vector<int> permute(in.dims().size());
std::iota(permute.begin(), permute.end(), 0);
permute[axis] = 0;
permute[0] = axis;
std::vector<int64_t> in_trans_dims_vec(phi::vectorize(in.dims()));
in_trans_dims_vec[axis] = in.dims()[0];
in_trans_dims_vec[0] = in.dims()[axis];
DenseTensor in_trans;
std::vector<int64_t> in_trans_dims_vec(phi::vectorize(in.dims()));
auto in_trans_dims = phi::make_ddim(in_trans_dims_vec);
in_trans.Resize(in_trans_dims);
context.template Alloc<InT>(&in_trans);
phi::funcs::TransCompute<Context, InT>(
in.dims().size(), // num of dims
context, // device
in, // original DenseTensor
&in_trans, // DenseTensor after reshape
permute); // index of axis
std::vector<int> permute(in.dims().size());
bool is_transpose = axis != 0;
if (is_transpose) {
std::iota(permute.begin(), permute.end(), 0);
permute[axis] = 0;
permute[0] = axis;
in_trans_dims_vec[axis] = in.dims()[0];
in_trans_dims_vec[0] = in.dims()[axis];
in_trans_dims = phi::make_ddim(in_trans_dims_vec);
in_trans.Resize(in_trans_dims);
context.template Alloc<InT>(&in_trans);
phi::funcs::TransCompute<Context, InT>(
in.dims().size(), // num of dims
context, // device
in, // original DenseTensor
&in_trans, // DenseTensor after reshape
permute); // index of axis
} else {
in_trans.ShareDataWith(in);
}
// Reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2]
auto in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1);
in_trans.Resize(in_trans_flat_dims);
......@@ -407,22 +343,27 @@ static void UniqueDimsCUDATensor(const Context& context,
row);
// 3. Select indices and reshape back to get 'out'
DenseTensor out_trans;
std::vector<int64_t> out_trans_dims_vec = in_trans_dims_vec;
out_trans_dims_vec[0] = indices->numel();
out_trans.Resize(phi::make_ddim(out_trans_dims_vec));
context.template Alloc<InT>(&out_trans);
IndexSelect<Context, InT, IndexT>(context, in_trans, *indices, &out_trans, 0);
std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]);
out->Resize(phi::make_ddim(out_trans_dims_vec));
context.template Alloc<InT>(out);
std::vector<DenseTensor> out_trans_unbind = phi::funcs::Unbind(out_trans);
phi::funcs::ConcatFunctor<Context, InT> concat_functor;
concat_functor(context, out_trans_unbind, 0, &out_trans);
phi::funcs::TransCompute<Context, InT>(
out_trans.dims().size(), context, out_trans, out, permute);
if (is_transpose) {
DenseTensor out_trans;
out_trans.Resize(phi::make_ddim(out_trans_dims_vec));
context.template Alloc<InT>(&out_trans);
phi::IndexSelectKernel<InT, Context>(
context, in_trans, *indices, 0, &out_trans);
std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]);
out->Resize(phi::make_ddim(out_trans_dims_vec));
context.template Alloc<InT>(out);
phi::funcs::TransCompute<Context, InT>(
out_trans.dims().size(), context, out_trans, out, permute);
} else {
out->Resize(phi::make_ddim(out_trans_dims_vec));
context.template Alloc<InT>(out);
phi::IndexSelectKernel<InT, Context>(context, in_trans, *indices, 0, out);
}
}
// functor for processing a flattend DenseTensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册