未验证 提交 329b095e 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Phi]Move topk kernel to phi (#40064)

* first commit

* cpu kernel

* first version

* fix compile error

* fix compile error

* delete v2

* fix

* fix

* add alias

* fix

* fix

* fix

* fix error

* fix

* fix

* fix

* fix format
上级 2b6da4de
......@@ -16,7 +16,6 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/kthvalue_op.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_v2_op.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
......
......@@ -24,7 +24,6 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mode_op.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_v2_op.h"
namespace paddle {
namespace operators {
......
......@@ -51,6 +51,19 @@ namespace operators {
using Tensor = framework::Tensor;
inline void GetDims(const phi::DDim& dim, int axis, int* pre, int* n,
int* post) {
*pre = 1;
*post = 1;
*n = dim[axis];
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
}
struct SegmentOffsetIter {
EIGEN_DEVICE_FUNC
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
......
......@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/top_k_v2_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......@@ -173,15 +174,3 @@ REGISTER_OPERATOR(top_k_v2, ops::TopkV2Op, ops::TopkV2OpMaker,
ops::TopkV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(top_k_v2_grad, ops::TopkV2OpGrad);
REGISTER_OP_CPU_KERNEL(top_k_v2,
ops::TopkV2Kernel<paddle::platform::CPUPlace, float>,
ops::TopkV2Kernel<paddle::platform::CPUPlace, double>,
ops::TopkV2Kernel<paddle::platform::CPUPlace, int32_t>,
ops::TopkV2Kernel<paddle::platform::CPUPlace, int64_t>)
REGISTER_OP_CPU_KERNEL(
top_k_v2_grad, ops::TopkV2GradKernel<paddle::platform::CPUPlace, float>,
ops::TopkV2GradKernel<paddle::platform::CPUPlace, double>,
ops::TopkV2GradKernel<paddle::platform::CPUPlace, int32_t>,
ops::TopkV2GradKernel<paddle::platform::CPUPlace, int64_t>)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_v2_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
template <typename DeviceContext, typename T>
class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
"It must use CUDAPlace, you must check your device set."));
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
// get the attributes
int k = static_cast<int>(ctx.Attr<int>("k"));
int axis = static_cast<int>(ctx.Attr<int>("axis"));
const bool& sorted = static_cast<bool>(ctx.Attr<bool>("sorted"));
const bool& largest = static_cast<bool>(ctx.Attr<bool>("largest"));
// get the input dims
const auto& in_dims = input->dims();
// calcluate the real axis
if (axis < 0) axis += in_dims.size();
auto* k_t = ctx.Input<Tensor>("K");
if (k_t) {
Tensor k_host;
framework::TensorCopySync(*k_t, platform::CPUPlace(), &k_host);
k = k_host.data<int>()[0];
framework::DDim output_dims = output->dims();
output_dims[axis] = k;
output->Resize(output_dims);
indices->Resize(output_dims);
}
const auto& out_dims = output->dims();
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
if (axis == in_dims.size() - 1) {
// if get the topK from the last axis
const int64_t& input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
if (k > input_width) k = input_width;
// The conclusion is drawn from the data through multiple sets of
// statistics
if (input_width >= 128 && k >= input_width * 0.75) {
if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
indices, largest)) {
// Successed, return.
return;
} else {
LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
"default topk kernel.";
}
}
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
switch (GetDesiredBlockDim(input_width)) {
#ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM(
KeMatrixTopK<T, 20,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, k, indices_data, input_data, input_width,
input_width, static_cast<int>(k), gridx, input_height,
largest));
#else
FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, k, indices_data, input_data, input_width,
input_width, static_cast<int>(k), gridx, input_height,
largest));
#endif
default:
PADDLE_THROW(platform::errors::Fatal(
"the input data shape has error in the topk cuda kernel."));
}
} else {
// if get topK not from the last axis, will tranpose the tensor and get
// TopK
// first step, prepare the trans args for the tranpose
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.emplace_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
framework::DDim trans_dims(in_dims);
framework::DDim trans_out_dims(output->dims());
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
trans_out_dims[i] = out_dims[trans[i]];
}
// second step, tranpose the input
Tensor trans_input;
trans_input.mutable_data<T>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
const auto& dev_ctx = ctx.cuda_device_context();
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *input,
&trans_input, trans);
// third step, calcluate the topk
// allocate the tmp cuda memory for the tmp result
Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_out_dims, ctx.GetPlace());
Tensor trans_out;
trans_out.mutable_data<T>(trans_out_dims, ctx.GetPlace());
const int64_t input_height =
phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
if (k > input_width) k = input_width;
// The conclusion is drawn from the data through multiple sets of
// statistics
if (input_width >= 128 && k >= input_width * 0.75) {
if (SortTopk<T>(dev_ctx, &trans_input, input_width, input_height, k,
&trans_out, &trans_ind, largest)) {
// last step, tranpose back the indices and output
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans);
TransCompute<platform::CUDADeviceContext, T>(
ndims, dev_ctx, trans_out, output, trans);
return;
} else {
LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
"default topk kernel.";
}
}
const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
switch (GetDesiredBlockDim(input_width)) {
#ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM(
KeMatrixTopK<T, 20,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
trans_out.data<T>(), k, trans_ind.data<int64_t>(),
trans_input.data<T>(), input_width, input_width,
static_cast<int>(k), gridx, input_height, largest));
#else
FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
trans_out.data<T>(), k, trans_ind.data<int64_t>(),
trans_input.data<T>(), input_width, input_width,
static_cast<int>(k), gridx, input_height, largest));
#endif
default:
PADDLE_THROW(platform::errors::Fatal(
"the input data shape has error in the topk cuda kernel."));
}
// last step, tranpose back the indices and output
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans);
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, trans_out,
output, trans);
}
}
};
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
template <typename DeviceContext, typename T>
class TopkV2OpGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::InvalidArgument(
"It must use CUDAPlace, you must check your device set."));
auto* x = context.Input<Tensor>("X");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<Tensor>("Indices");
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int axis = context.Attr<int>("axis");
const auto& in_dims = x->dims();
const auto& out_dims = indices->dims();
// get the real the axis and the k
if (axis < 0) axis += in_dims.size();
const int& k = out_dims[axis];
const int& raw_height = in_dims[axis];
// allocate the cuda memory for the x_grad
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
const T* out_grad_data = out_grad->data<T>();
const int64_t* indices_data = indices->data<int64_t>();
int pre, n, post;
GetDims(in_dims, axis, &pre, &n, &post);
// calcluate the block and grid num
auto& dev_ctx = context.cuda_device_context();
auto ComputeBlockSize = [](int col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
};
int block_size = ComputeBlockSize(post * k);
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(((max_threads - 1) / block_size + 1), 1);
int grid_size = std::min(max_blocks, pre);
// lanuch the cuda kernel to assign the grad
AssignGradWithAxis<T><<<grid_size, block_size, 64 * 4, dev_ctx.stream()>>>(
out_grad_data, indices_data, x_grad_data, pre, post, n, k);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
top_k_v2,
paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
top_k_v2_grad, paddle::operators::TopkV2OpGradCUDAKernel<
paddle::platform::CUDADeviceContext, float>,
paddle::operators::TopkV2OpGradCUDAKernel<
paddle::platform::CUDADeviceContext, double>,
paddle::operators::TopkV2OpGradCUDAKernel<
paddle::platform::CUDADeviceContext, int>,
paddle::operators::TopkV2OpGradCUDAKernel<
paddle::platform::CUDADeviceContext, int64_t>,
paddle::operators::TopkV2OpGradCUDAKernel<
paddle::platform::CUDADeviceContext, paddle::platform::float16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
The reason why we need the topk v2 is because the compatibility. We redefine
the NaN is maximum value
in the process of comparing. If do not add the topk v2, will affect the
inference result of model that traing
by the older version paddlepaddle.
*/
#pragma once
#include <algorithm>
#include <iostream>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/operators/transpose_op.h"
namespace paddle {
namespace operators {
inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
int* post) {
*pre = 1;
*post = 1;
*n = dim[axis];
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
}
template <typename T, typename Type>
static void FullTopK(Type input_height, Type input_width, int input_dim,
const framework::Tensor* input, T* t_out, Type* t_indices,
const int& k, const bool& largest, const bool& sorted) {
// when the k is small, will the partial sort
bool partial_sort_flag = (k * 64) < input_width;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
// Eigen::DSizes<int, 2> flat2dims(input_height, input_width);
for (Type i = 0; i < input_height; ++i) {
std::vector<std::pair<T, Type>> col_vec;
col_vec.reserve(input_width);
if (input_dim == 1) {
auto e_input = framework::EigenVector<T>::Flatten(*input);
for (Type j = 0; j < input_width; ++j) {
col_vec.emplace_back(std::pair<T, Type>(e_input(j), j));
}
} else {
auto e_input = framework::EigenMatrix<T>::Reshape(*input, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
col_vec.emplace_back(std::pair<T, Type>(e_input(i, j), j));
}
}
if (partial_sort_flag) {
std::partial_sort(
col_vec.begin(), col_vec.begin() + k, col_vec.end(),
[&largest](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (largest) {
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
} else {
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
}
});
} else {
// use the nth-element to get the K-larger or K-small element
if (largest) {
std::nth_element(
col_vec.begin(), col_vec.begin() + k - 1, col_vec.end(),
[](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
});
// the nth-element will get the unorder elements, sort the element
if (sorted) {
std::sort(col_vec.begin(), col_vec.begin() + k - 1,
[&largest](const std::pair<T, Type>& l,
const std::pair<T, Type>& r) {
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
});
}
} else {
std::nth_element(
col_vec.begin(), col_vec.begin() + k - 1, col_vec.end(),
[](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
// the nth-element will get the unorder elements, sort the element
if (sorted) {
std::sort(
col_vec.begin(), col_vec.begin() + k - 1,
[](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
}
}
}
for (Type j = 0; j < k; ++j) {
t_out[i * k + j] = col_vec[j].first;
t_indices[i * k + j] = col_vec[j].second;
}
}
}
template <typename T, typename Type>
static void FullTopKAssign(const Type& input_height, const Type& input_width,
const int& input_dim, const framework::Tensor* input,
const framework::Tensor* indices, T* output_data,
const int& k) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
if (input_dim == 1) {
auto e_input = framework::EigenVector<T>::Flatten(*input);
auto e_indices = framework::EigenVector<Type>::Flatten(*indices);
for (Type j = 0; j < k; ++j) {
output_data[i * input_width + e_indices(j)] = e_input(j);
}
} else {
auto e_input = framework::EigenMatrix<T>::Reshape(*input, input_dim - 1);
auto e_indices =
framework::EigenMatrix<Type>::Reshape(*indices, input_dim - 1);
for (Type j = 0; j < k; ++j) {
output_data[i * input_width + e_indices(i, j)] = e_input(i, j);
}
}
}
}
template <typename DeviceContext, typename T>
class TopkV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// Get the top k elements of each row of input tensor
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
auto* indices = context.Output<Tensor>("Indices");
const auto& in_dims = input->dims();
int k = static_cast<int>(context.Attr<int>("k"));
const auto& sorted = static_cast<bool>(context.Attr<bool>("sorted"));
const auto& largest = static_cast<bool>(context.Attr<bool>("largest"));
// axis < 0, cacluate the real axis
int axis = static_cast<int>(context.Attr<int>("axis"));
if (axis < 0) axis += in_dims.size();
// if K tensor is not null, will the use K tesnor as k
auto* k_t = context.Input<Tensor>("K");
if (k_t) {
k = k_t->data<int>()[0];
framework::DDim output_dims = output->dims();
// accroding to axis to set K value in the dim
output_dims[axis] = k;
output->Resize(output_dims);
indices->Resize(output_dims);
}
T* output_data = output->mutable_data<T>(context.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(context.GetPlace());
const auto& out_dims = output->dims();
if (axis + 1 == in_dims.size()) {
const int64_t& input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1];
FullTopK<T, int64_t>(input_height, input_width, in_dims.size(), input,
output_data, indices_data, k, largest, sorted);
} else {
// if the topk dims is not last dim, will tranpose and do topk
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
// get the trans input_dims, out_dims
framework::DDim trans_dims(in_dims);
framework::DDim trans_out_dims(output->dims());
for (size_t i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}
for (size_t i = 0; i < trans.size(); i++) {
trans_out_dims[i] = out_dims[trans[i]];
}
Tensor trans_inp;
trans_inp.mutable_data<T>(trans_dims, context.GetPlace());
int ndims = trans.size();
auto& dev_context =
context.template device_context<platform::CPUDeviceContext>();
// transpose the input value
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_context, *input,
&trans_inp, trans);
const int64_t input_height =
phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
// Allocate the temp tensor to the save the topk indices, values
Tensor tmp_out;
T* t_out = tmp_out.mutable_data<T>(trans_out_dims, context.GetPlace());
Tensor tmp_indices;
auto* t_ind =
tmp_indices.mutable_data<int64_t>(trans_out_dims, context.GetPlace());
// get the TopK value
FullTopK<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_inp, t_out, t_ind, k, largest, sorted);
// transpose back
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_context, tmp_indices, indices, trans);
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_context, tmp_out,
output, trans);
}
}
};
template <typename DeviceContext, typename T>
class TopkV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<Tensor>("Indices");
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int axis = static_cast<int>(context.Attr<int>("axis"));
const auto& in_dims = x->dims();
const auto& out_dims = indices->dims();
// axis < 0, get the real axis
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const size_t& k = out_dims[axis];
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
if (axis + 1 == in_dims.size()) {
// allocate the memory for the input_grad
// assign the out_grad to input_grad directly
const int64_t input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
// init the output grad with 0, because some input elements has no grad
memset(x_grad_data, 0, x_grad->numel() * sizeof(T));
// Assign the output_grad to input_grad
FullTopKAssign(input_height, input_width, in_dims.size(), out_grad,
indices, x_grad_data, k);
} else {
// can not assign grad to input_grad, must do the transpose
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.emplace_back(out_dims.size() - 1);
for (int i = axis + 1; i < out_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
framework::DDim trans_dims(out_dims);
framework::DDim trans_in_dims(in_dims);
for (size_t i = 0; i < trans.size(); i++) {
trans_dims[i] = out_dims[trans[i]];
trans_in_dims[i] = in_dims[trans[i]];
}
// transpose the out_grad, indices
Tensor trans_dO;
trans_dO.mutable_data<T>(trans_dims, context.GetPlace());
Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_dims, context.GetPlace());
int ndims = trans.size();
auto& dev_context =
context.template device_context<platform::CPUDeviceContext>();
// Do transpose
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_context, *out_grad,
&trans_dO, trans);
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_context, *indices, &trans_ind, trans);
const int64_t input_height = phi::product(
phi::slice_ddim(trans_in_dims, 0, trans_in_dims.size() - 1));
const int64_t input_width = trans_in_dims[trans_in_dims.size() - 1];
// Assign the out_grad to tranpose input_grad
Tensor tmp_out;
T* t_out = tmp_out.mutable_data<T>(trans_in_dims, context.GetPlace());
memset(t_out, 0, x_grad->numel() * sizeof(T));
FullTopKAssign<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_dO, &trans_ind, t_out, k);
// Transpose back
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_context, tmp_out,
x_grad, trans);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/top_k_v2_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
......
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/top_k_v2_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <memory>
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "xpu/refactor/math.h"
......
......@@ -52,7 +52,9 @@ const std::unordered_set<std::string> deprecated_op_names({"diag",
"reshape_grad",
"expand",
"expand_grad",
"sum"});
"sum",
"top_k",
"top_k_grad"});
class DefaultKernelSignatureMap {
public:
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/top_k_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Type>
static void FullTopKAssign(const Type& input_height,
const Type& input_width,
const int& input_dim,
const DenseTensor* input,
const DenseTensor* indices,
T* output_data,
const int& k) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
auto e_indices = EigenVector<Type>::Flatten(*indices);
for (Type j = 0; j < k; ++j) {
output_data[i * input_width + e_indices(j)] = e_input(j);
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
auto e_indices = EigenMatrix<Type>::Reshape(*indices, input_dim - 1);
for (Type j = 0; j < k; ++j) {
output_data[i * input_width + e_indices(i, j)] = e_input(i, j);
}
}
}
}
template <typename T, typename Context>
void TopkGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& indices,
int k,
int axis,
bool largest,
bool sorted,
DenseTensor* x_grad) {
const auto& in_dims = x.dims();
const auto& out_dims = indices.dims();
// axis < 0, get the real axis
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
T* x_grad_data = dev_ctx.template Alloc<T>(x_grad);
if (axis + 1 == in_dims.size()) {
// allocate the memory for the input_grad
// assign the out_grad to input_grad directly
const int64_t input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
// init the output grad with 0, because some input elements has no grad
memset(x_grad_data, 0, x_grad->numel() * sizeof(T));
// Assign the output_grad to input_grad
FullTopKAssign(input_height,
input_width,
in_dims.size(),
&out_grad,
&indices,
x_grad_data,
k);
} else {
// can not assign grad to input_grad, must do the transpose
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.emplace_back(out_dims.size() - 1);
for (int i = axis + 1; i < out_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
phi::DDim trans_dims(out_dims);
phi::DDim trans_in_dims(in_dims);
for (size_t i = 0; i < trans.size(); i++) {
trans_dims[i] = out_dims[trans[i]];
trans_in_dims[i] = in_dims[trans[i]];
}
// transpose the out_grad, indices
DenseTensor trans_dO;
DenseTensor trans_ind;
trans_dO.Resize(trans_dims);
trans_ind.Resize(trans_dims);
dev_ctx.template Alloc<T>(&trans_dO);
dev_ctx.template Alloc<int64_t>(&trans_ind);
int ndims = trans.size();
// Do transpose
funcs::TransCompute<phi::CPUContext, T>(
ndims, dev_ctx, out_grad, &trans_dO, trans);
funcs::TransCompute<phi::CPUContext, int64_t>(
ndims, dev_ctx, indices, &trans_ind, trans);
const int64_t input_height = phi::product(
phi::slice_ddim(trans_in_dims, 0, trans_in_dims.size() - 1));
const int64_t input_width = trans_in_dims[trans_in_dims.size() - 1];
// Assign the out_grad to tranpose input_grad
DenseTensor tmp_out;
tmp_out.Resize(trans_in_dims);
T* t_out = dev_ctx.template Alloc<T>(&tmp_out);
memset(t_out, 0, x_grad->numel() * sizeof(T));
FullTopKAssign<T, int64_t>(input_height,
input_width,
in_dims.size(),
&trans_dO,
&trans_ind,
t_out,
k);
// Transpose back
funcs::TransCompute<phi::CPUContext, T>(
ndims, dev_ctx, tmp_out, x_grad, trans);
}
}
} // namespace phi
PD_REGISTER_KERNEL(top_k_grad,
CPU,
ALL_LAYOUT,
phi::TopkGradKernel,
float,
double,
int32_t,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/top_k_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T, typename Type>
static void FullTopK(Type input_height,
Type input_width,
int input_dim,
const DenseTensor* input,
T* t_out,
Type* t_indices,
const int& k,
const bool& largest,
const bool& sorted) {
// when the k is small, will the partial sort
bool partial_sort_flag = (k * 64) < input_width;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
std::vector<std::pair<T, Type>> col_vec;
col_vec.reserve(input_width);
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
for (Type j = 0; j < input_width; ++j) {
col_vec.emplace_back(std::pair<T, Type>(e_input(j), j));
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
col_vec.emplace_back(std::pair<T, Type>(e_input(i, j), j));
}
}
if (partial_sort_flag) {
std::partial_sort(
col_vec.begin(),
col_vec.begin() + k,
col_vec.end(),
[&largest](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (largest) {
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
} else {
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
}
});
} else {
// use the nth-element to get the K-larger or K-small element
if (largest) {
std::nth_element(
col_vec.begin(),
col_vec.begin() + k - 1,
col_vec.end(),
[](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
});
// the nth-element will get the unorder elements, sort the element
if (sorted) {
std::sort(col_vec.begin(),
col_vec.begin() + k - 1,
[&largest](const std::pair<T, Type>& l,
const std::pair<T, Type>& r) {
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
});
}
} else {
std::nth_element(
col_vec.begin(),
col_vec.begin() + k - 1,
col_vec.end(),
[](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
// the nth-element will get the unorder elements, sort the element
if (sorted) {
std::sort(
col_vec.begin(),
col_vec.begin() + k - 1,
[](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
}
}
}
for (Type j = 0; j < k; ++j) {
t_out[i * k + j] = col_vec[j].first;
t_indices[i * k + j] = col_vec[j].second;
}
}
}
template <typename T, typename Context>
void TopkKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& k_scalar,
int axis,
bool largest,
bool sorted,
DenseTensor* out,
DenseTensor* indices) {
const auto* input = &x;
// Get the top k elements of each row of input tensor
const auto& in_dims = input->dims();
// axis < 0, cacluate the real axis
if (axis < 0) {
axis += in_dims.size();
}
int k = k_scalar.to<int>();
if (k_scalar.FromTensor()) {
auto out_dims = out->dims();
// accroding to axis to set K value in the dim
out_dims[axis] = k;
out->Resize(out_dims);
indices->Resize(out_dims);
}
T* out_data = dev_ctx.template Alloc<T>(out);
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);
const auto& out_dims = out->dims();
if (axis + 1 == in_dims.size()) {
const int64_t& input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1];
FullTopK<T, int64_t>(input_height,
input_width,
in_dims.size(),
input,
out_data,
indices_data,
k,
largest,
sorted);
} else {
// if the topk dims is not last dim, will tranpose and do topk
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
// get the trans input_dims, out_dims
phi::DDim trans_dims(in_dims);
phi::DDim trans_out_dims(out->dims());
for (size_t i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}
for (size_t i = 0; i < trans.size(); i++) {
trans_out_dims[i] = out_dims[trans[i]];
}
DenseTensor trans_inp;
trans_inp.Resize(trans_dims);
dev_ctx.template Alloc<T>(&trans_inp);
int ndims = trans.size();
// transpose the input value
funcs::TransCompute<phi::CPUContext, T>(
ndims, dev_ctx, *input, &trans_inp, trans);
const int64_t input_height =
phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
// Allocate the temp tensor to the save the topk indices, values
DenseTensor tmp_out;
DenseTensor tmp_indices;
tmp_out.Resize(trans_out_dims);
tmp_indices.Resize(trans_out_dims);
T* t_out = dev_ctx.template Alloc<T>(&tmp_out);
auto* t_ind = dev_ctx.template Alloc<int64_t>(&tmp_indices);
// get the TopK value
FullTopK<T, int64_t>(input_height,
input_width,
in_dims.size(),
&trans_inp,
t_out,
t_ind,
k,
largest,
sorted);
// transpose back
funcs::TransCompute<phi::CPUContext, int64_t>(
ndims, dev_ctx, tmp_indices, indices, trans);
funcs::TransCompute<phi::CPUContext, T>(
ndims, dev_ctx, tmp_out, out, trans);
}
}
} // namespace phi
PD_REGISTER_KERNEL(
top_k, CPU, ALL_LAYOUT, phi::TopkKernel, float, double, int32_t, int64_t) {}
......@@ -125,5 +125,43 @@ struct TensorSetConstantXPU {
};
#endif
template <typename Context, typename T>
inline void TransCompute(const int dim,
const Context& dev_ctx,
const DenseTensor& in,
DenseTensor* out,
const std::vector<int>& axis) {
switch (dim) {
case 1:
Transpose<Context, T, 1> trans1;
trans1(dev_ctx, in, out, axis);
break;
case 2:
Transpose<Context, T, 2> trans2;
trans2(dev_ctx, in, out, axis);
break;
case 3:
Transpose<Context, T, 3> trans3;
trans3(dev_ctx, in, out, axis);
break;
case 4:
Transpose<Context, T, 4> trans4;
trans4(dev_ctx, in, out, axis);
break;
case 5:
Transpose<Context, T, 5> trans5;
trans5(dev_ctx, in, out, axis);
break;
case 6:
Transpose<Context, T, 6> trans6;
trans6(dev_ctx, in, out, axis);
break;
default:
// for dim >= 7 situation
TransposeNormal<Context, T> trans_normal;
trans_normal(dev_ctx, in, out, axis);
}
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/top_k_grad_kernel.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace ops = paddle::operators;
template <typename T, typename Context>
void TopkGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& indices,
int k,
int axis,
bool largest,
bool sorted,
DenseTensor* x_grad) {
const auto& in_dims = x.dims();
const auto& out_dims = indices.dims();
// get the real the axis and the k
if (axis < 0) {
axis += in_dims.size();
}
const int& raw_height = in_dims[axis];
// allocate the cuda memory for the x_grad
T* x_grad_data = dev_ctx.template Alloc<T>(x_grad);
const T* out_grad_data = out_grad.data<T>();
const int64_t* indices_data = indices.data<int64_t>();
int pre, n, post;
ops::GetDims(in_dims, axis, &pre, &n, &post);
// calcluate the block and grid num
auto ComputeBlockSize = [](int col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
};
int block_size = ComputeBlockSize(post * k);
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(((max_threads - 1) / block_size + 1), 1);
int grid_size = std::min(max_blocks, pre);
// lanuch the cuda kernel to assign the grad
ops::AssignGradWithAxis<
T><<<grid_size, block_size, 64 * 4, dev_ctx.stream()>>>(
out_grad_data, indices_data, x_grad_data, pre, post, n, k);
}
} // namespace phi
PD_REGISTER_KERNEL(top_k_grad,
GPU,
ALL_LAYOUT,
phi::TopkGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/top_k_kernel.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace ops = paddle::operators;
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
template <typename T, typename Context>
void TopkKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& k_scalar,
int axis,
bool largest,
bool sorted,
DenseTensor* out,
DenseTensor* indices) {
const auto* input = &x;
// get the input dims
const auto& in_dims = input->dims();
// calcluate the real axis
if (axis < 0) axis += in_dims.size();
int k = k_scalar.to<int>();
if (k_scalar.FromTensor()) {
phi::DDim out_dims = out->dims();
out_dims[axis] = k;
out->Resize(out_dims);
indices->Resize(out_dims);
}
const auto& out_dims = out->dims();
const T* input_data = input->data<T>();
T* output_data = dev_ctx.template Alloc<T>(out);
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);
if (axis == in_dims.size() - 1) {
// if get the topK from the last axis
const int64_t& input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1];
if (k > input_width) {
k = input_width;
}
// The conclusion is drawn from the data through multiple sets of
// statistics
if (input_width >= 128 && k >= input_width * 0.75) {
if (ops::SortTopk<T>(
paddle::platform::CUDADeviceContext(dev_ctx.GetPlace()),
input,
input_width,
input_height,
k,
out,
indices,
largest)) {
// Successed, return.
return;
} else {
LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
"default topk kernel.";
}
}
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
switch (ops::GetDesiredBlockDim(input_width)) {
#ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM(ops::KeMatrixTopK<
T,
20,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data,
k,
indices_data,
input_data,
input_width,
input_width,
static_cast<int>(k),
gridx,
input_height,
largest));
#else
FIXED_BLOCK_DIM(ops::KeMatrixTopK<
T,
5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data,
k,
indices_data,
input_data,
input_width,
input_width,
static_cast<int>(k),
gridx,
input_height,
largest));
#endif
default:
PADDLE_THROW(errors::Fatal(
"the input data shape has error in the topk cuda kernel."));
}
} else {
// if get topK not from the last axis, will tranpose the tensor and get
// TopK
// first step, prepare the trans args for the tranpose
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.emplace_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
phi::DDim trans_dims(in_dims);
phi::DDim trans_out_dims(out->dims());
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
trans_out_dims[i] = out_dims[trans[i]];
}
// second step, tranpose the input
DenseTensor trans_input;
trans_input.Resize(trans_dims);
dev_ctx.template Alloc<T>(&trans_input);
int ndims = trans.size();
funcs::TransCompute<phi::GPUContext, T>(
ndims, dev_ctx, *input, &trans_input, trans);
// third step, calcluate the topk
// allocate the tmp cuda memory for the tmp result
DenseTensor trans_ind;
DenseTensor trans_out;
trans_ind.Resize(trans_out_dims);
trans_out.Resize(trans_out_dims);
dev_ctx.template Alloc<int64_t>(&trans_ind);
dev_ctx.template Alloc<T>(&trans_out);
const int64_t input_height =
phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
if (k > input_width) k = input_width;
// The conclusion is drawn from the data through multiple sets of
// statistics
if (input_width >= 128 && k >= input_width * 0.75) {
if (ops::SortTopk<T>(
paddle::platform::CUDADeviceContext(dev_ctx.GetPlace()),
&trans_input,
input_width,
input_height,
k,
&trans_out,
&trans_ind,
largest)) {
// last step, tranpose back the indices and output
funcs::TransCompute<phi::GPUContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans);
funcs::TransCompute<phi::GPUContext, T>(
ndims, dev_ctx, trans_out, out, trans);
return;
} else {
LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
"default topk kernel.";
}
}
const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
switch (ops::GetDesiredBlockDim(input_width)) {
#ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM(ops::KeMatrixTopK<
T,
20,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
trans_out.data<T>(),
k,
trans_ind.data<int64_t>(),
trans_input.data<T>(),
input_width,
input_width,
static_cast<int>(k),
gridx,
input_height,
largest));
#else
FIXED_BLOCK_DIM(ops::KeMatrixTopK<
T,
5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
trans_out.data<T>(),
k,
trans_ind.data<int64_t>(),
trans_input.data<T>(),
input_width,
input_width,
static_cast<int>(k),
gridx,
input_height,
largest));
#endif
default:
PADDLE_THROW(errors::Fatal(
"the input data shape has error in the topk cuda kernel."));
}
// last step, tranpose back the indices and output
funcs::TransCompute<phi::GPUContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans);
funcs::TransCompute<phi::GPUContext, T>(
ndims, dev_ctx, trans_out, out, trans);
}
}
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
} // namespace phi
PD_REGISTER_KERNEL(top_k,
GPU,
ALL_LAYOUT,
phi::TopkKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TopkGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& indices,
int k,
int axis,
bool largest,
bool sorted,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void TopkKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& k_scalar,
int axis,
bool largest,
bool sorted,
DenseTensor* out,
DenseTensor* indices);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TopkOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("K")) {
return KernelSignature(
"top_k", {"X"}, {"K", "axis", "largest", "sorted"}, {"Out", "Indices"});
} else {
return KernelSignature(
"top_k", {"X"}, {"k", "axis", "largest", "sorted"}, {"Out", "Indices"});
}
}
KernelSignature TopkGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("top_k_grad",
{GradVarName("Out"), "X", "Indices"},
{"k", "axis", "largest", "sorted"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(top_k_v2, top_k);
PD_REGISTER_BASE_KERNEL_NAME(top_k_v2_grad, top_k_grad);
PD_REGISTER_ARG_MAPPING_FN(top_k_v2, phi::TopkOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(top_k_v2_grad, phi::TopkGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册