Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ab385ca4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2301
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ab385ca4
编写于
5月 15, 2023
作者:
L
Leo Chen
提交者:
GitHub
5月 15, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix add_n kernel of large shape (#53767)
上级
268156f8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
6 addition
and
22 deletion
+6
-22
paddle/phi/kernels/gpu/add_n_kernel.cu
paddle/phi/kernels/gpu/add_n_kernel.cu
+6
-22
未找到文件。
paddle/phi/kernels/gpu/add_n_kernel.cu
浏览文件 @
ab385ca4
...
...
@@ -21,34 +21,20 @@ namespace phi {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
template
<
class
T
>
__global__
void
Sum2CUDAKernel
(
const
T
*
in_0
,
const
T
*
in_1
,
T
*
out
,
int64_t
N
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
while
(
id
<
N
)
{
out
[
id
]
=
in_0
[
id
]
+
in_1
[
id
];
id
+=
blockDim
.
x
*
gridDim
.
x
;
}
}
template
<
class
T
>
__global__
void
SumArrayCUDAKernel
(
T
**
in
,
T
*
out
,
int64_t
N
,
size_t
in_size
,
bool
read_dst
)
{
using
MPType
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
while
(
id
<
N
)
{
MPType
total
(
read_dst
?
static_cast
<
MPType
>
(
out
[
id
])
CUDA_KERNEL_LOOP_TYPE
(
idx
,
N
,
int64_t
)
{
MPType
total
(
read_dst
?
static_cast
<
MPType
>
(
out
[
idx
])
:
static_cast
<
MPType
>
(
0
));
for
(
int
i
=
0
;
i
<
in_size
;
++
i
)
{
const
T
*
tmp
=
in
[
i
];
if
(
tmp
)
{
total
+=
static_cast
<
MPType
>
(
tmp
[
id
]);
total
+=
static_cast
<
MPType
>
(
tmp
[
id
x
]);
}
}
out
[
id
]
=
static_cast
<
T
>
(
total
);
id
+=
blockDim
.
x
*
gridDim
.
x
;
out
[
idx
]
=
static_cast
<
T
>
(
total
);
}
}
...
...
@@ -56,16 +42,14 @@ template <class T>
__global__
void
SumSelectedRowsCUDAKernel
(
T
**
sr_in_out
,
int64_t
N
,
size_t
rows
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
while
(
id
<
N
)
{
CUDA_KERNEL_LOOP_TYPE
(
idx
,
N
,
int64_t
)
{
for
(
int
i
=
0
;
i
<
2
*
rows
;
i
+=
2
)
{
const
T
*
tmp
=
sr_in_out
[
i
];
T
*
tmp_out
=
sr_in_out
[
i
+
1
];
if
(
tmp
&&
tmp_out
)
{
tmp_out
[
id
]
+=
tmp
[
id
];
tmp_out
[
id
x
]
+=
tmp
[
idx
];
}
}
id
+=
blockDim
.
x
*
gridDim
.
x
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录