Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e25b75b6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e25b75b6
编写于
4月 15, 2022
作者:
H
huangxu96
提交者:
GitHub
4月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix a bug which will casue cuda address error when the input size is very large (#41824)
As the title
上级
ea0a164b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
3 addition
and
9 deletion
+3
-9
paddle/fluid/operators/gather_scatter_kernel.cu
paddle/fluid/operators/gather_scatter_kernel.cu
+3
-9
未找到文件。
paddle/fluid/operators/gather_scatter_kernel.cu
浏览文件 @
e25b75b6
...
...
@@ -119,7 +119,7 @@ struct gpu_gather_scatter_functor {
is_scatter_like
?
self_dims
[
dim
]
:
src_dims
[
dim
];
int64_t
inner_dim_size
=
1
;
int64_t
outer_dim_size
=
1
;
for
(
int64_t
i
=
0
;
i
<
index_dims
.
size
()
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
dim
;
++
i
)
{
inner_dim_size
*=
index_dims
[
i
];
}
...
...
@@ -127,11 +127,8 @@ struct gpu_gather_scatter_functor {
outer_dim_size
*=
index_dims
[
i
];
}
int64_t
slice_size
=
1
;
for
(
int
i
=
1
;
i
<
src_dims
.
size
();
++
i
)
slice_size
*=
src_dims
[
i
];
int
block
=
512
;
int64_t
n
=
slice_size
*
index
_size
;
int64_t
n
=
inner_dim_size
*
select_dim_size
*
outer_dim
_size
;
int64_t
grid
=
(
n
+
block
-
1
)
/
block
;
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
();
...
...
@@ -215,11 +212,8 @@ void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
outer_dim_size
*=
index_dims
[
i
];
}
int64_t
slice_size
=
1
;
for
(
int
i
=
1
;
i
<
grad_dims
.
size
();
++
i
)
slice_size
*=
grad_dims
[
i
];
int
block
=
512
;
int64_t
n
=
slice_size
*
index
_size
;
int64_t
n
=
inner_dim_size
*
select_dim_size
*
outer_dim
_size
;
int64_t
grid
=
(
n
+
block
-
1
)
/
block
;
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录