Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
84273aaa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
84273aaa
编写于
10月 24, 2022
作者:
Z
Zhang Ting
提交者:
GitHub
10月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix cumsum compilation error for GPU architecture that does not support fast FP16 (#47277)
上级
28ed27a6
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
47 addition
and
43 deletion
+47
-43
paddle/phi/kernels/gpu/cum_kernel.cu
paddle/phi/kernels/gpu/cum_kernel.cu
+47
-43
未找到文件。
paddle/phi/kernels/gpu/cum_kernel.cu
浏览文件 @
84273aaa
...
...
@@ -34,18 +34,6 @@ namespace cub = hipcub;
namespace
phi
{
template
<
typename
T
>
class
CumTypeTrait
{
public:
using
Type
=
T
;
};
template
<
>
class
CumTypeTrait
<
phi
::
dtype
::
float16
>
{
public:
using
Type
=
__half
;
};
template
<
typename
T
,
int
BLOCK_SIZE
>
__device__
void
BlockReverse
(
const
T
*
idata
,
T
*
odata
,
int
src_base
,
int
dst_base
,
int
valid_item
)
{
...
...
@@ -228,6 +216,51 @@ __global__ void BlockScanKernel(T* d_out,
}
}
template
<
typename
Context
,
typename
T
>
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
>::
type
ThrustCumsumKernel
(
const
Context
&
dev_ctx
,
const
T
*
in_data
,
T
*
out_data
,
int64_t
size
,
bool
reverse
,
bool
exclusive
)
{
#ifdef __HIPCC__
const
auto
&
policy
=
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
());
#else
const
auto
&
policy
=
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
());
#endif
if
(
reverse
)
{
thrust
::
reverse_iterator
<
thrust
::
device_ptr
<
const
T
>>
reversed_in
(
thrust
::
device_pointer_cast
(
in_data
)
+
size
);
thrust
::
reverse_iterator
<
thrust
::
device_ptr
<
T
>>
reversed_out
(
thrust
::
device_pointer_cast
(
out_data
)
+
size
);
if
(
exclusive
)
{
thrust
::
exclusive_scan
(
policy
,
reversed_in
,
reversed_in
+
size
,
reversed_out
);
}
else
{
thrust
::
inclusive_scan
(
policy
,
reversed_in
,
reversed_in
+
size
,
reversed_out
);
}
}
else
{
if
(
exclusive
)
{
thrust
::
exclusive_scan
(
policy
,
in_data
,
in_data
+
size
,
out_data
);
}
else
{
thrust
::
inclusive_scan
(
policy
,
in_data
,
in_data
+
size
,
out_data
);
}
}
return
;
}
template
<
typename
Context
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
>::
type
ThrustCumsumKernel
(
const
Context
&
dev_ctx
,
const
phi
::
dtype
::
float16
*
in_data
,
phi
::
dtype
::
float16
*
out_data
,
int64_t
size
,
bool
reverse
,
bool
exclusive
)
{}
template
<
typename
T
,
typename
Context
,
typename
Op
>
void
ScanKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
...
...
@@ -260,37 +293,8 @@ void ScanKernel(const Context& dev_ctx,
// length of the ‘axis’ dimension.
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
&&
std
::
is_same
<
Op
,
cub
::
Sum
>::
value
&&
size
==
out_dims
[
axis
])
{
#ifdef __HIPCC__
const
auto
&
policy
=
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
());
#else
const
auto
&
policy
=
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
());
#endif
using
CumType
=
typename
CumTypeTrait
<
T
>::
Type
;
CumType
*
out_data_ptr
=
reinterpret_cast
<
CumType
*>
(
out_data
);
const
CumType
*
in_data_ptr
=
reinterpret_cast
<
const
CumType
*>
(
in_data
);
if
(
reverse
)
{
thrust
::
reverse_iterator
<
thrust
::
device_ptr
<
const
CumType
>>
reversed_in
(
thrust
::
device_pointer_cast
(
in_data_ptr
)
+
size
);
thrust
::
reverse_iterator
<
thrust
::
device_ptr
<
CumType
>>
reversed_out
(
thrust
::
device_pointer_cast
(
out_data_ptr
)
+
size
);
if
(
exclusive
)
{
thrust
::
exclusive_scan
(
policy
,
reversed_in
,
reversed_in
+
size
,
reversed_out
);
}
else
{
thrust
::
inclusive_scan
(
policy
,
reversed_in
,
reversed_in
+
size
,
reversed_out
);
}
}
else
{
if
(
exclusive
)
{
thrust
::
exclusive_scan
(
policy
,
in_data_ptr
,
in_data_ptr
+
size
,
out_data_ptr
);
}
else
{
thrust
::
inclusive_scan
(
policy
,
in_data_ptr
,
in_data_ptr
+
size
,
out_data_ptr
);
}
}
ThrustCumsumKernel
<
Context
,
T
>
(
dev_ctx
,
in_data
,
out_data
,
size
,
reverse
,
exclusive
);
return
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录