Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c9a334e1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c9a334e1
编写于
1月 15, 2021
作者:
Z
Zhang Ting
提交者:
GitHub
1月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add VecCastCUDAKernel (#30296)
上级
13d75736
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
47 addition
and
2 deletion
+47
-2
paddle/fluid/operators/cast_op.cu
paddle/fluid/operators/cast_op.cu
+47
-2
未找到文件。
paddle/fluid/operators/cast_op.cu
浏览文件 @
c9a334e1
...
...
@@ -19,6 +19,43 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
// aligned vector generates vectorized load/store on CUDA
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
AlignedVector
{
T
val
[
Size
];
};
template
<
typename
T
>
inline
int
VectorizedSize
(
const
T
*
pointer
)
{
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec4
=
std
::
alignment_of
<
AlignedVector
<
T
,
4
>>::
value
;
// NOLINT
if
(
address
%
vec4
==
0
)
{
return
4
;
}
return
1
;
}
template
<
typename
InT
,
typename
OutT
,
int
VecSize
>
__global__
void
VecCastCUDAKernel
(
const
InT
*
in
,
const
int64_t
N
,
OutT
*
out
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
using
LoadT
=
AlignedVector
<
InT
,
VecSize
>
;
using
StoreT
=
AlignedVector
<
OutT
,
VecSize
>
;
for
(
int
i
=
idx
*
VecSize
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
*
VecSize
)
{
InT
in_vec
[
VecSize
];
LoadT
*
in_value
=
reinterpret_cast
<
LoadT
*>
(
&
in_vec
);
*
in_value
=
*
reinterpret_cast
<
const
LoadT
*>
(
&
in
[
i
]);
OutT
out_vec
[
VecSize
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
VecSize
;
ii
++
)
{
out_vec
[
ii
]
=
static_cast
<
OutT
>
(
in_vec
[
ii
]);
}
*
(
reinterpret_cast
<
StoreT
*>
(
&
out
[
i
]))
=
*
reinterpret_cast
<
StoreT
*>
(
&
out_vec
[
0
]);
}
}
template
<
typename
InT
,
typename
OutT
>
__global__
void
CastCUDAKernel
(
const
InT
*
in
,
const
int64_t
N
,
OutT
*
out
)
{
CUDA_KERNEL_LOOP
(
index
,
N
)
{
out
[
index
]
=
static_cast
<
OutT
>
(
in
[
index
]);
}
...
...
@@ -40,8 +77,16 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> {
auto
*
out
=
out_
->
mutable_data
<
OutT
>
(
ctx_
.
GetPlace
());
platform
::
GpuLaunchConfig
config
=
platform
::
GetGpuLaunchConfig1D
(
ctx_
,
size
);
CastCUDAKernel
<
InT
,
OutT
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx_
.
stream
()
>>>
(
in
,
size
,
out
);
int
vec_size
=
VectorizedSize
<
OutT
>
(
out
);
if
(
!
std
::
is_same
<
InT
,
OutT
>::
value
&&
vec_size
==
4
&&
size
%
4
==
0
)
{
VecCastCUDAKernel
<
InT
,
OutT
,
4
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx_
.
stream
()
>>>
(
in
,
size
,
out
);
}
else
{
CastCUDAKernel
<
InT
,
OutT
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
ctx_
.
stream
()
>>>
(
in
,
size
,
out
);
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录