From 856536b9d7e5bb84678941907aed02c140a4c098 Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Mon, 8 Jul 2019 14:28:20 +0800 Subject: [PATCH] cherry-pick Fix topk cannot handle 1D vector bug (#18466) Add path to handle 1D vector --- paddle/fluid/operators/top_k_op.h | 21 +++++++++++++----- .../fluid/tests/unittests/test_top_k_op.py | 22 +++++++++++++++++++ 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/top_k_op.h b/paddle/fluid/operators/top_k_op.h index f7bac67300b..6b9db260605 100644 --- a/paddle/fluid/operators/top_k_op.h +++ b/paddle/fluid/operators/top_k_op.h @@ -29,6 +29,10 @@ template using EigenMatrix = framework::EigenMatrix; +template +using EigenVector = framework::EigenVector; + template class TopkKernel : public framework::OpKernel { public: @@ -57,17 +61,24 @@ class TopkKernel : public framework::OpKernel { framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); const size_t col = inputdims[inputdims.size() - 1]; Eigen::DSizes flat2dims(row, col); - // NOTE: eigen shape doesn't affect paddle tensor. - auto eg_input = EigenMatrix::Reshape(*input, inputdims.size() - 1); - +// NOTE: eigen shape doesn't affect paddle tensor. #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif for (size_t i = 0; i < row; i++) { std::vector> vec; vec.reserve(col); - for (size_t j = 0; j < col; j++) { - vec.push_back(std::pair(eg_input(i, j), j)); + // 1D vector + if (inputdims.size() == 1) { + auto eg_input = EigenVector::Flatten(*input); + for (size_t j = 0; j < col; j++) { + vec.push_back(std::pair(eg_input(j), j)); + } + } else { + auto eg_input = EigenMatrix::Reshape(*input, inputdims.size() - 1); + for (size_t j = 0; j < col; j++) { + vec.push_back(std::pair(eg_input(i, j), j)); + } } std::partial_sort( diff --git a/python/paddle/fluid/tests/unittests/test_top_k_op.py b/python/paddle/fluid/tests/unittests/test_top_k_op.py index 9fbf59ed669..5327c0f5de5 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_op.py @@ -87,6 +87,28 @@ class TestTopkOp3d(OpTest): self.check_output() +class TestTopkOp1(OpTest): + def setUp(self): + self.op_type = "top_k" + k = 2 + m = 2056 + input = np.random.random(m).astype("float32") + output = np.ndarray(k) + indices = np.ndarray(k).astype("int64") + + self.inputs = {'X': input} + self.attrs = {'k': k} + + row = input + output = -np.sort(-row)[:k] + indices = (-row).argsort()[:k] + + self.outputs = {'Out': output, 'Indices': indices} + + def test_check_output(self): + self.check_output() + + class TestTopkOp2(OpTest): def setUp(self): self.op_type = "top_k" -- GitLab