Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
45c7d905
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
45c7d905
编写于
3月 10, 2021
作者:
J
JamesLim
提交者:
GitHub
3月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimization of elementwise CUDA kernel (#30801)
上级
0b3c2296
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
14 addition
and
7 deletion
+14
-7
paddle/fluid/operators/elementwise/elementwise_op_function.h
paddle/fluid/operators/elementwise/elementwise_op_function.h
+14
-7
未找到文件。
paddle/fluid/operators/elementwise/elementwise_op_function.h
浏览文件 @
45c7d905
...
...
@@ -99,6 +99,7 @@ inline void get_mid_dims(const framework::DDim &x_dims,
(
*
post
)
*=
x_dims
[
i
];
}
}
inline
int
GetElementwiseIndex
(
const
int
*
x_dims_array
,
const
int
max_dim
,
const
int
*
index_array
)
{
int
index_
=
0
;
...
...
@@ -202,12 +203,16 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
#if defined(__NVCC__) || defined(__HIPCC__)
template
<
typename
Functor
,
typename
T
,
typename
OutType
>
__global__
void
ElementwiseKernel
(
const
T
*
x
,
const
T
*
y
,
OutType
*
out
,
int
pre
,
int
n
,
int
post
,
int
total
,
Functor
func
)
{
__global__
void
ElementwiseKernel
(
const
T
*
__restrict__
x_data
,
const
T
*
__restrict__
y_data
,
OutType
*
__restrict__
out_data
,
int
n
,
int
post
,
const
size_t
total
,
Functor
func
)
{
int
tid
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int
idx
=
tid
/
post
%
n
;
if
(
tid
<
total
)
{
out
[
tid
]
=
func
(
x
[
tid
],
y
[
idx
]);
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int
i
=
tid
;
i
<
total
;
i
+=
stride
)
{
int
idx
=
i
/
post
%
n
;
out_data
[
i
]
=
func
(
x_data
[
i
],
y_data
[
idx
]);
}
}
...
...
@@ -224,14 +229,16 @@ void ComputeElementwiseCUDA(const framework::Tensor *x,
int
numel
=
pre
*
n
*
post
;
int
threads
=
256
;
int
blocks
=
(
numel
+
threads
-
1
)
/
threads
;
if
(
is_xsize_larger
)
{
ElementwiseKernel
<
Functor
,
T
,
OutType
><<<
blocks
,
threads
,
0
,
ctx
.
stream
()
>>>
(
x_data
,
y_data
,
out_data
,
pre
,
n
,
post
,
numel
,
func
);
x_data
,
y_data
,
out_data
,
n
,
post
,
numel
,
func
);
}
else
{
ElementwiseKernel
<
Functor
,
T
,
OutType
><<<
blocks
,
threads
,
0
,
ctx
.
stream
()
>>>
(
y_data
,
x_data
,
out_data
,
pre
,
n
,
post
,
numel
,
func
);
y_data
,
x_data
,
out_data
,
n
,
post
,
numel
,
func
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录