Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ae4d1ec1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
ae4d1ec1
编写于
5月 09, 2022
作者:
N
niuliling123
提交者:
GitHub
5月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modified reduce for xpu2 (#42439)
上级
8b546f1c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
8 addition
and
6 deletion
+8
-6
paddle/phi/kernels/funcs/reduce_function.h
paddle/phi/kernels/funcs/reduce_function.h
+5
-1
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
+3
-5
未找到文件。
paddle/phi/kernels/funcs/reduce_function.h
浏览文件 @
ae4d1ec1
...
...
@@ -473,7 +473,11 @@ struct ReduceConfig {
bool
not_higher
=
x_dim
[
0
]
>=
max_grid_z
;
#endif
if
(
reduce_last_dim
&&
(
reduce_rank
==
1
))
{
#ifdef PADDLE_WITH_XPU_KP
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceAny
);
#else
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceLastDim
);
#endif
}
else
if
(
reduce_rank
==
1
)
{
reduce_type
=
static_cast
<
int
>
(
ReduceType
::
kReduceHigherDim
);
if
(
rank
==
3
&&
not_higher
)
{
...
...
@@ -588,7 +592,7 @@ struct ReduceConfig {
void
SetBlockDim
()
{
// init
should_reduce_again
=
false
;
dim3
block_dim
;
dim3
block_dim
(
1
,
1
,
1
)
;
dim3
grid_dim
(
left_num
,
1
,
1
);
blocking_size
=
reduce_num
;
...
...
paddle/phi/kernels/primitive/compute_primitives_xpu2.h
浏览文件 @
ae4d1ec1
...
...
@@ -329,14 +329,12 @@ __device__ __forceinline__ void Reduce(T* out,
ReduceFunctor
reducer
,
bool
reduce_last_dim
)
{
if
(
Mode
==
details
::
kGlobalMode
)
{
if
(
reduce_last_dim
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NX
;
++
j
)
{
out
[
i
]
=
reducer
(
out
[
i
],
in
[
i
*
NX
+
j
]);
for
(
int
i
=
0
;
i
<
NY
*
NX
;
i
++
)
{
// reduce along blockDim.x
details
::
BlockXReduce
<
T
,
ReduceFunctor
,
1
>
(
&
out
[
i
],
reducer
);
}
}
details
::
BlockXReduce
<
T
,
ReduceFunctor
,
NY
>
(
out
,
reducer
);
}
else
{
// else kLocalMode
#pragma unroll
for
(
int
i
=
0
;
i
<
NY
;
++
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录