未验证 提交 2b8fd704 编写于 作者: P pangyoki 提交者: GitHub

fix bug of top_k npu op (#36175)

上级 c79de728
...@@ -51,7 +51,9 @@ class TopkNPUKernel : public framework::OpKernel<T> { ...@@ -51,7 +51,9 @@ class TopkNPUKernel : public framework::OpKernel<T> {
indices->mutable_data<int64_t>(ctx.GetPlace()); indices->mutable_data<int64_t>(ctx.GetPlace());
// prepare assit // prepare assit
auto dim = input->dims().size(); auto size = input->dims().size();
// dim is the last dimension of input
auto dim = input->dims()[size - 1];
framework::Tensor assist_seq_tensor; framework::Tensor assist_seq_tensor;
assist_seq_tensor.Resize({2 * dim}); assist_seq_tensor.Resize({2 * dim});
assist_seq_tensor.mutable_data<T>(ctx.GetPlace()); assist_seq_tensor.mutable_data<T>(ctx.GetPlace());
......
...@@ -22,6 +22,7 @@ from op_test import OpTest ...@@ -22,6 +22,7 @@ from op_test import OpTest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from test_top_k_v2_op_npu import numpy_topk
paddle.enable_static() paddle.enable_static()
SEED = 2021 SEED = 2021
...@@ -87,5 +88,40 @@ class TestTopkV2(OpTest): ...@@ -87,5 +88,40 @@ class TestTopkV2(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
class TestTopkV3(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "top_k"
self.init_dtype()
self.set_input_data()
self.set_attrs()
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=True)
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis}
self.outputs = {'Out': output, 'Indices': indices}
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place)
def set_attrs(self):
self.k = 3
self.axis = 1
def set_input_data(self):
self.input_data = np.random.choice(
10000, size=(10, 20), replace=False).astype(self.dtype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册