From c7379a7320c7af42b8e40b1293169dbdc39930c6 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Wed, 24 Oct 2018 14:28:29 +0800 Subject: [PATCH] Fix top_k op (#14034) 1. Fix CUDA kernel when height is large than 2048. 2. Support input with more than 2D. 3. Fix unit test when k is large than 1. 4. Enhence unit testing. test=develop --- paddle/fluid/operators/top_k_op.cu | 32 +++++----- paddle/fluid/operators/top_k_op.h | 5 +- .../fluid/tests/unittests/test_top_k_op.py | 64 +++++++++++++++---- 3 files changed, 70 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 8e4a07556fb..0cad224ca88 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, const T* src, int lds, int dim, int k, int grid_dim, int num) { __shared__ Pair sh_topk[BlockSize]; - __shared__ int maxid[BlockSize / 2]; const int tid = threadIdx.x; const int warp = threadIdx.x / 32; const int bid = blockIdx.x; for (int i = bid; i < num; i += grid_dim) { - output += i * output_stride; - indices += i * k; - + int top_num = k; + __shared__ int maxid[BlockSize / 2]; + T* out = output + i * output_stride; + int64_t* inds = indices + i * k; Pair topk[MaxLength]; int beam = MaxLength; Pair max; bool is_empty = false; bool firststep = true; - for (int k = 0; k < MaxLength; k++) { - topk[k].set(-INFINITY, -1); + for (int j = 0; j < MaxLength; j++) { + topk[j].set(-INFINITY, -1); } - while (k) { + while (top_num) { ThreadGetTopK( topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid); sh_topk[tid] = topk[0]; - BlockReduce(sh_topk, maxid, topk, &output, - &indices, &beam, &k, tid, warp); + BlockReduce(sh_topk, maxid, topk, &out, &inds, + &beam, &top_num, tid, warp); } } } @@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel { size_t k = static_cast(ctx.Attr("k")); const T* input_data = input->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); // FIXME(typhoonzero): data is always converted to type T? int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); - size_t input_height = input->dims()[0]; - size_t input_width = input->dims()[1]; + framework::DDim inputdims = input->dims(); + const size_t input_height = framework::product( + framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); + const size_t input_width = inputdims[inputdims.size() - 1]; + if (k > input_width) k = input_width; // NOTE: pass lds and dim same to input width. @@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel { const int kMaxHeight = 2048; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; auto& dev_ctx = ctx.cuda_device_context(); - switch (GetDesiredBlockDim(input_width)) { FIXED_BLOCK_DIM( KeMatrixTopK<<>>( - output_data, output->dims()[1], indices_data, input_data, - input_width, input_width, static_cast(k), gridx, - input_height)); + output_data, k, indices_data, input_data, input_width, + input_width, static_cast(k), gridx, input_height)); default: PADDLE_THROW("Error"); } diff --git a/paddle/fluid/operators/top_k_op.h b/paddle/fluid/operators/top_k_op.h index 054dd481994..76ece57b399 100644 --- a/paddle/fluid/operators/top_k_op.h +++ b/paddle/fluid/operators/top_k_op.h @@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { // Get the top k elements of each row of input tensor - // FIXME: only deal with matrix(2d tensor). auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); @@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel { T* output_data = output->mutable_data(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); - auto eg_input = EigenMatrix::From(*input); - // reshape input to a flattern matrix(like flat_inner_dims) framework::DDim inputdims = input->dims(); const size_t row = framework::product( @@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel { const size_t col = inputdims[inputdims.size() - 1]; Eigen::DSizes flat2dims(row, col); // NOTE: eigen shape doesn't affect paddle tensor. - eg_input.reshape(flat2dims); + auto eg_input = EigenMatrix::Reshape(*input, inputdims.size() - 1); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for diff --git a/python/paddle/fluid/tests/unittests/test_top_k_op.py b/python/paddle/fluid/tests/unittests/test_top_k_op.py index e54e170f7f1..69b29db83a4 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_op.py @@ -21,22 +21,27 @@ from op_test import OpTest class TestTopkOp(OpTest): def setUp(self): + self.set_args() self.op_type = "top_k" - k = 1 - input = np.random.random((32, 84)).astype("float32") - output = np.ndarray((32, k)) - indices = np.ndarray((32, k)).astype("int64") + k = self.top_k + input = np.random.random((self.row, k)).astype("float32") + output = np.ndarray((self.row, k)) + indices = np.ndarray((self.row, k)).astype("int64") self.inputs = {'X': input} self.attrs = {'k': k} - for rowid in range(32): + for rowid in range(self.row): row = input[rowid] - output[rowid] = np.sort(row)[-k:] - indices[rowid] = row.argsort()[-k:] + output[rowid] = np.sort(row)[::-1][:k] + indices[rowid] = row.argsort()[::-1][:k] self.outputs = {'Out': output, 'Indices': indices} + def set_args(self): + self.row = 32 + self.top_k = 1 + def test_check_output(self): self.check_output() @@ -50,14 +55,39 @@ class TestTopkOp3d(OpTest): output = np.ndarray((64, k)) indices = np.ndarray((64, k)).astype("int64") - # FIXME: should use 'X': input for a 3d input - self.inputs = {'X': input_flat_2d} + self.inputs = {'X': input} self.attrs = {'k': k} for rowid in range(64): row = input_flat_2d[rowid] - output[rowid] = np.sort(row)[-k:] - indices[rowid] = row.argsort()[-k:] + output[rowid] = np.sort(row)[::-1][:k] + indices[rowid] = row.argsort()[::-1][:k] + + self.outputs = { + 'Out': output.reshape((32, 2, k)), + 'Indices': indices.reshape((32, 2, k)) + } + + def test_check_output(self): + self.check_output() + + +class TestTopkOp2(OpTest): + def setUp(self): + self.op_type = "top_k" + k = 1 + m = 2056 + input = np.random.random((m, 84)).astype("float32") + output = np.ndarray((m, k)) + indices = np.ndarray((m, k)).astype("int64") + + self.inputs = {'X': input} + self.attrs = {'k': k} + + for rowid in range(m): + row = input[rowid] + output[rowid] = -np.sort(-row)[:k] + indices[rowid] = (-row).argsort()[:k] self.outputs = {'Out': output, 'Indices': indices} @@ -65,5 +95,17 @@ class TestTopkOp3d(OpTest): self.check_output() +class TestTopkOp3(TestTopkOp): + def set_args(self): + self.row = 2056 + self.top_k = 3 + + +class TestTopkOp4(TestTopkOp): + def set_args(self): + self.row = 40000 + self.top_k = 1 + + if __name__ == "__main__": unittest.main() -- GitLab