Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
244e7546
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
244e7546
编写于
2月 06, 2023
作者:
W
wanghuancoder
提交者:
GitHub
2月 06, 2023
1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine optimizer create accumulators (#50188)
* refine optimizer create accumulators * refine
上级
eb8353a4
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
33 addition
and
0 deletion
+33
-0
python/paddle/optimizer/adadelta.py
python/paddle/optimizer/adadelta.py
+3
-0
python/paddle/optimizer/adagrad.py
python/paddle/optimizer/adagrad.py
+3
-0
python/paddle/optimizer/adam.py
python/paddle/optimizer/adam.py
+4
-0
python/paddle/optimizer/adamax.py
python/paddle/optimizer/adamax.py
+3
-0
python/paddle/optimizer/adamw.py
python/paddle/optimizer/adamw.py
+5
-0
python/paddle/optimizer/lamb.py
python/paddle/optimizer/lamb.py
+4
-0
python/paddle/optimizer/momentum.py
python/paddle/optimizer/momentum.py
+4
-0
python/paddle/optimizer/optimizer.py
python/paddle/optimizer/optimizer.py
+1
-0
python/paddle/optimizer/rmsprop.py
python/paddle/optimizer/rmsprop.py
+3
-0
python/paddle/optimizer/sgd.py
python/paddle/optimizer/sgd.py
+3
-0
未找到文件。
python/paddle/optimizer/adadelta.py
浏览文件 @
244e7546
...
@@ -145,8 +145,11 @@ class Adadelta(Optimizer):
...
@@ -145,8 +145,11 @@ class Adadelta(Optimizer):
parameters
=
parameters
.
get
(
'params'
)
parameters
=
parameters
.
get
(
'params'
)
for
p
in
parameters
:
for
p
in
parameters
:
if
p
.
name
in
self
.
_already_create_accumulater
:
continue
self
.
_add_accumulator
(
self
.
_avg_squared_grad_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_avg_squared_grad_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_avg_squared_update_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_avg_squared_update_acc_str
,
p
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
if
isinstance
(
param_and_grad
,
dict
):
if
isinstance
(
param_and_grad
,
dict
):
...
...
python/paddle/optimizer/adagrad.py
浏览文件 @
244e7546
...
@@ -139,11 +139,14 @@ class Adagrad(Optimizer):
...
@@ -139,11 +139,14 @@ class Adagrad(Optimizer):
parameters
=
self
.
_update_param_group
(
parameters
)
parameters
=
self
.
_update_param_group
(
parameters
)
for
p
in
parameters
:
for
p
in
parameters
:
if
p
.
name
in
self
.
_already_create_accumulater
:
continue
self
.
_add_accumulator
(
self
.
_add_accumulator
(
self
.
_moment_acc_str
,
self
.
_moment_acc_str
,
p
,
p
,
fill_value
=
self
.
initial_accumulator_value
,
fill_value
=
self
.
initial_accumulator_value
,
)
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
assert
isinstance
(
block
,
framework
.
Block
)
assert
isinstance
(
block
,
framework
.
Block
)
...
...
python/paddle/optimizer/adam.py
浏览文件 @
244e7546
...
@@ -317,9 +317,12 @@ class Adam(Optimizer):
...
@@ -317,9 +317,12 @@ class Adam(Optimizer):
# Create accumulator tensors for first and second moments
# Create accumulator tensors for first and second moments
for
p
in
parameters
:
for
p
in
parameters
:
if
p
.
name
in
self
.
_already_create_accumulater
:
continue
if
self
.
_multi_precision
and
self
.
_is_dtype_fp16_or_bf16
(
p
.
dtype
):
if
self
.
_multi_precision
and
self
.
_is_dtype_fp16_or_bf16
(
p
.
dtype
):
master_p
=
self
.
_create_master_weight
(
p
)
master_p
=
self
.
_create_master_weight
(
p
)
self
.
_add_moments_pows
(
master_p
)
self
.
_add_moments_pows
(
master_p
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
continue
continue
if
(
if
(
self
.
_is_dtype_fp16_or_bf16
(
p
.
dtype
)
self
.
_is_dtype_fp16_or_bf16
(
p
.
dtype
)
...
@@ -330,6 +333,7 @@ class Adam(Optimizer):
...
@@ -330,6 +333,7 @@ class Adam(Optimizer):
"Consider using multi_precision=True option of the Adam optimizer."
"Consider using multi_precision=True option of the Adam optimizer."
)
)
self
.
_add_moments_pows
(
p
)
self
.
_add_moments_pows
(
p
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
assert
isinstance
(
block
,
framework
.
Block
)
assert
isinstance
(
block
,
framework
.
Block
)
...
...
python/paddle/optimizer/adamax.py
浏览文件 @
244e7546
...
@@ -176,6 +176,8 @@ class Adamax(Optimizer):
...
@@ -176,6 +176,8 @@ class Adamax(Optimizer):
# Create accumulator tensors for first moment and infinity norm
# Create accumulator tensors for first moment and infinity norm
for
p
in
parameters
:
for
p
in
parameters
:
if
p
.
name
in
self
.
_already_create_accumulater
:
continue
self
.
_add_accumulator
(
self
.
_moment_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_moment_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_inf_norm_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_inf_norm_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_add_accumulator
(
...
@@ -184,6 +186,7 @@ class Adamax(Optimizer):
...
@@ -184,6 +186,7 @@ class Adamax(Optimizer):
fill_value
=
self
.
_beta1
,
fill_value
=
self
.
_beta1
,
shape
=
[
1
],
shape
=
[
1
],
)
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
assert
isinstance
(
block
,
framework
.
Block
)
assert
isinstance
(
block
,
framework
.
Block
)
...
...
python/paddle/optimizer/adamw.py
浏览文件 @
244e7546
...
@@ -281,6 +281,7 @@ class AdamW(Optimizer):
...
@@ -281,6 +281,7 @@ class AdamW(Optimizer):
self
.
_use_multi_tensor
=
None
self
.
_use_multi_tensor
=
None
self
.
regularization
=
None
self
.
regularization
=
None
self
.
_auxiliary_vars
=
{}
self
.
_auxiliary_vars
=
{}
self
.
_already_create_accumulater
=
set
()
def
_set_auxiliary_var
(
self
,
key
,
val
):
def
_set_auxiliary_var
(
self
,
key
,
val
):
self
.
_auxiliary_vars
[
key
]
=
val
self
.
_auxiliary_vars
[
key
]
=
val
...
@@ -422,9 +423,12 @@ class AdamW(Optimizer):
...
@@ -422,9 +423,12 @@ class AdamW(Optimizer):
# Create accumulator tensors for first and second moments
# Create accumulator tensors for first and second moments
for
p
in
parameters
:
for
p
in
parameters
:
if
p
.
name
in
self
.
_already_create_accumulater
:
continue
if
self
.
_multi_precision
and
self
.
_is_dtype_fp16_or_bf16
(
p
.
dtype
):
if
self
.
_multi_precision
and
self
.
_is_dtype_fp16_or_bf16
(
p
.
dtype
):
master_p
=
self
.
_create_master_weight
(
p
)
master_p
=
self
.
_create_master_weight
(
p
)
self
.
_add_moments_pows
(
master_p
)
self
.
_add_moments_pows
(
master_p
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
continue
continue
if
(
if
(
self
.
_is_dtype_fp16_or_bf16
(
p
.
dtype
)
self
.
_is_dtype_fp16_or_bf16
(
p
.
dtype
)
...
@@ -435,6 +439,7 @@ class AdamW(Optimizer):
...
@@ -435,6 +439,7 @@ class AdamW(Optimizer):
"Consider using multi_precision=True option of the Adam optimizer."
"Consider using multi_precision=True option of the Adam optimizer."
)
)
self
.
_add_moments_pows
(
p
)
self
.
_add_moments_pows
(
p
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
assert
isinstance
(
block
,
framework
.
Block
)
assert
isinstance
(
block
,
framework
.
Block
)
...
...
python/paddle/optimizer/lamb.py
浏览文件 @
244e7546
...
@@ -190,11 +190,15 @@ class Lamb(Optimizer):
...
@@ -190,11 +190,15 @@ class Lamb(Optimizer):
# Create accumulator tensors for first and second moments
# Create accumulator tensors for first and second moments
for
p
in
parameters
:
for
p
in
parameters
:
if
p
.
name
in
self
.
_already_create_accumulater
:
continue
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
master_p
=
self
.
_create_master_weight
(
p
)
master_p
=
self
.
_create_master_weight
(
p
)
self
.
_add_moments_pows
(
master_p
)
self
.
_add_moments_pows
(
master_p
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
else
:
else
:
self
.
_add_moments_pows
(
p
)
self
.
_add_moments_pows
(
p
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
def
_get_accumulator
(
self
,
name
,
param
):
def
_get_accumulator
(
self
,
name
,
param
):
"""Utility function to fetch an accumulator for a parameter
"""Utility function to fetch an accumulator for a parameter
...
...
python/paddle/optimizer/momentum.py
浏览文件 @
244e7546
...
@@ -270,9 +270,12 @@ class Momentum(Optimizer):
...
@@ -270,9 +270,12 @@ class Momentum(Optimizer):
parameters
=
self
.
_update_param_group
(
parameters
)
parameters
=
self
.
_update_param_group
(
parameters
)
for
p
in
parameters
:
for
p
in
parameters
:
if
p
.
name
in
self
.
_already_create_accumulater
:
continue
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
master_p
=
self
.
_create_master_weight
(
p
)
master_p
=
self
.
_create_master_weight
(
p
)
self
.
_add_accumulator
(
self
.
_velocity_acc_str
,
master_p
)
self
.
_add_accumulator
(
self
.
_velocity_acc_str
,
master_p
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
continue
continue
if
(
if
(
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
...
@@ -283,6 +286,7 @@ class Momentum(Optimizer):
...
@@ -283,6 +286,7 @@ class Momentum(Optimizer):
"Consider using multi_precision=True option of the Momentum optimizer."
"Consider using multi_precision=True option of the Momentum optimizer."
)
)
self
.
_add_accumulator
(
self
.
_velocity_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_velocity_acc_str
,
p
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
def
_create_regularization_of_grad
(
self
,
param
,
grad
,
regularization
=
None
):
def
_create_regularization_of_grad
(
self
,
param
,
grad
,
regularization
=
None
):
"""Create and add backward regularization Operators
"""Create and add backward regularization Operators
...
...
python/paddle/optimizer/optimizer.py
浏览文件 @
244e7546
...
@@ -275,6 +275,7 @@ class Optimizer:
...
@@ -275,6 +275,7 @@ class Optimizer:
self
.
_param_dict
=
self
.
_create_multi_tensor_dict
()
self
.
_param_dict
=
self
.
_create_multi_tensor_dict
()
self
.
_auxiliary_vars
=
{}
self
.
_auxiliary_vars
=
{}
self
.
_already_create_accumulater
=
set
()
def
_set_auxiliary_var
(
self
,
key
,
val
):
def
_set_auxiliary_var
(
self
,
key
,
val
):
self
.
_auxiliary_vars
[
key
]
=
val
self
.
_auxiliary_vars
[
key
]
=
val
...
...
python/paddle/optimizer/rmsprop.py
浏览文件 @
244e7546
...
@@ -199,9 +199,12 @@ class RMSProp(Optimizer):
...
@@ -199,9 +199,12 @@ class RMSProp(Optimizer):
parameters
=
parameters
.
get
(
'params'
)
parameters
=
parameters
.
get
(
'params'
)
for
p
in
parameters
:
for
p
in
parameters
:
if
p
.
name
in
self
.
_already_create_accumulater
:
continue
self
.
_add_accumulator
(
self
.
_momentum_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_momentum_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_mean_square_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_mean_square_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_mean_grad_acc_str
,
p
)
self
.
_add_accumulator
(
self
.
_mean_grad_acc_str
,
p
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
if
not
isinstance
(
block
,
framework
.
Block
):
if
not
isinstance
(
block
,
framework
.
Block
):
...
...
python/paddle/optimizer/sgd.py
浏览文件 @
244e7546
...
@@ -129,8 +129,11 @@ class SGD(Optimizer):
...
@@ -129,8 +129,11 @@ class SGD(Optimizer):
# Create accumulator tensors for first and second moments
# Create accumulator tensors for first and second moments
for
p
in
parameters
:
for
p
in
parameters
:
if
p
.
name
in
self
.
_already_create_accumulater
:
continue
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
master_p
=
self
.
_create_master_weight
(
p
)
master_p
=
self
.
_create_master_weight
(
p
)
self
.
_already_create_accumulater
.
add
(
p
.
name
)
continue
continue
if
(
if
(
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
...
...
saxon_zh
@saxon_zh
mentioned in commit
8a503522
·
2月 25, 2023
mentioned in commit
8a503522
mentioned in commit 8a50352216156c8cd723ed2fc482b611e552915c
开关提交列表
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录