未验证 提交 856536b9 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

cherry-pick Fix topk cannot handle 1D vector bug (#18466)

Add path to handle 1D vector
上级 e616c3da
......@@ -29,6 +29,10 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class TopkKernel : public framework::OpKernel<T> {
public:
......@@ -57,17 +61,24 @@ class TopkKernel : public framework::OpKernel<T> {
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t col = inputdims[inputdims.size() - 1];
Eigen::DSizes<int, 2> flat2dims(row, col);
// NOTE: eigen shape doesn't affect paddle tensor.
auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
// NOTE: eigen shape doesn't affect paddle tensor.
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (size_t i = 0; i < row; i++) {
std::vector<std::pair<T, size_t>> vec;
vec.reserve(col);
for (size_t j = 0; j < col; j++) {
vec.push_back(std::pair<T, size_t>(eg_input(i, j), j));
// 1D vector
if (inputdims.size() == 1) {
auto eg_input = EigenVector<T>::Flatten(*input);
for (size_t j = 0; j < col; j++) {
vec.push_back(std::pair<T, size_t>(eg_input(j), j));
}
} else {
auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
for (size_t j = 0; j < col; j++) {
vec.push_back(std::pair<T, size_t>(eg_input(i, j), j));
}
}
std::partial_sort(
......
......@@ -87,6 +87,28 @@ class TestTopkOp3d(OpTest):
self.check_output()
class TestTopkOp1(OpTest):
def setUp(self):
self.op_type = "top_k"
k = 2
m = 2056
input = np.random.random(m).astype("float32")
output = np.ndarray(k)
indices = np.ndarray(k).astype("int64")
self.inputs = {'X': input}
self.attrs = {'k': k}
row = input
output = -np.sort(-row)[:k]
indices = (-row).argsort()[:k]
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
self.check_output()
class TestTopkOp2(OpTest):
def setUp(self):
self.op_type = "top_k"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册