提交 c7379a73 编写于 作者: Q qingqing01 提交者: whs

Fix top_k op (#14034)

1. Fix CUDA kernel when height is large than 2048.
2. Support input with more than 2D.
3. Fix unit test when k is large than 1.
4. Enhence unit testing.

test=develop
上级 d3e52551
...@@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, ...@@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k, const T* src, int lds, int dim, int k,
int grid_dim, int num) { int grid_dim, int num) {
__shared__ Pair<T> sh_topk[BlockSize]; __shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2];
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int warp = threadIdx.x / 32; const int warp = threadIdx.x / 32;
const int bid = blockIdx.x; const int bid = blockIdx.x;
for (int i = bid; i < num; i += grid_dim) { for (int i = bid; i < num; i += grid_dim) {
output += i * output_stride; int top_num = k;
indices += i * k; __shared__ int maxid[BlockSize / 2];
T* out = output + i * output_stride;
int64_t* inds = indices + i * k;
Pair<T> topk[MaxLength]; Pair<T> topk[MaxLength];
int beam = MaxLength; int beam = MaxLength;
Pair<T> max; Pair<T> max;
bool is_empty = false; bool is_empty = false;
bool firststep = true; bool firststep = true;
for (int k = 0; k < MaxLength; k++) { for (int j = 0; j < MaxLength; j++) {
topk[k].set(-INFINITY, -1); topk[j].set(-INFINITY, -1);
} }
while (k) { while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>( ThreadGetTopK<T, MaxLength, BlockSize>(
topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid); topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
sh_topk[tid] = topk[0]; sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output, BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
&indices, &beam, &k, tid, warp); &beam, &top_num, tid, warp);
} }
} }
} }
...@@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
size_t k = static_cast<int>(ctx.Attr<int>("k")); size_t k = static_cast<int>(ctx.Attr<int>("k"));
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T? // FIXME(typhoonzero): data is always converted to type T?
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
size_t input_height = input->dims()[0]; framework::DDim inputdims = input->dims();
size_t input_width = input->dims()[1]; const size_t input_height = framework::product(
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t input_width = inputdims[inputdims.size() - 1];
if (k > input_width) k = input_width; if (k > input_width) k = input_width;
// NOTE: pass lds and dim same to input width. // NOTE: pass lds and dim same to input width.
...@@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
const int kMaxHeight = 2048; const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
auto& dev_ctx = ctx.cuda_device_context(); auto& dev_ctx = ctx.cuda_device_context();
switch (GetDesiredBlockDim(input_width)) { switch (GetDesiredBlockDim(input_width)) {
FIXED_BLOCK_DIM( FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5, KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>( kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, output->dims()[1], indices_data, input_data, output_data, k, indices_data, input_data, input_width,
input_width, input_width, static_cast<int>(k), gridx, input_width, static_cast<int>(k), gridx, input_height));
input_height));
default: default:
PADDLE_THROW("Error"); PADDLE_THROW("Error");
} }
......
...@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
// Get the top k elements of each row of input tensor // Get the top k elements of each row of input tensor
// FIXME: only deal with matrix(2d tensor).
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices"); auto* indices = ctx.Output<Tensor>("Indices");
...@@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
auto eg_input = EigenMatrix<T>::From(*input);
// reshape input to a flattern matrix(like flat_inner_dims) // reshape input to a flattern matrix(like flat_inner_dims)
framework::DDim inputdims = input->dims(); framework::DDim inputdims = input->dims();
const size_t row = framework::product( const size_t row = framework::product(
...@@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel<T> {
const size_t col = inputdims[inputdims.size() - 1]; const size_t col = inputdims[inputdims.size() - 1];
Eigen::DSizes<int, 2> flat2dims(row, col); Eigen::DSizes<int, 2> flat2dims(row, col);
// NOTE: eigen shape doesn't affect paddle tensor. // NOTE: eigen shape doesn't affect paddle tensor.
eg_input.reshape(flat2dims); auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
......
...@@ -21,22 +21,27 @@ from op_test import OpTest ...@@ -21,22 +21,27 @@ from op_test import OpTest
class TestTopkOp(OpTest): class TestTopkOp(OpTest):
def setUp(self): def setUp(self):
self.set_args()
self.op_type = "top_k" self.op_type = "top_k"
k = 1 k = self.top_k
input = np.random.random((32, 84)).astype("float32") input = np.random.random((self.row, k)).astype("float32")
output = np.ndarray((32, k)) output = np.ndarray((self.row, k))
indices = np.ndarray((32, k)).astype("int64") indices = np.ndarray((self.row, k)).astype("int64")
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = {'k': k} self.attrs = {'k': k}
for rowid in range(32): for rowid in range(self.row):
row = input[rowid] row = input[rowid]
output[rowid] = np.sort(row)[-k:] output[rowid] = np.sort(row)[::-1][:k]
indices[rowid] = row.argsort()[-k:] indices[rowid] = row.argsort()[::-1][:k]
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
def set_args(self):
self.row = 32
self.top_k = 1
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -50,14 +55,39 @@ class TestTopkOp3d(OpTest): ...@@ -50,14 +55,39 @@ class TestTopkOp3d(OpTest):
output = np.ndarray((64, k)) output = np.ndarray((64, k))
indices = np.ndarray((64, k)).astype("int64") indices = np.ndarray((64, k)).astype("int64")
# FIXME: should use 'X': input for a 3d input self.inputs = {'X': input}
self.inputs = {'X': input_flat_2d}
self.attrs = {'k': k} self.attrs = {'k': k}
for rowid in range(64): for rowid in range(64):
row = input_flat_2d[rowid] row = input_flat_2d[rowid]
output[rowid] = np.sort(row)[-k:] output[rowid] = np.sort(row)[::-1][:k]
indices[rowid] = row.argsort()[-k:] indices[rowid] = row.argsort()[::-1][:k]
self.outputs = {
'Out': output.reshape((32, 2, k)),
'Indices': indices.reshape((32, 2, k))
}
def test_check_output(self):
self.check_output()
class TestTopkOp2(OpTest):
def setUp(self):
self.op_type = "top_k"
k = 1
m = 2056
input = np.random.random((m, 84)).astype("float32")
output = np.ndarray((m, k))
indices = np.ndarray((m, k)).astype("int64")
self.inputs = {'X': input}
self.attrs = {'k': k}
for rowid in range(m):
row = input[rowid]
output[rowid] = -np.sort(-row)[:k]
indices[rowid] = (-row).argsort()[:k]
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
...@@ -65,5 +95,17 @@ class TestTopkOp3d(OpTest): ...@@ -65,5 +95,17 @@ class TestTopkOp3d(OpTest):
self.check_output() self.check_output()
class TestTopkOp3(TestTopkOp):
def set_args(self):
self.row = 2056
self.top_k = 3
class TestTopkOp4(TestTopkOp):
def set_args(self):
self.row = 40000
self.top_k = 1
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册