Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
dea24544
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 1 年 前同步成功
通知
2292
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
dea24544
编写于
3月 31, 2022
作者:
Z
Zhang Zheng
提交者:
GitHub
3月 31, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Restrict compilation conditions of optimized topk kernel (#41153)
* Restrict compilation conditions of optimized topk kernel * fix
上级
23a69bc7
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
11 addition
and
1 deletion
+11
-1
paddle/fluid/operators/top_k_function_cuda.h
paddle/fluid/operators/top_k_function_cuda.h
+11
-1
未找到文件。
paddle/fluid/operators/top_k_function_cuda.h
浏览文件 @
dea24544
...
...
@@ -361,7 +361,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
}
/*---------------------------Radix TopK Begin------------------*/
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA)
&& CUDA_VERSION >= 9000
constexpr
int
RADIX_BITS
=
2
;
// digits are base-(2 ^ RADIX_BITS)
constexpr
int
RADIX_SIZE
=
4
;
// 2 ^ RADIX_BITS
constexpr
int
RADIX_MASK
=
(
RADIX_SIZE
-
1
);
...
...
@@ -479,15 +479,25 @@ struct RadixTypeConfig<platform::float16> {
typedef
uint32_t
RadixType
;
static
inline
__device__
RadixType
Convert
(
platform
::
float16
v
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
half
v_h
=
v
.
to_half
();
RadixType
x
=
__half_as_ushort
(
v_h
);
RadixType
mask
=
(
x
&
0x00008000
)
?
0x0000ffff
:
0x00008000
;
return
(
v_h
==
v_h
)
?
(
x
^
mask
)
:
0xffff
;
#else
assert
(
false
);
return
0u
;
#endif
}
static
inline
__device__
platform
::
float16
Deconvert
(
RadixType
v
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
RadixType
mask
=
(
v
&
0x00008000
)
?
0x00008000
:
0x0000ffff
;
return
static_cast
<
platform
::
float16
>
(
__ushort_as_half
(
v
^
mask
));
#else
assert
(
false
);
return
static_cast
<
platform
::
float16
>
(
0
);
#endif
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录