未验证 提交 0fd6e2a1 编写于 作者: W wangshengxiang 提交者: GitHub

[XPU] add smallest mode for top_k (#51053)

上级 8ac05c09
......@@ -43,12 +43,6 @@ void TopkKernel(const Context& dev_ctx,
errors::External(
"XPU API does not support unsorted topk operation currently."
" Operator will be supported in future update."));
PADDLE_ENFORCE_EQ(
largest,
true,
errors::External(
"XPU API does not support smallest topk operation currently."
" Operator will be supported in future update."));
if (in_dims.size() == 0) {
int r = xpu::copy<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
......@@ -77,7 +71,8 @@ void TopkKernel(const Context& dev_ctx,
indices_int_data,
row,
col,
k);
k,
largest);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sorted_topk");
r = xpu::cast<int32_t, int64_t>(dev_ctx.x_context(),
......@@ -140,7 +135,8 @@ void TopkKernel(const Context& dev_ctx,
trans_idx_int32_data,
row,
col,
k);
k,
largest);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sorted_topk");
r = xpu::cast<int32_t, int64_t>(dev_ctx.x_context(),
......
......@@ -193,6 +193,102 @@ class XPUTestTopKV2Op(XPUOpTestWrapper):
self.largest = True
self.input_data_shape = (10, 10, 5)
class TestTopkSmallestOp1(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 1
self.largest = False
# too many values for fp16 will lead to failure in random_unique_float function
if self.dtype == np.float16:
self.input_data_shape = (100, 55)
else:
self.input_data_shape = (100, 155)
class TestTopkSmallestOp2(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 1
self.largest = False
self.input_data_shape = (10, 10, 5)
class TestTopkSmallestOp3(TestTopkOp):
def init_args(self):
self.k = 5
self.axis = 1
self.largest = False
self.input_data_shape = (10, 10, 5)
class TestTopkSmallestOp4(TestTopkOp):
def init_args(self):
self.k = 1
self.axis = 1
self.largest = False
self.input_data_shape = (10, 10, 5)
class TestTopkSmallestOp5(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 2
self.largest = False
self.input_data_shape = (10, 10, 5)
class TestTopkSmallestOp6(TestTopkOp):
def init_args(self):
self.k = 5
self.axis = 1
self.largest = False
# too many values for fp16 will lead to failure in random_unique_float function
if self.dtype == np.float16:
self.input_data_shape = (8, 32, 32)
else:
self.input_data_shape = (8, 32, 64)
class TestTopkSmallestOp7(TestTopkOp):
def init_args(self):
self.k = 10
self.axis = 2
self.largest = False
self.input_data_shape = (8, 5, 10, 16)
class TestTopkSmallestOp8(TestTopkOp):
def init_args(self):
self.k = 1
self.axis = 1
self.largest = False
# too many values for fp16 will lead to failure in random_unique_float function
if self.dtype == np.float16:
self.input_data_shape = (8, 32, 32)
else:
self.input_data_shape = (8, 32, 64)
class TestTopkSmallestOp9(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 1
self.largest = False
self.input_data_shape = (10, 10, 5)
class TestTopkSmallestOp10(TestTopkOp):
def init_args(self):
self.k = 3
self.axis = 1
self.largest = False
self.input_data_shape = (10, 10, 5)
class TestTopkSmallestOp11(TestTopkOp):
def init_args(self):
self.k = 5
self.axis = 1
self.largest = False
self.input_data_shape = (10, 10, 5)
class TestTopkSmallestOp12(TestTopkOp):
def init_args(self):
self.k = 1
self.axis = 1
self.largest = False
self.input_data_shape = (10, 10, 5)
support_types = get_xpu_op_support_types('top_k_v2')
for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册