Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6ab0a6a8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6ab0a6a8
编写于
7月 27, 2021
作者:
W
WangXi
提交者:
GitHub
7月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid parallel] pipeline support adamw and LRScheduler (#34402)
上级
ede001f9
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
45 addition
and
5 deletion
+45
-5
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+10
-0
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+8
-0
python/paddle/fluid/tests/unittests/pipeline_mnist.py
python/paddle/fluid/tests/unittests/pipeline_mnist.py
+3
-3
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+20
-1
python/paddle/optimizer/adamw.py
python/paddle/optimizer/adamw.py
+4
-1
未找到文件。
python/paddle/fluid/executor.py
浏览文件 @
6ab0a6a8
...
...
@@ -1664,6 +1664,16 @@ class Executor(object):
print_period
,
fetch_handler
,
use_program_cache
)
from
paddle.optimizer.lr
import
LRScheduler
if
hasattr
(
program
,
'lr_sheduler'
):
lr_sheduler
=
program
.
lr_sheduler
assert
isinstance
(
lr_sheduler
,
LRScheduler
),
"must be LRScheduler"
lr_value
=
lr_sheduler
()
lr_var
=
program
.
global_block
().
vars
[
lr_sheduler
.
_var_name
]
data
=
np
.
array
([
lr_value
]).
astype
(
convert_dtype
(
lr_var
.
dtype
))
tensor
=
core
.
get_variable_tensor
(
scope
,
lr_sheduler
.
_var_name
)
tensor
.
set
(
data
,
self
.
place
)
self
.
_default_executor
.
run_from_dataset
(
trainer_instance
)
if
not
use_program_cache
:
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
6ab0a6a8
...
...
@@ -4634,6 +4634,9 @@ class PipelineOptimizer(object):
op
.
type
==
'elementwise_div'
):
device
=
f
"
{
self
.
_device
}
:all"
op
.
_set_attr
(
self
.
_op_device_key
,
device
)
elif
self
.
_is_weight_decay_op
(
op
)
and
op
.
type
==
'scale'
:
# set AdamW decay_coeff to device:all
op
.
_set_attr
(
self
.
_op_device_key
,
f
"
{
self
.
_device
}
:all"
)
elif
op
.
type
==
"alloc_float_status"
:
op
.
_set_attr
(
self
.
_op_device_key
,
f
"
{
self
.
_device
}
:all"
)
else
:
...
...
@@ -5267,6 +5270,11 @@ class PipelineOptimizer(object):
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/regularization"
)
def
_is_weight_decay_op
(
self
,
op
):
# in AdamW namescope is /optimizer_*/weight decay/
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
'weight decay'
in
op
.
desc
.
attr
(
"op_namescope"
)
def
_get_input_output_info
(
self
,
block
):
'''
Get info of op input and output.
...
...
python/paddle/fluid/tests/unittests/pipeline_mnist.py
浏览文件 @
6ab0a6a8
...
...
@@ -116,10 +116,10 @@ class TestDistMnist2x2(TestDistRunnerBase):
steps_per_pass
=
10
bd
=
[
steps_per_pass
*
p
for
p
in
passes
]
lr
=
[
base_lr
*
(
0.1
**
i
)
for
i
in
range
(
len
(
bd
)
+
1
)]
lr_val
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
bd
,
values
=
lr
)
opt
=
fluid
.
optimizer
.
Momentum
(
lr_val
=
paddle
.
optimizer
.
lr
.
PiecewiseDecay
(
boundaries
=
bd
,
values
=
lr
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_val
,
momentum
=
0.9
,
grad_clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
1.0
))
acc_steps
=
2
# accumulated steps for pipeline
...
...
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
6ab0a6a8
...
...
@@ -96,6 +96,15 @@ class TestDistRunnerBase(object):
current_endpoint
=
current_endpoint
)
return
t
@
staticmethod
def
get_lr_scheduler
(
program
):
lr_sheduler
=
None
if
hasattr
(
program
,
'lr_sheduler'
):
from
paddle.optimizer.lr
import
LRScheduler
lr_sheduler
=
program
.
lr_sheduler
assert
isinstance
(
lr_sheduler
,
LRScheduler
),
"must be LRScheduler"
return
lr_sheduler
def
run_pserver
(
self
,
args
):
self
.
lr
=
args
.
lr
self
.
get_model
(
batch_size
=
args
.
batch_size
)
...
...
@@ -139,11 +148,17 @@ class TestDistRunnerBase(object):
data_loader
.
start
()
print_to_err
(
type
(
self
).
__name__
,
"begin to train on trainer"
)
out_losses
=
[]
main_program
=
fluid
.
default_main_program
()
lr_sheduler
=
self
.
get_lr_scheduler
(
main_program
)
for
i
in
six
.
moves
.
xrange
(
RUN_STEP
):
loss
=
exe
.
run
(
fluid
.
default_main_program
()
,
fetch_list
=
[
avg_cost
])
loss
=
exe
.
run
(
main_program
,
fetch_list
=
[
avg_cost
])
loss
=
loss
[
0
]
if
loss
else
None
out_losses
.
append
(
loss
)
print_to_err
(
type
(
self
).
__name__
,
"run step %d finished"
%
i
)
if
lr_sheduler
is
not
None
:
lr_sheduler
.
step
()
data_loader
.
reset
()
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
...
...
@@ -494,6 +509,7 @@ class TestDistRunnerBase(object):
else
:
return
origin_batch
lr_scheduler
=
self
.
get_lr_scheduler
(
trainer_prog
)
print_to_err
(
type
(
self
).
__name__
,
"begin to train on trainer"
)
out_losses
=
[]
for
i
in
six
.
moves
.
xrange
(
RUN_STEP
):
...
...
@@ -502,6 +518,9 @@ class TestDistRunnerBase(object):
feed
=
feeder
.
feed
(
get_data
()))
out_losses
.
append
(
loss
[
0
])
print_to_err
(
type
(
self
).
__name__
,
"run step %d finished"
%
i
)
if
lr_scheduler
is
not
None
:
lr_scheduler
.
step
()
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
print_to_out
(
out_losses
)
...
...
python/paddle/optimizer/adamw.py
浏览文件 @
6ab0a6a8
...
...
@@ -160,6 +160,7 @@ class AdamW(Adam):
self
.
_apply_decay_param_fun
=
apply_decay_param_fun
self
.
_coeff
=
coeff
self
.
_lr_to_coeff
=
dict
()
super
(
AdamW
,
self
).
__init__
(
learning_rate
=
learning_rate
,
parameters
=
parameters
,
...
...
@@ -211,7 +212,9 @@ class AdamW(Adam):
# we do this in _create_optimization_pass
decay_coeff
=
self
.
_lr_to_coeff
.
get
(
learning_rate
,
None
)
if
decay_coeff
is
None
:
decay_coeff
=
1.0
-
learning_rate
*
self
.
_coeff
# NOTE(wangxi): for pipeline to set device:all
with
paddle
.
static
.
device_guard
(
None
):
decay_coeff
=
1.0
-
learning_rate
*
self
.
_coeff
self
.
_lr_to_coeff
[
learning_rate
]
=
decay_coeff
find_master
=
(
self
.
_multi_precision
and
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录