未验证 提交 286eca2d 编写于 作者: W wawltor 提交者: GitHub

update the code for the topk v2

add the top v2 for the paddlepaddle api 2.0
上级 f8238411
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdio.h>
#include <cstdio>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<paddle::platform::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t,
paddle::platform::float16> {};
} // namespace cub
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
struct SegmentOffsetIter {
EIGEN_DEVICE_FUNC
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
return idx * num_cols_;
}
int num_cols_;
};
// Iter using into a column
struct ColumnIndexIter {
explicit ColumnIndexIter(int num_cols) : num_cols_(num_cols) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
const Eigen::array<int, 1>& ix) const {
return ix[0] % num_cols_;
}
int num_cols_;
};
inline static int GetDesiredBlockDim(int dim) {
if (dim > 128) {
return 256;
} else if (dim > 64) {
return 128;
} else if (dim > 32) {
return 64;
} else {
return 32;
}
}
template <typename T>
__global__ void InitIndex(T* indices, T num_rows, T num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
for (int64_t j = row_id; j < num_rows; j += gridDim.x) {
for (int64_t i = col_id; i < num_cols; i += blockDim.x) {
indices[j * num_cols + i] = i;
}
}
}
template <typename T>
struct Pair {
__device__ __forceinline__ Pair() {}
__device__ __forceinline__ Pair(T value, int64_t id) : v(value), id(id) {}
__device__ __forceinline__ void set(T value, int64_t id) {
v = value;
id = id;
}
__device__ __forceinline__ void operator=(const Pair<T>& in) {
v = in.v;
id = in.id;
}
__device__ __forceinline__ bool operator<(const T value) const {
return (v < value);
}
__device__ __forceinline__ bool operator>(const T value) const {
return (v > value);
}
__device__ __forceinline__ bool operator<(const Pair<T>& in) const {
return (v < in.v) || ((v == in.v) && (id > in.id));
}
__device__ __forceinline__ bool operator>(const Pair<T>& in) const {
return (v > in.v) || ((v == in.v) && (id < in.id));
}
T v;
int64_t id;
};
template <typename T>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T>& p,
int beam_size, const bool& largest) {
for (int k = beam_size - 2; k >= 0; k--) {
if (largest) {
if (topk[k] < p) {
topk[k + 1] = topk[k];
} else {
topk[k + 1] = p;
return;
}
} else {
if (topk[k] > p) {
topk[k + 1] = topk[k];
} else {
topk[k + 1] = p;
return;
}
}
}
topk[0] = p;
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
int dim, int beam_size,
const bool& largest) {
while (idx < dim) {
if (largest) {
if (topk[beam_size - 1] < src[idx]) {
Pair<T> tmp(src[idx], idx);
AddTo<T>(topk, tmp, beam_size, largest);
}
} else {
if (topk[beam_size - 1] > src[idx]) {
Pair<T> tmp(src[idx], idx);
AddTo<T>(topk, tmp, beam_size, largest);
}
}
idx += BlockSize;
}
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
int dim, const Pair<T>& max,
int beam_size, const bool& largest) {
while (idx < dim) {
if (largest) {
if (topk[beam_size - 1] < src[idx]) {
Pair<T> tmp(src[idx], idx);
if (tmp < max) {
AddTo<T>(topk, tmp, beam_size, largest);
}
}
} else {
if (topk[beam_size - 1] > src[idx]) {
Pair<T> tmp(src[idx], idx);
if (tmp > max) {
AddTo<T>(topk, tmp, beam_size, largest);
}
}
}
idx += BlockSize;
}
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
int beam_size, const T* src,
bool* firstStep, bool* is_empty,
Pair<T>* max, int dim,
const int tid, bool largest) {
if (*beam > 0) {
int length = (*beam) < beam_size ? *beam : beam_size;
if (*firstStep) {
*firstStep = false;
GetTopK<T, BlockSize>(topk, src, tid, dim, length, largest);
} else {
for (int k = 0; k < MaxLength; k++) {
if (k < MaxLength - (*beam)) {
topk[k] = topk[k + *beam];
} else {
topk[k].set(-static_cast<T>(INFINITY), -1);
}
}
if (!(*is_empty)) {
GetTopK<T, BlockSize>(topk + MaxLength - *beam, src, tid, dim, *max,
length, largest);
}
}
*max = topk[MaxLength - 1];
if ((*max).v == -static_cast<T>(1)) *is_empty = true;
*beam = 0;
}
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
Pair<T> topk[], T** topVal,
int64_t** topIds, int* beam, int* k,
const int tid, const int warp,
const bool& largest) {
while (true) {
__syncthreads();
if (tid < BlockSize / 2) {
if (largest) {
if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
maxid[tid] = tid + BlockSize / 2;
} else {
maxid[tid] = tid;
}
} else {
if (sh_topk[tid] > sh_topk[tid + BlockSize / 2]) {
maxid[tid] = tid + BlockSize / 2;
} else {
maxid[tid] = tid;
}
}
}
__syncthreads();
for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) {
if (tid < stride) {
if (largest) {
if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) {
maxid[tid] = maxid[tid + stride];
}
} else {
if (sh_topk[maxid[tid]] > sh_topk[maxid[tid + stride]]) {
maxid[tid] = maxid[tid + stride];
}
}
}
__syncthreads();
}
__syncthreads();
if (tid == 0) {
**topVal = sh_topk[maxid[0]].v;
**topIds = sh_topk[maxid[0]].id;
(*topVal)++;
(*topIds)++;
}
if (tid == maxid[0]) (*beam)++;
if (--(*k) == 0) break;
__syncthreads();
if (tid == maxid[0]) {
if (*beam < MaxLength) {
sh_topk[tid] = topk[*beam];
}
}
// NOTE(zcd): temporary solution
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
if (maxid[0] / 32 == warp) {
if (platform::CudaShuffleSync(mask, *beam, (maxid[0]) % 32, 32) ==
MaxLength)
break;
}
}
}
/**
* Each block compute one sample.
* In a block:
* 1. every thread get top MaxLength value;
* 2. merge to sh_topk, block reduce and get max value;
* 3. go to the second setp, until one thread's topk value is null;
* 4. go to the first setp, until get the topk value.
*/
template <typename T, int MaxLength, int BlockSize>
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k,
int grid_dim, int num, bool largest = true) {
__shared__ Pair<T> sh_topk[BlockSize];
const int tid = threadIdx.x;
const int warp = threadIdx.x / 32;
const int bid = blockIdx.x;
for (int i = bid; i < num; i += grid_dim) {
int top_num = k;
__shared__ int maxid[BlockSize / 2];
T* out = output + i * output_stride;
int64_t* inds = indices + i * k;
Pair<T> topk[MaxLength];
int beam = MaxLength;
Pair<T> max;
bool is_empty = false;
bool firststep = true;
for (int j = 0; j < MaxLength; j++) {
if (largest) {
topk[j].set(-static_cast<T>(INFINITY), -1);
} else {
topk[j].set(static_cast<T>(INFINITY), -1);
}
}
while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, k, src + i * lds,
&firststep, &is_empty, &max, dim,
tid, largest);
sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
&beam, &top_num, tid, warp, largest);
}
}
}
template <typename T, int MaxLength, int BlockSize>
__global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
size_t rows, size_t cols, size_t k) {
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
x_grad[i * cols + j] = 0;
}
for (size_t j = 0; j < k; ++j) {
size_t idx = indices[i * k + j];
x_grad[i * cols + idx] = out_grad[i * k + j];
}
}
}
// the grad assign with the axis
template <typename T>
__global__ void AssignGradWithAxis(const T* grad_out, const int64_t* indices,
T* grad_in, int pre, int post,
int raw_height, int k) {
// raw_height is the length of topk axis
for (int i = blockIdx.x; i < pre; i += gridDim.x) {
const int& base_index = i * post * k;
const int& base_grad = i * post * raw_height;
for (int j = threadIdx.x; j < raw_height * post; j += blockDim.x) {
grad_in[base_grad + j] = static_cast<T>(0);
}
for (int j = threadIdx.x; j < k * post; j += blockDim.x) {
const int64_t idx_ij = indices[base_index + j];
const int64_t in_ij = base_grad + (idx_ij * post) + (j % post);
grad_in[in_ij] = grad_out[idx_ij];
}
}
}
// use the radix sort for the topk
template <typename T>
bool SortTopk(const platform::CUDADeviceContext& ctx,
const framework::Tensor* input_tensor, const int64_t num_cols,
const int64_t num_rows, const int k,
framework::Tensor* out_tensor, framework::Tensor* indices_tensor,
bool largest = true) {
auto cu_stream = ctx.stream();
Tensor input_indices;
const std::vector<int64_t> dims = {num_rows, num_cols};
auto dim = framework::make_ddim(dims);
input_indices.Resize(dim);
// input_indices.Resize(num_rows*num_cols);
input_indices.mutable_data<int64_t>(ctx.GetPlace());
size_t temp_storage_bytes = -1;
auto ComputeBlockSize = [](int col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
};
int block_size = ComputeBlockSize(num_cols);
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
// actually, int num_rows < max_grid_size
unsigned int grid_size = num_rows < maxGridDimX
? static_cast<unsigned int>(num_rows)
: maxGridDimX;
// Init a index array
InitIndex<int64_t><<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<int64_t>(), num_rows, num_cols);
// create iter for counting input
cub::CountingInputIterator<int64_t> counting_iter(0);
// segment_offset is used for move to next row
cub::TransformInputIterator<int64_t, SegmentOffsetIter,
cub::CountingInputIterator<int64_t>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
T* sorted_values_ptr;
int64_t* sorted_indices_ptr;
Tensor temp_values;
Tensor temp_indices;
const T* input = input_tensor->data<T>();
T* values = out_tensor->data<T>();
int64_t* indices = indices_tensor->mutable_data<int64_t>(ctx.GetPlace());
if (k == num_cols) {
// Doing a full sort.
sorted_values_ptr = values;
sorted_indices_ptr = indices;
} else {
temp_values.Resize(dim);
temp_indices.Resize(dim);
sorted_values_ptr = temp_values.mutable_data<T>(ctx.GetPlace());
sorted_indices_ptr = temp_indices.mutable_data<int64_t>(ctx.GetPlace());
}
// Get temp storage buffer size, maybe can allocate a fixed buffer to save
// time.
if (largest) {
auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr, temp_storage_bytes, input, sorted_values_ptr,
input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
if (err != cudaSuccess) {
LOG(ERROR)
<< "TopKOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
"temp_storage_bytes, status: "
<< cudaGetErrorString(err);
return false;
}
} else {
auto err = cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, input, sorted_values_ptr,
input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
if (err != cudaSuccess) {
LOG(ERROR) << "TopKOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairs to calculate "
"temp_storage_bytes, status: "
<< cudaGetErrorString(err);
return false;
}
}
Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
if (largest) {
auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage.data<uint8_t>(), temp_storage_bytes, input,
sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr,
num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1,
0, sizeof(T) * 8, cu_stream);
if (err != cudaSuccess) {
LOG(ERROR) << "TopKOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to "
"sort input, "
"temp_storage_bytes: "
<< temp_storage_bytes
<< ", status: " << cudaGetErrorString(err);
return false;
}
} else {
auto err = cub::DeviceSegmentedRadixSort::SortPairs(
temp_storage.data<uint8_t>(), temp_storage_bytes, input,
sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr,
num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1,
0, sizeof(T) * 8, cu_stream);
if (err != cudaSuccess) {
LOG(ERROR) << "TopKOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairs to "
"sort input, "
"temp_storage_bytes: "
<< temp_storage_bytes
<< ", status: " << cudaGetErrorString(err);
return false;
}
}
auto& dev = *ctx.eigen_device();
if (k < num_cols) {
// copy sliced data to output.
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
auto e_indices = EigenMatrix<int64_t>::From(*indices_tensor, dim);
auto e_tmp_indices = EigenMatrix<int64_t>::From(temp_indices);
std::vector<int> odims = {static_cast<int>(num_rows), static_cast<int>(k)};
auto dim = framework::make_ddim(odims);
auto e_values = EigenMatrix<T>::From(*out_tensor, dim);
auto e_tmp_values = EigenMatrix<T>::From(temp_values);
e_indices.device(dev) = e_tmp_indices.slice(slice_indices, slice_sizes);
e_values.device(dev) = e_tmp_values.slice(slice_indices, slice_sizes);
}
return true;
}
} // namespace operators
} // namespace paddle
......@@ -12,474 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstdio>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<paddle::platform::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t,
paddle::platform::float16> {};
} // namespace cub
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct Pair {
__device__ __forceinline__ Pair() {}
__device__ __forceinline__ Pair(T value, int64_t id) : v(value), id(id) {}
__device__ __forceinline__ void set(T value, int64_t id) {
v = value;
id = id;
}
__device__ __forceinline__ void operator=(const Pair<T>& in) {
v = in.v;
id = in.id;
}
__device__ __forceinline__ bool operator<(const T value) const {
return (v < value);
}
__device__ __forceinline__ bool operator<(const Pair<T>& in) const {
return (v < in.v) || ((v == in.v) && (id > in.id));
}
__device__ __forceinline__ bool operator>(const Pair<T>& in) const {
return (v > in.v) || ((v == in.v) && (id < in.id));
}
T v;
int64_t id;
};
template <typename T>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T>& p,
int beam_size) {
for (int k = beam_size - 2; k >= 0; k--) {
if (topk[k] < p) {
topk[k + 1] = topk[k];
} else {
topk[k + 1] = p;
return;
}
}
topk[0] = p;
}
template <typename T, int beam_size>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T>& p) {
for (int k = beam_size - 2; k >= 0; k--) {
if (topk[k] < p) {
topk[k + 1] = topk[k];
} else {
topk[k + 1] = p;
return;
}
}
topk[0] = p;
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
int dim, int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < src[idx]) {
Pair<T> tmp(src[idx], idx);
AddTo<T>(topk, tmp, beam_size);
}
idx += BlockSize;
}
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
int dim, const Pair<T>& max,
int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < src[idx]) {
Pair<T> tmp(src[idx], idx);
if (tmp < max) {
AddTo<T>(topk, tmp, beam_size);
}
}
idx += BlockSize;
}
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
int idx, int dim, int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < val[idx]) {
Pair<T> tmp(val[idx], col[idx]);
AddTo<T>(topk, tmp, beam_size);
}
idx += BlockSize;
}
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
int idx, int dim, const Pair<T>& max,
int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < val[idx]) {
Pair<T> tmp(val[idx], col[idx]);
if (tmp < max) {
AddTo<T>(topk, tmp, beam_size);
}
}
idx += BlockSize;
}
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
int beam_size, const T* src,
bool* firstStep, bool* is_empty,
Pair<T>* max, int dim,
const int tid) {
if (*beam > 0) {
int length = (*beam) < beam_size ? *beam : beam_size;
if (*firstStep) {
*firstStep = false;
GetTopK<T, BlockSize>(topk, src, tid, dim, length);
} else {
for (int k = 0; k < MaxLength; k++) {
if (k < MaxLength - (*beam)) {
topk[k] = topk[k + *beam];
} else {
topk[k].set(-static_cast<T>(INFINITY), -1);
}
}
if (!(*is_empty)) {
GetTopK<T, BlockSize>(topk + MaxLength - *beam, src, tid, dim, *max,
length);
}
}
*max = topk[MaxLength - 1];
if ((*max).v == -static_cast<T>(1)) *is_empty = true;
*beam = 0;
}
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
int beam_size, const T* val,
int* col, bool* firstStep,
bool* is_empty, Pair<T>* max,
int dim, const int tid) {
if (*beam > 0) {
int length = (*beam) < beam_size ? *beam : beam_size;
if (*firstStep) {
*firstStep = false;
GetTopK<T, BlockSize>(topk, val, col, tid, dim, length);
} else {
for (int k = 0; k < MaxLength; k++) {
if (k < MaxLength - *beam) {
topk[k] = topk[k + *beam];
} else {
topk[k].set(-static_cast<T>(INFINITY), -1);
}
}
if (!(*is_empty)) {
GetTopK<T, BlockSize>(topk + MaxLength - *beam, val, col, tid, dim, max,
length);
}
}
*max = topk[MaxLength - 1];
if ((*max).v == -1) *is_empty = true;
*beam = 0;
}
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
Pair<T> topk[], T** topVal,
int64_t** topIds, int* beam, int* k,
const int tid, const int warp) {
while (true) {
__syncthreads();
if (tid < BlockSize / 2) {
if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
maxid[tid] = tid + BlockSize / 2;
} else {
maxid[tid] = tid;
}
}
__syncthreads();
for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) {
if (tid < stride) {
if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) {
maxid[tid] = maxid[tid + stride];
}
}
__syncthreads();
}
__syncthreads();
if (tid == 0) {
**topVal = sh_topk[maxid[0]].v;
**topIds = sh_topk[maxid[0]].id;
(*topVal)++;
(*topIds)++;
}
if (tid == maxid[0]) (*beam)++;
if (--(*k) == 0) break;
__syncthreads();
if (tid == maxid[0]) {
if (*beam < MaxLength) {
sh_topk[tid] = topk[*beam];
}
}
// NOTE(zcd): temporary solution
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
if (maxid[0] / 32 == warp) {
if (platform::CudaShuffleSync(mask, *beam, (maxid[0]) % 32, 32) ==
MaxLength)
break;
}
}
}
/**
* Each block compute one sample.
* In a block:
* 1. every thread get top MaxLength value;
* 2. merge to sh_topk, block reduce and get max value;
* 3. go to the second setp, until one thread's topk value is null;
* 4. go to the first setp, until get the topk value.
*/
template <typename T, int MaxLength, int BlockSize>
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k,
int grid_dim, int num) {
__shared__ Pair<T> sh_topk[BlockSize];
const int tid = threadIdx.x;
const int warp = threadIdx.x / 32;
const int bid = blockIdx.x;
for (int i = bid; i < num; i += grid_dim) {
int top_num = k;
__shared__ int maxid[BlockSize / 2];
T* out = output + i * output_stride;
int64_t* inds = indices + i * k;
Pair<T> topk[MaxLength];
int beam = MaxLength;
Pair<T> max;
bool is_empty = false;
bool firststep = true;
for (int j = 0; j < MaxLength; j++) {
topk[j].set(-static_cast<T>(INFINITY), -1);
}
while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>(
topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
&beam, &top_num, tid, warp);
}
}
}
template <typename T, int MaxLength, int BlockSize>
__global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
size_t rows, size_t cols, size_t k) {
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
x_grad[i * cols + j] = 0;
}
for (size_t j = 0; j < k; ++j) {
size_t idx = indices[i * k + j];
x_grad[i * cols + idx] = out_grad[i * k + j];
}
}
}
inline static int GetDesiredBlockDim(int dim) {
if (dim > 128) {
return 256;
} else if (dim > 64) {
return 128;
} else if (dim > 32) {
return 64;
} else {
return 32;
}
}
// Iter for move to next row
struct SegmentOffsetIter {
EIGEN_DEVICE_FUNC
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
return idx * num_cols_;
}
int num_cols_;
};
// Iter using into a column
struct ColumnIndexIter {
explicit ColumnIndexIter(int num_cols) : num_cols_(num_cols) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
const Eigen::array<int, 1>& ix) const {
return ix[0] % num_cols_;
}
int num_cols_;
};
__global__ void InitIndex(int64_t* indices, int64_t num_rows,
int64_t num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
for (int64_t j = row_id; j < num_rows; j += gridDim.x) {
for (int64_t i = col_id; i < num_cols; i += blockDim.x) {
indices[j * num_cols + i] = i;
}
}
}
template <typename T>
bool SortTopk(const platform::CUDADeviceContext& ctx,
const framework::Tensor* input_tensor, const int64_t num_cols,
const int64_t num_rows, const int k,
framework::Tensor* out_tensor,
framework::Tensor* indices_tensor) {
auto cu_stream = ctx.stream();
Tensor input_indices;
const std::vector<int64_t> dims = {num_rows, num_cols};
auto dim = framework::make_ddim(dims);
input_indices.Resize(dim);
// input_indices.Resize(num_rows*num_cols);
input_indices.mutable_data<int64_t>(ctx.GetPlace());
size_t temp_storage_bytes = -1;
auto ComputeBlockSize = [](int col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
};
int block_size = ComputeBlockSize(num_cols);
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
// actually, int num_rows < max_grid_size
unsigned int grid_size = num_rows < maxGridDimX
? static_cast<unsigned int>(num_rows)
: maxGridDimX;
// Init a index array
InitIndex<<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<int64_t>(), num_rows, num_cols);
// create iter for counting input
cub::CountingInputIterator<int64_t> counting_iter(0);
// segment_offset is used for move to next row
cub::TransformInputIterator<int64_t, SegmentOffsetIter,
cub::CountingInputIterator<int64_t>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
T* sorted_values_ptr;
int64_t* sorted_indices_ptr;
Tensor temp_values;
Tensor temp_indices;
const T* input = input_tensor->data<T>();
T* values = out_tensor->data<T>();
int64_t* indices = indices_tensor->mutable_data<int64_t>(ctx.GetPlace());
if (k == num_cols) {
// Doing a full sort.
sorted_values_ptr = values;
sorted_indices_ptr = indices;
} else {
temp_values.Resize(dim);
temp_indices.Resize(dim);
sorted_values_ptr = temp_values.mutable_data<T>(ctx.GetPlace());
sorted_indices_ptr = temp_indices.mutable_data<int64_t>(ctx.GetPlace());
}
// Get temp storage buffer size, maybe can allocate a fixed buffer to save
// time.
auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr, temp_storage_bytes, input, sorted_values_ptr,
input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
if (err != cudaSuccess) {
LOG(ERROR)
<< "TopKOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
"temp_storage_bytes, status: "
<< cudaGetErrorString(err);
return false;
}
Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage.data<uint8_t>(), temp_storage_bytes, input,
sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr,
num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1,
0, sizeof(T) * 8, cu_stream);
if (err != cudaSuccess) {
LOG(ERROR)
<< "TopKOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
"temp_storage_bytes: "
<< temp_storage_bytes << ", status: " << cudaGetErrorString(err);
return false;
}
auto& dev = *ctx.eigen_device();
if (k < num_cols) {
// copy sliced data to output.
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
auto e_indices = EigenMatrix<int64_t>::From(*indices_tensor, dim);
auto e_tmp_indices = EigenMatrix<int64_t>::From(temp_indices);
std::vector<int> odims = {static_cast<int>(num_rows), static_cast<int>(k)};
auto dim = framework::make_ddim(odims);
auto e_values = EigenMatrix<T>::From(*out_tensor, dim);
auto e_tmp_values = EigenMatrix<T>::From(temp_values);
e_indices.device(dev) = e_tmp_indices.slice(slice_indices, slice_sizes);
e_values.device(dev) = e_tmp_values.slice(slice_indices, slice_sizes);
}
return true;
}
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
......@@ -523,7 +70,6 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const int64_t input_width = inputdims[inputdims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
if ((input_width <= 1024 || k >= 128 || k == input_width)) {
if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
indices)) {
......@@ -576,7 +122,6 @@ class TopkOpGradCUDAKernel : public framework::OpKernel<T> {
framework::product(framework::slice_ddim(xdims, 0, xdims.size() - 1));
const size_t col = xdims[xdims.size() - 1];
const auto& dev_ctx = context.cuda_device_context();
const int kMaxHeight = 2048;
int gridx = row < kMaxHeight ? row : kMaxHeight;
switch (GetDesiredBlockDim(col)) {
......@@ -595,7 +140,6 @@ class TopkOpGradCUDAKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
top_k,
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/top_k_v2_op.h"
#include <memory>
namespace paddle {
namespace operators {
class TopkV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of TopkOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of TopkOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Indices"),
"Output(Indices) of TopkOp should not be null.");
auto input_dims = ctx->GetInputDim("X");
const int& dim_size = input_dims.size();
const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
int axis = static_cast<int>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE_EQ((axis < dim_size) && (axis >= (-1 * dim_size)), true,
"the axis of topk"
"must be [-%d, %d), but you set axis is %d",
dim_size, dim_size, axis);
if (axis < 0) axis += dim_size;
PADDLE_ENFORCE_GE(
k, 1, "the attribute of k in the topk must >= 1, but received %d .", k);
PADDLE_ENFORCE_GE(input_dims.size(), 1,
"input of topk must have >= 1d shape");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(
input_dims[axis], k,
"input of topk op must have >= %d columns in axis of %d", k, axis);
}
framework::DDim dims = input_dims;
dims[axis] = k;
ctx->SetOutputDim("Out", dims);
ctx->SetOutputDim("Indices", dims);
ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(),
layout_, library_);
}
};
class TopkV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input of Topk op");
AddInput("K",
"(Tensor) Number of top elements to look for along "
"the last dimension (along each row for matrices).")
.AsDispensable();
AddOutput("Out", "(Tensor) The output tensor of Topk op");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
AddComment(R"DOC(
Top K operator
If the input is a vector (1d tensor), this operator finds the k largest
entries in the vector and outputs their values and indices as vectors.
Thus values[j] is the j-th largest entry in input, and its index is indices[j].
For matrices, this operator computes the top k entries in each row. )DOC");
AddAttr<int>("k",
"(int, default 1) Number of top elements to look for along "
"the tensor).")
.SetDefault(1);
AddAttr<int>("axis",
"the axis to sort and get the k indices, value."
"if not set, will get k value in last axis.")
.SetDefault(-1);
AddAttr<bool>("largest",
"control flag whether to return largest or smallest")
.SetDefault(true);
AddAttr<bool>("sorted",
"control flag whether to return elements in sorted order")
.SetDefault(true);
}
};
class TopkV2OpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should be not null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Indices"), true,
platform::errors::InvalidArgument("Input(Indices) should be not null"));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Grad Input(Out) should be not null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument("Grad Output(X) should be not null"));
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
template <typename T>
class TopkV2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("top_k_v2_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X"));
op->SetInput("Indices", this->Output("Indices"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(top_k_v2, ops::TopkV2Op, ops::TopkV2OpMaker,
ops::TopkV2GradOpMaker<paddle::framework::OpDesc>,
ops::TopkV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(top_k_v2_grad, ops::TopkV2OpGrad);
REGISTER_OP_CPU_KERNEL(top_k_v2,
ops::TopkV2Kernel<paddle::platform::CPUPlace, float>,
ops::TopkV2Kernel<paddle::platform::CPUPlace, double>,
ops::TopkV2Kernel<paddle::platform::CPUPlace, int32_t>,
ops::TopkV2Kernel<paddle::platform::CPUPlace, int64_t>)
REGISTER_OP_CPU_KERNEL(
top_k_v2_grad, ops::TopkV2GradKernel<paddle::platform::CPUPlace, float>,
ops::TopkV2GradKernel<paddle::platform::CPUPlace, double>,
ops::TopkV2GradKernel<paddle::platform::CPUPlace, int32_t>,
ops::TopkV2GradKernel<paddle::platform::CPUPlace, int64_t>)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/p_norm_op.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_v2_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
template <typename DeviceContext, typename T>
class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
// get the attributes
int k = static_cast<int>(ctx.Attr<int>("k"));
int axis = static_cast<int>(ctx.Attr<int>("axis"));
const bool& sorted = static_cast<bool>(ctx.Attr<bool>("sorted"));
const bool& largest = static_cast<bool>(ctx.Attr<bool>("largest"));
// get the input dims
const auto& in_dims = input->dims();
// calcluate the real axis
if (axis < 0) axis += in_dims.size();
auto* k_t = ctx.Input<Tensor>("K");
if (k_t) {
Tensor k_host;
framework::TensorCopySync(*k_t, platform::CPUPlace(), &k_host);
k = k_host.data<int>()[0];
framework::DDim output_dims = output->dims();
output_dims[axis] = k;
output->Resize(output_dims);
indices->Resize(output_dims);
}
const auto& out_dims = output->dims();
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
if (axis == in_dims.size() - 1) {
// if get the topK from the last axis
const int64_t& input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
if (k > input_width) k = input_width;
if ((input_width <= 1024 || k >= 128 || k == input_width)) {
if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
indices, largest)) {
// Successed, return.
return;
} else {
LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
"default topk kernel.";
}
}
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
switch (GetDesiredBlockDim(input_width)) {
FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, k, indices_data, input_data, input_width,
input_width, static_cast<int>(k), gridx, input_height,
largest));
default:
PADDLE_THROW(platform::errors::Fatal(
"the input data shape has error in the topk cuda kernel."));
}
} else {
// if get topK not from the last axis, will tranpose the tensor and get
// TopK
// first step, prepare the trans args for the tranpose
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.emplace_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
framework::DDim trans_dims(in_dims);
framework::DDim trans_out_dims(output->dims());
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
trans_out_dims[i] = out_dims[trans[i]];
}
// second step, tranpose the input
Tensor trans_input;
trans_input.mutable_data<T>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
const auto& dev_ctx = ctx.cuda_device_context();
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *input,
&trans_input, trans);
// third step, calcluate the topk
// allocate the tmp cuda memory for the tmp result
Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_out_dims, ctx.GetPlace());
Tensor trans_out;
trans_out.mutable_data<T>(trans_out_dims, ctx.GetPlace());
const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
if (k > input_width) k = input_width;
if ((input_width <= 1024 || k >= 128 || k == input_width)) {
if (SortTopk<T>(dev_ctx, &trans_input, input_width, input_height, k,
&trans_out, &trans_ind, largest)) {
// last step, tranpose back the indices and output
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans);
TransCompute<platform::CUDADeviceContext, T>(
ndims, dev_ctx, trans_out, output, trans);
return;
} else {
LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
"default topk kernel.";
}
}
const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
switch (GetDesiredBlockDim(input_width)) {
FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
trans_out.data<T>(), k, trans_ind.data<int64_t>(),
trans_input.data<T>(), input_width, input_width,
static_cast<int>(k), gridx, input_height, largest));
default:
PADDLE_THROW(platform::errors::Fatal(
"the input data shape has error in the topk cuda kernel."));
}
// last step, tranpose back the indices and output
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans);
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, trans_out,
output, trans);
}
}
};
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
template <typename DeviceContext, typename T>
class TopkV2OpGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
auto* x = context.Input<Tensor>("X");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<Tensor>("Indices");
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int axis = context.Attr<int>("axis");
const auto& in_dims = x->dims();
const auto& out_dims = indices->dims();
// get the real the axis and the k
if (axis < 0) axis += in_dims.size();
const int& k = out_dims[axis];
const int& raw_height = in_dims[axis];
// allocate the cuda memory for the x_grad
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
const T* out_grad_data = out_grad->data<T>();
const int64_t* indices_data = indices->data<int64_t>();
int pre, n, post;
GetDims(in_dims, axis, &pre, &n, &post);
// calcluate the block and grid num
auto& dev_ctx = context.cuda_device_context();
auto ComputeBlockSize = [](int col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
};
int block_size = ComputeBlockSize(post * k);
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(((max_threads - 1) / block_size + 1), 1);
int grid_size = std::min(max_blocks, pre);
// lanuch the cuda kernel to assign the grad
AssignGradWithAxis<T><<<grid_size, block_size, 64 * 4, dev_ctx.stream()>>>(
out_grad_data, indices_data, x_grad_data, pre, post, n, k);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
top_k_v2,
paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::TopkV2OpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
top_k_v2_grad, paddle::operators::TopkV2OpGradCUDAKernel<
paddle::platform::CUDADeviceContext, float>,
paddle::operators::TopkV2OpGradCUDAKernel<
paddle::platform::CUDADeviceContext, double>,
paddle::operators::TopkV2OpGradCUDAKernel<
paddle::platform::CUDADeviceContext, int>,
paddle::operators::TopkV2OpGradCUDAKernel<
paddle::platform::CUDADeviceContext, int64_t>,
paddle::operators::TopkV2OpGradCUDAKernel<
paddle::platform::CUDADeviceContext, paddle::platform::float16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
The reason why we need the topk v2 is because the compatibility. We redefine
the NaN is maximum value
in the process of comparing. If do not add the topk v2, will affect the
inference result of model that traing
by the older version paddlepaddle.
*/
#pragma once
#include <algorithm>
#include <iostream>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/operators/transpose_op.h"
namespace paddle {
namespace operators {
template <typename T, typename Type>
static void FullTopK(Type input_height, Type input_width, int input_dim,
const framework::Tensor* input, T* t_out, Type* t_indices,
const int& k, const bool& largest, const bool& sorted) {
// when the k is small, will the partial sort
bool partial_sort_flag = (k * 64) < input_width;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
// Eigen::DSizes<int, 2> flat2dims(input_height, input_width);
for (Type i = 0; i < input_height; ++i) {
std::vector<std::pair<T, Type>> col_vec;
col_vec.reserve(input_width);
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
for (Type j = 0; j < input_width; ++j) {
col_vec.emplace_back(std::pair<T, Type>(e_input(j), j));
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
col_vec.emplace_back(std::pair<T, Type>(e_input(i, j), j));
}
}
if (partial_sort_flag) {
std::partial_sort(
col_vec.begin(), col_vec.begin() + k, col_vec.end(),
[&largest](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (largest) {
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
} else {
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
}
});
} else {
// use the nth-element to get the K-larger or K-small element
if (largest) {
std::nth_element(
col_vec.begin(), col_vec.begin() + k - 1, col_vec.end(),
[](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
});
// the nth-element will get the unorder elements, sort the element
if (sorted) {
std::sort(col_vec.begin(), col_vec.begin() + k - 1,
[&largest](const std::pair<T, Type>& l,
const std::pair<T, Type>& r) {
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
});
}
} else {
std::nth_element(
col_vec.begin(), col_vec.begin() + k - 1, col_vec.end(),
[](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
// the nth-element will get the unorder elements, sort the element
if (sorted) {
std::sort(
col_vec.begin(), col_vec.begin() + k - 1,
[](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
}
}
}
for (Type j = 0; j < k; ++j) {
t_out[i * k + j] = col_vec[j].first;
t_indices[i * k + j] = col_vec[j].second;
}
}
}
template <typename T, typename Type>
static void FullTopKAssign(const Type& input_height, const Type& input_width,
const int& input_dim, const framework::Tensor* input,
const framework::Tensor* indices, T* output_data,
const int& k) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
auto e_indices = EigenVector<Type>::Flatten(*indices);
for (Type j = 0; j < k; ++j) {
output_data[i * input_width + e_indices(j)] = e_input(j);
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
auto e_indices = EigenMatrix<Type>::Reshape(*indices, input_dim - 1);
for (Type j = 0; j < k; ++j) {
output_data[i * input_width + e_indices(i, j)] = e_input(i, j);
}
}
}
}
template <typename DeviceContext, typename T>
class TopkV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// Get the top k elements of each row of input tensor
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
auto* indices = context.Output<Tensor>("Indices");
const auto& in_dims = input->dims();
int k = static_cast<int>(context.Attr<int>("k"));
const auto& sorted = static_cast<bool>(context.Attr<bool>("sorted"));
const auto& largest = static_cast<bool>(context.Attr<bool>("largest"));
// axis < 0, cacluate the real axis
int axis = static_cast<int>(context.Attr<int>("axis"));
if (axis < 0) axis += in_dims.size();
// if K tensor is not null, will the use K tesnor as k
auto* k_t = context.Input<Tensor>("K");
if (k_t) {
k = k_t->data<int>()[0];
framework::DDim output_dims = output->dims();
// accroding to axis to set K value in the dim
output_dims[axis] = k;
output->Resize(output_dims);
indices->Resize(output_dims);
}
T* output_data = output->mutable_data<T>(context.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(context.GetPlace());
const auto& out_dims = output->dims();
if (axis + 1 == in_dims.size()) {
const int64_t& input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1];
FullTopK<T, int64_t>(input_height, input_width, in_dims.size(), input,
output_data, indices_data, k, largest, sorted);
} else {
// if the topk dims is not last dim, will tranpose and do topk
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
// get the trans input_dims, out_dims
framework::DDim trans_dims(in_dims);
framework::DDim trans_out_dims(output->dims());
for (size_t i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}
for (size_t i = 0; i < trans.size(); i++) {
trans_out_dims[i] = out_dims[trans[i]];
}
Tensor trans_inp;
trans_inp.mutable_data<T>(trans_dims, context.GetPlace());
int ndims = trans.size();
auto& dev_context =
context.template device_context<platform::CPUDeviceContext>();
// transpose the input value
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_context, *input,
&trans_inp, trans);
const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
// Allocate the temp tensor to the save the topk indices, values
Tensor tmp_out;
T* t_out = tmp_out.mutable_data<T>(trans_out_dims, context.GetPlace());
Tensor tmp_indices;
auto* t_ind =
tmp_indices.mutable_data<int64_t>(trans_out_dims, context.GetPlace());
// get the TopK value
FullTopK<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_inp, t_out, t_ind, k, largest, sorted);
// transpose back
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_context, tmp_indices, indices, trans);
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_context, tmp_out,
output, trans);
}
}
};
template <typename DeviceContext, typename T>
class TopkV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<Tensor>("Indices");
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int axis = static_cast<int>(context.Attr<int>("axis"));
const auto& in_dims = x->dims();
const auto& out_dims = indices->dims();
// axis < 0, get the real axis
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
const size_t& k = out_dims[axis];
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
if (axis + 1 == in_dims.size()) {
// allocate the memory for the input_grad
// assign the out_grad to input_grad directly
const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
// init the output grad with 0, because some input elements has no grad
memset(x_grad_data, 0, x_grad->numel() * sizeof(T));
// Assign the output_grad to input_grad
FullTopKAssign(input_height, input_width, in_dims.size(), out_grad,
indices, x_grad_data, k);
} else {
// can not assign grad to input_grad, must do the transpose
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.emplace_back(out_dims.size() - 1);
for (int i = axis + 1; i < out_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
framework::DDim trans_dims(out_dims);
framework::DDim trans_in_dims(in_dims);
for (size_t i = 0; i < trans.size(); i++) {
trans_dims[i] = out_dims[trans[i]];
trans_in_dims[i] = in_dims[trans[i]];
}
// transpose the out_grad, indices
Tensor trans_dO;
trans_dO.mutable_data<T>(trans_dims, context.GetPlace());
Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_dims, context.GetPlace());
int ndims = trans.size();
auto& dev_context =
context.template device_context<platform::CPUDeviceContext>();
// Do transpose
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_context, *out_grad,
&trans_dO, trans);
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_context, *indices, &trans_ind, trans);
const int64_t input_height = framework::product(
framework::slice_ddim(trans_in_dims, 0, trans_in_dims.size() - 1));
const int64_t input_width = trans_in_dims[trans_in_dims.size() - 1];
// Assign the out_grad to tranpose input_grad
Tensor tmp_out;
T* t_out = tmp_out.mutable_data<T>(trans_in_dims, context.GetPlace());
memset(t_out, 0, x_grad->numel() * sizeof(T));
FullTopKAssign<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_dO, &trans_ind, t_out, k);
// Transpose back
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_context, tmp_out,
x_grad, trans);
}
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
def numpy_topk(x, k=1, axis=-1, largest=True):
if axis < 0:
axis = len(x.shape) + axis
if largest:
indices = np.argsort(-x, axis=axis)
else:
indices = np.argsort(x, axis=axis)
if largest:
value = -np.sort(-x, axis=axis)
else:
value = np.sort(x, axis=axis)
indices = indices.take(indices=range(0, k), axis=axis)
value = value.take(indices=range(0, k), axis=axis)
return value, indices
class TestTopkOp(OpTest):
def init_args(self):
self.k = 3
self.axis = 1
self.largest = True
def setUp(self):
self.op_type = "top_k_v2"
self.dtype = np.float64
self.input_data = np.random.rand(10, 20)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=self.largest)
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
paddle.enable_static()
self.check_output()
def test_check_grad(self):
paddle.enable_static()
self.check_grad(set(['X']), 'Out')
class TestTopOp1(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 0
self.largest = True
class TestTopOp2(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 0
self.largest = False
class TestTopOp3(TestTopkOp):
def init_args(self):
self.k = 4
self.axis = 0
self.largest = False
class TestTopOp4(TestTopkOp):
def init_args(self):
self.k = 4
self.axis = 0
self.largest = False
class TestTopkOp5(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 1
self.largest = True
def setUp(self):
self.op_type = "top_k_v2"
self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=self.largest)
self.outputs = {'Out': output, 'Indices': indices}
class TestTopkOp6(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 1
self.largest = True
def setUp(self):
self.op_type = "top_k_v2"
self.dtype = np.float64
self.input_data = np.random.rand(10, 10, 5)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=self.largest)
self.outputs = {'Out': output, 'Indices': indices}
class TestTopKAPI(unittest.TestCase):
def setUp(self):
np.random.seed(123)
self.input_data = np.random.rand(6, 7, 8)
self.large_input_data = np.random.rand(2, 1030)
def run_dygraph(self, place):
paddle.disable_static(place)
input_tensor = paddle.to_tensor(self.input_data)
large_input_tensor = paddle.to_tensor(self.large_input_data)
# test case for basic test case 1
paddle_result = paddle.topk(input_tensor, k=2)
numpy_result = numpy_topk(self.input_data, k=2)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 2 with axis
paddle_result = paddle.topk(input_tensor, k=2, axis=1)
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 3 with tensor K
k_tensor = paddle.to_tensor(np.array([2]))
paddle_result = paddle.topk(input_tensor, k=k_tensor, axis=1)
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 4 with tensor largest
k_tensor = paddle.to_tensor(np.array([2]))
paddle_result = paddle.topk(input_tensor, k=2, axis=1, largest=False)
numpy_result = numpy_topk(self.input_data, k=2, axis=1, largest=False)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 5 with axis -1
k_tensor = paddle.to_tensor(np.array([2]))
paddle_result = paddle.topk(input_tensor, k=2, axis=-1, largest=False)
numpy_result = numpy_topk(self.input_data, k=2, axis=-1, largest=False)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 6 for the partial sort
paddle_result = paddle.topk(large_input_tensor, k=1, axis=-1)
numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1)
self.assertTrue(np.allclose(paddle_result[0].numpy(), numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1].numpy(), numpy_result[1]))
# test case for basic test case 7 for the unsorted
paddle_result = paddle.topk(input_tensor, k=2, axis=1, sorted=False)
sort_paddle = numpy_topk(
np.array(paddle_result[0].numpy()), axis=1, k=2)
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0]))
def run_static(self, place):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
input_tensor = paddle.static.data(
name="x", shape=[6, 7, 8], dtype="float64")
large_input_tensor = paddle.static.data(
name="large_x", shape=[2, 1030], dtype="float64")
k_tensor = paddle.static.data(name="k", shape=[1], dtype="int32")
result1 = paddle.topk(input_tensor, k=2)
result2 = paddle.topk(input_tensor, k=2, axis=-1)
result3 = paddle.topk(input_tensor, k=k_tensor, axis=1)
result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False)
result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False)
result6 = paddle.topk(large_input_tensor, k=1, axis=-1)
result7 = paddle.topk(input_tensor, k=2, axis=1, sorted=False)
exe = paddle.static.Executor(place)
input_data = np.random.rand(10, 20).astype("float64")
large_input_data = np.random.rand(2, 100).astype("float64")
paddle_result = exe.run(
feed={
"x": self.input_data,
"large_x": self.large_input_data,
"k": np.array([2]).astype("int32")
},
fetch_list=[
result1[0], result1[1], result2[0], result2[1], result3[0],
result3[1], result4[0], result4[1], result5[0], result5[1],
result6[0], result6[1], result7[0], result7[1]
])
numpy_result = numpy_topk(self.input_data, k=2)
self.assertTrue(np.allclose(paddle_result[0], numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[1], numpy_result[1]))
numpy_result = numpy_topk(self.input_data, k=2, axis=-1)
self.assertTrue(np.allclose(paddle_result[2], numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[3], numpy_result[1]))
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(np.allclose(paddle_result[4], numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[5], numpy_result[1]))
numpy_result = numpy_topk(
self.input_data, k=2, axis=1, largest=False)
self.assertTrue(np.allclose(paddle_result[6], numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[7], numpy_result[1]))
numpy_result = numpy_topk(
self.input_data, k=2, axis=-1, largest=False)
self.assertTrue(np.allclose(paddle_result[8], numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[9], numpy_result[1]))
numpy_result = numpy_topk(self.large_input_data, k=1, axis=-1)
self.assertTrue(np.allclose(paddle_result[10], numpy_result[0]))
self.assertTrue(np.allclose(paddle_result[11], numpy_result[1]))
sort_paddle = numpy_topk(paddle_result[12], axis=1, k=2)
numpy_result = numpy_topk(self.input_data, k=2, axis=1)
self.assertTrue(np.allclose(sort_paddle[0], numpy_result[0]))
def test_cases(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.run_dygraph(place)
self.run_static(place)
if __name__ == "__main__":
unittest.main()
......@@ -21,7 +21,6 @@ from ..fluid import core, layers
from ..fluid.layers import argmin #DEFINE_ALIAS
from ..fluid.layers import has_inf #DEFINE_ALIAS
from ..fluid.layers import has_nan #DEFINE_ALIAS
from ..fluid.layers import topk #DEFINE_ALIAS
__all__ = [
'argmax',
......@@ -756,3 +755,100 @@ def masked_select(x, mask, name=None):
type='masked_select', inputs={'X': x,
'Mask': mask}, outputs={'Y': out})
return out
def topk(x, k, axis=None, largest=True, sorted=True, name=None):
"""
This OP is used to find values and indices of the k largest or smallest at the optional axis.
If the input is a 1-D Tensor, finds the k largest or smallest values and indices.
If the input is a Tensor with higher rank, this operator computes the top k values and indices along the :attr:`axis`.
Args:
x(Tensor): Tensor, an input N-D Tensor with type float32, float64, int32, int64.
k(int, Tensor): The number of top elements to look for along the axis.
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
as axis + R. Default is -1.
largest(bool, optional) : largest is a flag, if set to true,
algorithm will sort by descending order, otherwise sort by
ascending order. Default is True.
sorted(bool, optional): controls whether to return the elements in sorted order, default value is True. In gpu device, it always return the sorted value.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
data_1 = np.array([1, 4, 5, 7])
tensor_1 = paddle.to_tensor(data_1)
value_1, indices_1 = paddle.topk(tensor_1, k=1)
print(value_1.numpy())
# [7]
print(indices_1.numpy())
# [3]
data_2 = np.array([[1, 4, 5, 7], [2, 6, 2, 5]])
tensor_2 = paddle.to_tensor(data_2)
value_2, indices_2 = paddle.topk(tensor_2, k=1)
print(value_2.numpy())
# [[7]
# [6]]
print(indices_2.numpy())
# [[3]
# [1]]
value_3, indices_3 = paddle.topk(tensor_2, k=1, axis=-1)
print(value_3.numpy())
# [[7]
# [6]]
print(indices_3.numpy())
# [[3]
# [1]]
value_4, indices_4 = paddle.topk(tensor_2, k=1, axis=0)
print(value_4.numpy())
# [[2 6 5 7]]
print(indices_4.numpy())
# [[1 1 0 0]]
"""
if in_dygraph_mode():
k = k.numpy().item(0) if isinstance(k, Variable) else k
if axis is None:
out, indices = core.ops.top_k_v2(x, 'k',
int(k), 'largest', largest,
'sorted', sorted)
else:
out, indices = core.ops.top_k_v2(x, 'k',
int(k), 'axis', axis, 'largest',
largest, 'sorted', sorted)
return out, indices
helper = LayerHelper("top_k_v2", **locals())
inputs = {"X": [x]}
attrs = {}
if isinstance(k, Variable):
inputs['K'] = [k]
else:
attrs = {'k': k}
attrs['largest'] = largest
attrs['sorted'] = sorted
if axis is not None:
attrs['axis'] = axis
values = helper.create_variable_for_type_inference(dtype=x.dtype)
indices = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="top_k_v2",
inputs=inputs,
outputs={"Out": [values],
"Indices": [indices]},
attrs=attrs)
indices.stop_gradient = True
return values, indices
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册