Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
迷途归一
Paddle
提交
0fd6e2a1
P
Paddle
项目概览
迷途归一
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
0fd6e2a1
编写于
3月 02, 2023
作者:
W
wangshengxiang
提交者:
GitHub
3月 02, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] add smallest mode for top_k (#51053)
上级
8ac05c09
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
100 addition
and
8 deletion
+100
-8
paddle/phi/kernels/xpu/top_k_kernel.cc
paddle/phi/kernels/xpu/top_k_kernel.cc
+4
-8
python/paddle/fluid/tests/unittests/xpu/test_top_k_v2_op_xpu.py
.../paddle/fluid/tests/unittests/xpu/test_top_k_v2_op_xpu.py
+96
-0
未找到文件。
paddle/phi/kernels/xpu/top_k_kernel.cc
浏览文件 @
0fd6e2a1
...
...
@@ -43,12 +43,6 @@ void TopkKernel(const Context& dev_ctx,
errors
::
External
(
"XPU API does not support unsorted topk operation currently."
" Operator will be supported in future update."
));
PADDLE_ENFORCE_EQ
(
largest
,
true
,
errors
::
External
(
"XPU API does not support smallest topk operation currently."
" Operator will be supported in future update."
));
if
(
in_dims
.
size
()
==
0
)
{
int
r
=
xpu
::
copy
<
XPUType
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
()),
...
...
@@ -77,7 +71,8 @@ void TopkKernel(const Context& dev_ctx,
indices_int_data
,
row
,
col
,
k
);
k
,
largest
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"sorted_topk"
);
r
=
xpu
::
cast
<
int32_t
,
int64_t
>
(
dev_ctx
.
x_context
(),
...
...
@@ -140,7 +135,8 @@ void TopkKernel(const Context& dev_ctx,
trans_idx_int32_data
,
row
,
col
,
k
);
k
,
largest
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"sorted_topk"
);
r
=
xpu
::
cast
<
int32_t
,
int64_t
>
(
dev_ctx
.
x_context
(),
...
...
python/paddle/fluid/tests/unittests/xpu/test_top_k_v2_op_xpu.py
浏览文件 @
0fd6e2a1
...
...
@@ -193,6 +193,102 @@ class XPUTestTopKV2Op(XPUOpTestWrapper):
self
.
largest
=
True
self
.
input_data_shape
=
(
10
,
10
,
5
)
class
TestTopkSmallestOp1
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
3
self
.
axis
=
1
self
.
largest
=
False
# too many values for fp16 will lead to failure in random_unique_float function
if
self
.
dtype
==
np
.
float16
:
self
.
input_data_shape
=
(
100
,
55
)
else
:
self
.
input_data_shape
=
(
100
,
155
)
class
TestTopkSmallestOp2
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
3
self
.
axis
=
1
self
.
largest
=
False
self
.
input_data_shape
=
(
10
,
10
,
5
)
class
TestTopkSmallestOp3
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
5
self
.
axis
=
1
self
.
largest
=
False
self
.
input_data_shape
=
(
10
,
10
,
5
)
class
TestTopkSmallestOp4
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
1
self
.
axis
=
1
self
.
largest
=
False
self
.
input_data_shape
=
(
10
,
10
,
5
)
class
TestTopkSmallestOp5
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
3
self
.
axis
=
2
self
.
largest
=
False
self
.
input_data_shape
=
(
10
,
10
,
5
)
class
TestTopkSmallestOp6
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
5
self
.
axis
=
1
self
.
largest
=
False
# too many values for fp16 will lead to failure in random_unique_float function
if
self
.
dtype
==
np
.
float16
:
self
.
input_data_shape
=
(
8
,
32
,
32
)
else
:
self
.
input_data_shape
=
(
8
,
32
,
64
)
class
TestTopkSmallestOp7
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
10
self
.
axis
=
2
self
.
largest
=
False
self
.
input_data_shape
=
(
8
,
5
,
10
,
16
)
class
TestTopkSmallestOp8
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
1
self
.
axis
=
1
self
.
largest
=
False
# too many values for fp16 will lead to failure in random_unique_float function
if
self
.
dtype
==
np
.
float16
:
self
.
input_data_shape
=
(
8
,
32
,
32
)
else
:
self
.
input_data_shape
=
(
8
,
32
,
64
)
class
TestTopkSmallestOp9
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
3
self
.
axis
=
1
self
.
largest
=
False
self
.
input_data_shape
=
(
10
,
10
,
5
)
class
TestTopkSmallestOp10
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
3
self
.
axis
=
1
self
.
largest
=
False
self
.
input_data_shape
=
(
10
,
10
,
5
)
class
TestTopkSmallestOp11
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
5
self
.
axis
=
1
self
.
largest
=
False
self
.
input_data_shape
=
(
10
,
10
,
5
)
class
TestTopkSmallestOp12
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
1
self
.
axis
=
1
self
.
largest
=
False
self
.
input_data_shape
=
(
10
,
10
,
5
)
support_types
=
get_xpu_op_support_types
(
'top_k_v2'
)
for
stype
in
support_types
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录