Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
6f1bb3d6
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6f1bb3d6
编写于
12月 28, 2021
作者:
G
Guoxia Wang
提交者:
GitHub
12月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix adamw epsilon in cuda kernel (#37746)
上级
340dfb26
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
27 addition
and
25 deletion
+27
-25
paddle/fluid/operators/optimizers/adamw_op.cu
paddle/fluid/operators/optimizers/adamw_op.cu
+27
-25
未找到文件。
paddle/fluid/operators/optimizers/adamw_op.cu
浏览文件 @
6f1bb3d6
...
...
@@ -27,25 +27,25 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
T
*
param_out
,
const
MT
*
master_param
,
MT
*
master_param_out
,
int
ndim
)
{
MT
lr
=
*
lr_
*
lr_ratio
;
MT
lr_orig
=
lr
;
MT
beta1_pow
=
beta1_pow_
;
MT
beta2_pow
=
beta2_pow_
;
lr
*=
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
);
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
id
<
ndim
;
id
+=
gridDim
.
x
*
blockDim
.
x
)
{
MT
p
=
master_param
?
master_param
[
id
]
:
static_cast
<
MT
>
(
param
[
id
]);
MT
g
=
static_cast
<
MT
>
(
grad
[
id
]);
MT
mom1
=
moment1
[
id
];
MT
mom2
=
moment2
[
id
];
MT
mom1
=
static_cast
<
MT
>
(
moment1
[
id
]);
MT
mom2
=
static_cast
<
MT
>
(
moment2
[
id
]);
p
*=
(
static_cast
<
MT
>
(
1.0
)
-
lr
*
coeff
);
mom1
=
beta1
*
mom1
+
(
static_cast
<
MT
>
(
1.0
)
-
beta1
)
*
g
;
mom2
=
beta2
*
mom2
+
(
static_cast
<
MT
>
(
1.0
)
-
beta2
)
*
g
*
g
;
p
-=
lr_orig
*
coeff
*
p
;
p
-=
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon
*
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)));
MT
denom
=
(
sqrt
(
mom2
)
/
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
))
+
epsilon
;
p
+=
(
mom1
/
denom
)
*
(
-
(
lr
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
)));
moment1_out
[
id
]
=
mom1
;
moment2_out
[
id
]
=
mom2
;
...
...
@@ -63,13 +63,9 @@ __global__ void AdamWKernelMEM(
MT
*
moment2_out
,
const
MT
*
lr_
,
const
T
*
grad
,
const
T
*
param
,
T
*
param_out
,
const
MT
*
master_param
,
MT
*
master_param_out
,
int
ndim
)
{
MT
lr
=
*
lr_
*
lr_ratio
;
MT
lr_orig
=
lr
;
MT
beta1_pow
=
*
beta1_pow_
;
MT
beta2_pow
=
*
beta2_pow_
;
lr
*=
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
);
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
id
<
ndim
;
id
+=
gridDim
.
x
*
blockDim
.
x
)
{
...
...
@@ -77,11 +73,15 @@ __global__ void AdamWKernelMEM(
MT
g
=
static_cast
<
MT
>
(
grad
[
id
]);
MT
mom1
=
static_cast
<
MT
>
(
moment1
[
id
]);
MT
mom2
=
static_cast
<
MT
>
(
moment2
[
id
]);
p
*=
(
static_cast
<
MT
>
(
1.0
)
-
lr
*
coeff
);
mom1
=
beta1
*
mom1
+
(
static_cast
<
MT
>
(
1.0
)
-
beta1
)
*
g
;
mom2
=
beta2
*
mom2
+
(
static_cast
<
MT
>
(
1.0
)
-
beta2
)
*
g
*
g
;
p
-=
lr_orig
*
coeff
*
p
;
p
-=
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon
*
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)));
MT
denom
=
(
sqrt
(
mom2
)
/
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
))
+
epsilon
;
p
+=
(
mom1
/
denom
)
*
(
-
(
lr
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
)));
moment1_out
[
id
]
=
mom1
;
moment2_out
[
id
]
=
mom2
;
...
...
@@ -109,10 +109,6 @@ __global__ void SparseAdamWCUDAKernelREG(
int
ndim
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
MT
lr
=
*
lr_
*
lr_ratio
;
MT
lr_orig
=
lr
;
lr
*=
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
);
for
(;
id
<
ndim
;
id
+=
blockDim
.
x
*
gridDim
.
x
)
{
auto
row_idx
=
...
...
@@ -120,17 +116,23 @@ __global__ void SparseAdamWCUDAKernelREG(
if
(
lazy_mode
&&
row_idx
<
0
)
{
return
;
}
else
{
MT
mom1
=
mom1_
[
id
];
MT
mom2
=
mom2_
[
id
];
MT
mom1
=
static_cast
<
MT
>
(
mom1_
[
id
]);
MT
mom2
=
static_cast
<
MT
>
(
mom2_
[
id
]);
MT
p
=
master_param
?
master_param
[
id
]
:
static_cast
<
MT
>
(
param_
[
id
]);
MT
g
=
row_idx
>=
0
?
static_cast
<
MT
>
(
grad_
[
row_idx
*
row_numel
+
id
%
row_numel
])
:
static_cast
<
MT
>
(
0
);
p
*=
(
static_cast
<
MT
>
(
1.0
)
-
lr
*
coeff
);
mom1
=
beta1
*
mom1
+
(
static_cast
<
MT
>
(
1.0
)
-
beta1
)
*
g
;
mom2
=
beta2
*
mom2
+
(
static_cast
<
MT
>
(
1.0
)
-
beta2
)
*
g
*
g
;
p
-=
lr_orig
*
coeff
*
p
;
p
-=
lr
*
(
mom1
/
(
sqrt
(
mom2
)
+
epsilon
*
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
)));
MT
denom
=
(
sqrt
(
mom2
)
/
sqrt
(
static_cast
<
MT
>
(
1.0
)
-
beta2_pow
))
+
epsilon
;
p
+=
(
mom1
/
denom
)
*
(
-
(
lr
/
(
static_cast
<
MT
>
(
1.0
)
-
beta1_pow
)));
// Write back to global memory
mom1_out_
[
id
]
=
mom1
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录