Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
37f43ebc
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
37f43ebc
编写于
12月 10, 2021
作者:
L
Leo Chen
提交者:
GitHub
12月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix int32 overflow in cuda kernel loop (#38007)
上级
dabf8152
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
3 addition
and
6 deletion
+3
-6
paddle/fluid/operators/label_smooth_op.cu
paddle/fluid/operators/label_smooth_op.cu
+3
-6
未找到文件。
paddle/fluid/operators/label_smooth_op.cu
浏览文件 @
37f43ebc
...
...
@@ -21,8 +21,7 @@ template <typename T>
__global__
void
LabelSmoothRunOriginKernel
(
const
int
N
,
const
float
epsilon
,
const
int
label_dim
,
const
T
*
src
,
T
*
dst
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
CUDA_KERNEL_LOOP
(
idx
,
N
)
{
dst
[
idx
]
=
static_cast
<
T
>
(
1
-
epsilon
)
*
src
[
idx
]
+
static_cast
<
T
>
(
epsilon
/
label_dim
);
}
...
...
@@ -32,8 +31,7 @@ template <typename T>
__global__
void
LabelSmoothRunDistKernel
(
const
int
N
,
const
float
epsilon
,
const
int
dist_numel
,
const
T
*
src
,
const
T
*
dist_data
,
T
*
dst
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
CUDA_KERNEL_LOOP
(
idx
,
N
)
{
int
dist_idx
=
idx
%
dist_numel
;
dst
[
idx
]
=
static_cast
<
T
>
(
1
-
epsilon
)
*
src
[
idx
]
+
static_cast
<
T
>
(
epsilon
)
*
dist_data
[
dist_idx
];
...
...
@@ -43,8 +41,7 @@ __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon,
template
<
typename
T
>
__global__
void
LabelSmoothGradRunKernel
(
const
int
N
,
const
float
epsilon
,
const
T
*
src
,
T
*
dst
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
CUDA_KERNEL_LOOP
(
idx
,
N
)
{
dst
[
idx
]
=
static_cast
<
T
>
(
1
-
epsilon
)
*
src
[
idx
];
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录