Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4b3e8d56
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4b3e8d56
编写于
6月 22, 2022
作者:
W
wawltor
提交者:
GitHub
6月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the cumsum bug for large size (#43722)
上级
561d09b9
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
5 addition
and
5 deletion
+5
-5
paddle/phi/kernels/gpu/cum_kernel.cu
paddle/phi/kernels/gpu/cum_kernel.cu
+5
-5
未找到文件。
paddle/phi/kernels/gpu/cum_kernel.cu
浏览文件 @
4b3e8d56
...
...
@@ -176,10 +176,8 @@ __global__ void BlockScanKernel(T* d_out,
}
temp_storage
;
int
bx
=
blockIdx
.
x
;
int
by
=
blockIdx
.
y
;
BlockPrefixCallbackOp
<
T
,
Op
>
prefix_op
(
Identity
<
T
,
Op
>::
value
,
op
);
T
block_aggregate
=
static_cast
<
T
>
(
0
);
// Obtain this block's segment of consecutive keys (blocked across threads)
int
item_per_block
=
BLOCK_THREADS
*
ITEMS_PER_THREAD
;
...
...
@@ -192,7 +190,7 @@ __global__ void BlockScanKernel(T* d_out,
valid_item
=
scan_size
;
}
int
offset
=
b
x
*
scan_size
+
block_offset
+
by
*
(
inner_size
*
scan_size
)
;
int
offset
=
b
lock_offset
+
bx
*
scan_size
;
T
thread_keys
[
ITEMS_PER_THREAD
];
BlockLoadT
(
temp_storage
.
load
)
...
...
@@ -307,6 +305,7 @@ void ScanKernel(const Context& dev_ctx,
int
outer_size
=
height
/
scan_size
;
int
inner_size
=
width
;
// Consider the size of shared memory, here block size is 128
dim3
scan_grid
(
outer_size
,
inner_size
);
dim3
reverse_grid
=
scan_grid
;
if
(
reverse
)
{
...
...
@@ -322,13 +321,14 @@ void ScanKernel(const Context& dev_ctx,
in_data
,
out_data
,
scan_size
,
outer_size
,
inner_size
);
}
}
int64_t
grid_size
=
outer_size
*
inner_size
;
if
(
!
transpose
&&
!
reverse
)
{
BlockScanKernel
<
T
,
128
,
4
,
Op
><<<
scan_grid
,
128
,
0
,
dev_ctx
.
stream
()
>>>
(
BlockScanKernel
<
T
,
128
,
4
,
Op
><<<
grid_size
,
128
,
0
,
dev_ctx
.
stream
()
>>>
(
out_data
,
in_data
,
outer_size
,
inner_size
,
scan_size
,
exclusive
,
op
);
}
else
{
BlockScanKernel
<
T
,
128
,
4
,
Op
>
<<<
scan_grid
,
128
,
0
,
dev_ctx
.
stream
()
>>>
(
next_out_data
,
<<<
grid_size
,
128
,
0
,
dev_ctx
.
stream
()
>>>
(
next_out_data
,
next_in_data
,
outer_size
,
inner_size
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录