未验证 提交 856536b9 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

cherry-pick Fix topk cannot handle 1D vector bug (#18466)

Add path to handle 1D vector
上级 e616c3da
...@@ -29,6 +29,10 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -29,6 +29,10 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class TopkKernel : public framework::OpKernel<T> { class TopkKernel : public framework::OpKernel<T> {
public: public:
...@@ -57,18 +61,25 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -57,18 +61,25 @@ class TopkKernel : public framework::OpKernel<T> {
framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t col = inputdims[inputdims.size() - 1]; const size_t col = inputdims[inputdims.size() - 1];
Eigen::DSizes<int, 2> flat2dims(row, col); Eigen::DSizes<int, 2> flat2dims(row, col);
// NOTE: eigen shape doesn't affect paddle tensor. // NOTE: eigen shape doesn't affect paddle tensor.
auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
for (size_t i = 0; i < row; i++) { for (size_t i = 0; i < row; i++) {
std::vector<std::pair<T, size_t>> vec; std::vector<std::pair<T, size_t>> vec;
vec.reserve(col); vec.reserve(col);
// 1D vector
if (inputdims.size() == 1) {
auto eg_input = EigenVector<T>::Flatten(*input);
for (size_t j = 0; j < col; j++) {
vec.push_back(std::pair<T, size_t>(eg_input(j), j));
}
} else {
auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
for (size_t j = 0; j < col; j++) { for (size_t j = 0; j < col; j++) {
vec.push_back(std::pair<T, size_t>(eg_input(i, j), j)); vec.push_back(std::pair<T, size_t>(eg_input(i, j), j));
} }
}
std::partial_sort( std::partial_sort(
vec.begin(), vec.begin() + k, vec.end(), vec.begin(), vec.begin() + k, vec.end(),
......
...@@ -87,6 +87,28 @@ class TestTopkOp3d(OpTest): ...@@ -87,6 +87,28 @@ class TestTopkOp3d(OpTest):
self.check_output() self.check_output()
class TestTopkOp1(OpTest):
def setUp(self):
self.op_type = "top_k"
k = 2
m = 2056
input = np.random.random(m).astype("float32")
output = np.ndarray(k)
indices = np.ndarray(k).astype("int64")
self.inputs = {'X': input}
self.attrs = {'k': k}
row = input
output = -np.sort(-row)[:k]
indices = (-row).argsort()[:k]
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
self.check_output()
class TestTopkOp2(OpTest): class TestTopkOp2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "top_k" self.op_type = "top_k"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册