diff --git a/paddle/fluid/operators/top_k_op.h b/paddle/fluid/operators/top_k_op.h index f7bac67300bd56b89d5b08238e78d625f4a773a6..6b9db260605c13af97325366c217e81fdfae08c7 100644 --- a/paddle/fluid/operators/top_k_op.h +++ b/paddle/fluid/operators/top_k_op.h @@ -29,6 +29,10 @@ template using EigenMatrix = framework::EigenMatrix; +template +using EigenVector = framework::EigenVector; + template class TopkKernel : public framework::OpKernel { public: @@ -57,17 +61,24 @@ class TopkKernel : public framework::OpKernel { framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); const size_t col = inputdims[inputdims.size() - 1]; Eigen::DSizes flat2dims(row, col); - // NOTE: eigen shape doesn't affect paddle tensor. - auto eg_input = EigenMatrix::Reshape(*input, inputdims.size() - 1); - +// NOTE: eigen shape doesn't affect paddle tensor. #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif for (size_t i = 0; i < row; i++) { std::vector> vec; vec.reserve(col); - for (size_t j = 0; j < col; j++) { - vec.push_back(std::pair(eg_input(i, j), j)); + // 1D vector + if (inputdims.size() == 1) { + auto eg_input = EigenVector::Flatten(*input); + for (size_t j = 0; j < col; j++) { + vec.push_back(std::pair(eg_input(j), j)); + } + } else { + auto eg_input = EigenMatrix::Reshape(*input, inputdims.size() - 1); + for (size_t j = 0; j < col; j++) { + vec.push_back(std::pair(eg_input(i, j), j)); + } } std::partial_sort( diff --git a/python/paddle/fluid/tests/unittests/test_top_k_op.py b/python/paddle/fluid/tests/unittests/test_top_k_op.py index 9fbf59ed669766077a456b3d83b7162e495ae8ae..5327c0f5de5d9a806f993818608929b9e07f624e 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_op.py @@ -87,6 +87,28 @@ class TestTopkOp3d(OpTest): self.check_output() +class TestTopkOp1(OpTest): + def setUp(self): + self.op_type = "top_k" + k = 2 + m = 2056 + input = np.random.random(m).astype("float32") + output = np.ndarray(k) + indices = np.ndarray(k).astype("int64") + + self.inputs = {'X': input} + self.attrs = {'k': k} + + row = input + output = -np.sort(-row)[:k] + indices = (-row).argsort()[:k] + + self.outputs = {'Out': output, 'Indices': indices} + + def test_check_output(self): + self.check_output() + + class TestTopkOp2(OpTest): def setUp(self): self.op_type = "top_k"