Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
bb6bd223
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bb6bd223
编写于
8月 18, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
8月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] support ClipGradByGlobalNorm (#45205)
* add clip_grad * fix comments * add unittest * update logger
上级
d257acc6
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
334 addition
and
90 deletion
+334
-90
python/paddle/distributed/auto_parallel/completion.py
python/paddle/distributed/auto_parallel/completion.py
+90
-24
python/paddle/distributed/auto_parallel/dist_context.py
python/paddle/distributed/auto_parallel/dist_context.py
+48
-39
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+32
-18
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+6
-3
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+8
-3
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+5
-0
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py
.../paddle/fluid/tests/unittests/auto_parallel/engine_api.py
+1
-3
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py
.../fluid/tests/unittests/auto_parallel/test_dist_context.py
+27
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py
.../fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py
+116
-0
未找到文件。
python/paddle/distributed/auto_parallel/completion.py
浏览文件 @
bb6bd223
...
@@ -19,7 +19,7 @@ import time
...
@@ -19,7 +19,7 @@ import time
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
paddle.fluid
import
framework
from
paddle.fluid
import
framework
from
.utils
import
print_program_with_dist_attr
from
.utils
import
print_program_with_dist_attr
,
_is_gradient_clip_op
from
.operators
import
find_compatible_distributed_operator_impls
from
.operators
import
find_compatible_distributed_operator_impls
from
.dist_context
import
get_default_distributed_context
,
_node_id
from
.dist_context
import
get_default_distributed_context
,
_node_id
from
.dist_tensor
import
DistributedTensor
from
.dist_tensor
import
DistributedTensor
...
@@ -1319,26 +1319,70 @@ class Completer:
...
@@ -1319,26 +1319,70 @@ class Completer:
# TODO to add attribute for moment var
# TODO to add attribute for moment var
op
=
ops
[
idx
]
op
=
ops
[
idx
]
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Optimize
):
if
int
(
op
.
attr
(
'op_role'
))
==
int
(
OpRole
.
Optimize
):
if
op
.
type
==
"clip_by_norm"
:
# TODO:
param_grad
=
vars
[
op
.
input
(
"X"
)[
0
]]
# 1. move `generate_optimizer` before `partitioner`
param_grad_dist_attr
=
self
.
_dist_context
.
get_tensor_dist_attr_for_program
(
# 2. implement grad_clip completion by `dist_op`
param_grad
)
# 3. allreduce dist_gloabl_norm (mp-group) and no_dist_global_norm (pp-group, sharding-group)
assert
param_grad_dist_attr
is
not
None
if
_is_gradient_clip_op
(
op
):
ref_process_mesh
=
param_grad_dist_attr
.
process_mesh
if
op
.
type
in
[
ref_dims_mapping
=
param_grad_dist_attr
.
dims_mapping
"sum"
,
"sqrt"
,
"fill_constant"
,
"elementwise_max"
,
"elementwise_div"
out
=
vars
[
op
.
output
(
"Out"
)[
0
]]
]:
out_dist_attr
=
TensorDistributedAttribute
()
op_dist_attr
=
OperatorDistributedAttribute
()
out_dist_attr
.
process_mesh
=
ref_process_mesh
op_dist_attr
.
process_mesh
=
world_ranks
out_dist_attr
.
dims_mapping
=
ref_dims_mapping
for
in_name
in
op
.
input_arg_names
:
self
.
_dist_context
.
set_tensor_dist_attr_for_program
(
in_var
=
vars
[
in_name
]
out
,
out_dist_attr
)
in_dist_attr
=
self
.
_dist_context
.
get_tensor_dist_attr_for_program
(
in_var
)
op_dist_attr
.
set_input_dist_attr
(
in_name
,
in_dist_attr
)
for
out_name
in
op
.
output_arg_names
:
out_var
=
vars
[
out_name
]
out_dist_attr
=
TensorDistributedAttribute
()
out_dist_attr
.
process_mesh
=
world_ranks
out_dist_attr
.
dims_mapping
=
[
-
1
for
_
in
range
(
len
(
out_var
.
shape
))
]
self
.
_dist_context
.
set_tensor_dist_attr_for_program
(
out_var
,
out_dist_attr
)
op_dist_attr
.
set_output_dist_attr
(
out_name
,
out_dist_attr
)
remove_no_need_in_op
(
op
,
self
.
_dist_context
)
else
:
in_var
=
vars
[
op
.
input
(
"X"
)[
0
]]
in_dist_attr
=
self
.
_dist_context
.
get_tensor_dist_attr_for_program
(
in_var
)
assert
in_dist_attr
is
not
None
ref_process_mesh
=
in_dist_attr
.
process_mesh
ref_dims_mapping
=
in_dist_attr
.
dims_mapping
if
op
.
type
==
"cast"
and
ops
[
idx
+
1
].
type
==
"elementwise_mul"
:
ref_var
=
vars
[
ops
[
idx
+
1
].
input
(
"X"
)[
0
]]
ref_dist_attr
=
self
.
_dist_context
.
get_tensor_dist_attr_for_program
(
ref_var
)
assert
ref_dist_attr
is
not
None
ref_process_mesh
=
ref_dist_attr
.
process_mesh
out_var
=
vars
[
op
.
output
(
"Out"
)[
0
]]
out_dist_attr
=
TensorDistributedAttribute
()
out_dist_attr
.
process_mesh
=
ref_process_mesh
if
out_var
.
shape
==
in_var
.
shape
:
out_dist_attr
.
dims_mapping
=
ref_dims_mapping
else
:
assert
len
(
out_var
.
shape
)
==
1
and
out_var
.
shape
[
0
]
==
1
out_dist_attr
.
dims_mapping
=
[
-
1
]
self
.
_dist_context
.
set_tensor_dist_attr_for_program
(
out_var
,
out_dist_attr
)
op_dist_attr
=
OperatorDistributedAttribute
()
op_dist_attr
.
process_mesh
=
ref_process_mesh
op_dist_attr
.
set_input_dist_attr
(
in_var
.
name
,
in_dist_attr
)
op_dist_attr
.
set_output_dist_attr
(
out_var
.
name
,
out_dist_attr
)
op_dist_attr
=
OperatorDistributedAttribute
()
op_dist_attr
.
process_mesh
=
ref_process_mesh
op_dist_attr
.
set_input_dist_attr
(
param_grad
.
name
,
param_grad_dist_attr
)
op_dist_attr
.
set_output_dist_attr
(
out
.
name
,
out_dist_attr
)
self
.
_dist_context
.
set_op_dist_attr_for_program
(
self
.
_dist_context
.
set_op_dist_attr_for_program
(
op
,
op_dist_attr
)
op
,
op_dist_attr
)
...
@@ -1383,11 +1427,17 @@ class Completer:
...
@@ -1383,11 +1427,17 @@ class Completer:
for
input_name
in
op
.
desc
.
input_names
():
for
input_name
in
op
.
desc
.
input_names
():
if
input_name
in
[
if
input_name
in
[
'Param'
,
'Grad'
,
'LearningRate'
,
"SkipUpdate"
,
'Param'
,
"Beta1Tensor"
,
"Beta2Tensor"
,
"EpsilonTensor"
,
'Grad'
,
"MasterParam"
'LearningRate'
,
"SkipUpdate"
,
"Beta1Tensor"
,
"Beta2Tensor"
,
"EpsilonTensor"
,
]:
]:
continue
continue
if
len
(
op
.
desc
.
input
(
input_name
))
==
0
:
continue
assert
len
(
op
.
desc
.
input
(
input_name
))
==
1
assert
len
(
op
.
desc
.
input
(
input_name
))
==
1
input_var
=
vars
[
op
.
desc
.
input
(
input_name
)[
0
]]
input_var
=
vars
[
op
.
desc
.
input
(
input_name
)[
0
]]
...
@@ -1400,7 +1450,6 @@ class Completer:
...
@@ -1400,7 +1450,6 @@ class Completer:
op_dist_attr
.
set_output_dims_mapping
(
op_dist_attr
.
set_output_dims_mapping
(
input_var
.
name
,
[
-
1
])
input_var
.
name
,
[
-
1
])
else
:
else
:
assert
"Moment"
in
input_name
or
"Velocity"
in
input_name
input_var_attr
.
dims_mapping
=
ref_dims_mapping
input_var_attr
.
dims_mapping
=
ref_dims_mapping
op_dist_attr
.
set_input_dims_mapping
(
op_dist_attr
.
set_input_dims_mapping
(
input_var
.
name
,
ref_dims_mapping
)
input_var
.
name
,
ref_dims_mapping
)
...
@@ -1481,3 +1530,20 @@ class Completer:
...
@@ -1481,3 +1530,20 @@ class Completer:
break
break
else
:
else
:
dist_op
.
dist_attr
=
backup_op_dist_attr
dist_op
.
dist_attr
=
backup_op_dist_attr
def
remove_no_need_in_op
(
op
,
dist_context
):
if
op
.
type
==
"fill_constant"
:
return
filter_vars
=
[]
main_block
=
op
.
block
rank_id
=
dist_context
.
dist_op_context
.
rank_id
for
varname
in
op
.
input
(
"X"
):
if
rank_id
in
dist_context
.
get_tensor_dist_attr_for_program
(
main_block
.
var
(
varname
)).
process_mesh
.
processes
:
filter_vars
.
append
(
varname
)
if
not
filter_vars
:
return
op
.
desc
.
set_input
(
'X'
,
filter_vars
)
python/paddle/distributed/auto_parallel/dist_context.py
浏览文件 @
bb6bd223
...
@@ -68,7 +68,6 @@ class DistributedContext:
...
@@ -68,7 +68,6 @@ class DistributedContext:
self
.
_original_serial_loss
=
serial_loss
self
.
_original_serial_loss
=
serial_loss
self
.
_original_serial_feed_vars
=
feed_vars
self
.
_original_serial_feed_vars
=
feed_vars
self
.
_original_serial_fetch_vars
=
fetch_vars
self
.
_original_serial_fetch_vars
=
fetch_vars
self
.
_original_serial_optimizer
=
serial_optimizer
# Data members related to programs (changed)
# Data members related to programs (changed)
self
.
_serial_main_program
=
None
self
.
_serial_main_program
=
None
...
@@ -77,6 +76,7 @@ class DistributedContext:
...
@@ -77,6 +76,7 @@ class DistributedContext:
self
.
_serial_optimizer
=
None
self
.
_serial_optimizer
=
None
self
.
_serial_feed_vars
=
{}
self
.
_serial_feed_vars
=
{}
self
.
_serial_fetch_vars
=
{}
self
.
_serial_fetch_vars
=
{}
self
.
_lr_optimizer
=
None
# record the optimzier holding lr_scheduler
# Data members related to the program
# Data members related to the program
self
.
_dist_tensors_for_program
=
{}
self
.
_dist_tensors_for_program
=
{}
...
@@ -126,7 +126,7 @@ class DistributedContext:
...
@@ -126,7 +126,7 @@ class DistributedContext:
self
.
_data_parallel
=
False
self
.
_data_parallel
=
False
# flag whether using `to_static`
# flag whether using `to_static`
self
.
_dygraph_mode
=
Tru
e
self
.
_dygraph_mode
=
Fals
e
@
property
@
property
def
serial_main_program
(
self
):
def
serial_main_program
(
self
):
...
@@ -235,31 +235,20 @@ class DistributedContext:
...
@@ -235,31 +235,20 @@ class DistributedContext:
if
dist
:
if
dist
:
self
.
_backup_dist_info
(
dist_mode
)
self
.
_backup_dist_info
(
dist_mode
)
def
_restore_serial_info
(
self
,
mode
=
"to_backup"
):
def
_restore_serial_loss
(
self
):
if
mode
==
"to_backup"
:
self
.
_serial_main_program
=
self
.
_backup_serial_main_program_stack
.
pop
(
)
self
.
_serial_startup_program
=
self
.
_backup_serial_startup_program_stack
.
pop
(
)
elif
mode
==
"to_original"
:
assert
self
.
_original_serial_main_program
is
not
None
assert
self
.
_original_serial_startup_program
is
not
None
self
.
_serial_main_program
=
self
.
_original_serial_main_program
.
clone
(
)
self
.
_serial_startup_program
=
self
.
_original_serial_startup_program
.
clone
(
)
self
.
_serial_optimizer
=
self
.
_original_serial_optimizer
if
self
.
_original_serial_loss
:
if
self
.
_original_serial_loss
:
if
isinstance
(
self
.
_original_serial_loss
,
list
):
if
isinstance
(
self
.
_original_serial_loss
,
list
):
assert
len
(
self
.
_original_serial_loss
)
==
1
if
len
(
self
.
_original_serial_loss
)
==
1
:
loss
=
self
.
_original_serial_loss
[
0
]
loss
=
self
.
_original_serial_loss
[
0
]
block_idx
=
loss
.
block
.
idx
block_idx
=
loss
.
block
.
idx
var_name
=
loss
.
name
var_name
=
loss
.
name
var
=
self
.
_serial_main_program
.
blocks
[
var
=
self
.
_serial_main_program
.
blocks
[
block_idx
].
_var_recursive
(
var_name
)
block_idx
].
_var_recursive
(
var_name
)
self
.
_serial_loss
=
var
self
.
_serial_loss
=
var
elif
len
(
self
.
_original_serial_loss
)
==
0
:
self
.
_serial_loss
=
[]
else
:
raise
ValueError
(
"multi loss vars are not supported."
)
else
:
else
:
block_idx
=
self
.
_original_serial_loss
.
block
.
idx
block_idx
=
self
.
_original_serial_loss
.
block
.
idx
var_name
=
self
.
_original_serial_loss
.
name
var_name
=
self
.
_original_serial_loss
.
name
...
@@ -267,6 +256,7 @@ class DistributedContext:
...
@@ -267,6 +256,7 @@ class DistributedContext:
block_idx
].
_var_recursive
(
var_name
)
block_idx
].
_var_recursive
(
var_name
)
self
.
_serial_loss
=
var
self
.
_serial_loss
=
var
def
_restore_serial_feed_vars
(
self
):
for
key
,
var_list
in
self
.
_original_serial_feed_vars
.
items
():
for
key
,
var_list
in
self
.
_original_serial_feed_vars
.
items
():
new_var_list
=
[]
new_var_list
=
[]
for
var
in
var_list
:
for
var
in
var_list
:
...
@@ -277,6 +267,7 @@ class DistributedContext:
...
@@ -277,6 +267,7 @@ class DistributedContext:
new_var_list
.
append
(
var
)
new_var_list
.
append
(
var
)
self
.
_serial_feed_vars
[
key
]
=
new_var_list
self
.
_serial_feed_vars
[
key
]
=
new_var_list
def
_restore_serial_fetch_vars
(
self
):
for
key
,
var_list
in
self
.
_original_serial_fetch_vars
.
items
():
for
key
,
var_list
in
self
.
_original_serial_fetch_vars
.
items
():
new_var_list
=
[]
new_var_list
=
[]
for
var
in
var_list
:
for
var
in
var_list
:
...
@@ -287,6 +278,24 @@ class DistributedContext:
...
@@ -287,6 +278,24 @@ class DistributedContext:
new_var_list
.
append
(
var
)
new_var_list
.
append
(
var
)
self
.
_serial_fetch_vars
[
key
]
=
new_var_list
self
.
_serial_fetch_vars
[
key
]
=
new_var_list
def
_restore_serial_info
(
self
,
mode
=
"to_backup"
):
if
mode
==
"to_backup"
:
self
.
_serial_main_program
=
self
.
_backup_serial_main_program_stack
.
pop
(
)
self
.
_serial_startup_program
=
self
.
_backup_serial_startup_program_stack
.
pop
(
)
elif
mode
==
"to_original"
:
assert
self
.
_original_serial_main_program
is
not
None
assert
self
.
_original_serial_startup_program
is
not
None
self
.
_serial_main_program
=
self
.
_original_serial_main_program
.
clone
(
)
self
.
_serial_startup_program
=
self
.
_original_serial_startup_program
.
clone
(
)
self
.
_restore_serial_loss
()
self
.
_restore_serial_feed_vars
()
self
.
_restore_serial_fetch_vars
()
self
.
_serial_optimizer
=
self
.
_original_serial_optimizer
self
.
_pass_context
=
self
.
_backup_pass_context_stack
.
pop
()
self
.
_pass_context
=
self
.
_backup_pass_context_stack
.
pop
()
self
.
_block_state
=
self
.
_backup_block_state_stack
.
pop
()
self
.
_block_state
=
self
.
_backup_block_state_stack
.
pop
()
...
@@ -353,25 +362,21 @@ class DistributedContext:
...
@@ -353,25 +362,21 @@ class DistributedContext:
def
initialize
(
self
,
with_graph
=
True
):
def
initialize
(
self
,
with_graph
=
True
):
if
not
self
.
_is_initialized
:
if
not
self
.
_is_initialized
:
if
not
self
.
_serial_main_program
:
if
not
self
.
_serial_main_program
:
self
.
_serial_main_program
=
self
.
_original_serial_main_program
if
self
.
_original_serial_main_program
:
self
.
_serial_main_program
=
self
.
_original_serial_main_program
.
clone
(
)
if
not
self
.
_serial_startup_program
:
if
not
self
.
_serial_startup_program
:
self
.
_serial_startup_program
=
self
.
_original_serial_startup_program
if
self
.
_original_serial_startup_program
:
self
.
_serial_startup_program
=
self
.
_original_serial_startup_program
.
clone
(
)
if
not
self
.
_serial_loss
:
if
not
self
.
_serial_loss
:
if
isinstance
(
self
.
_original_serial_loss
,
list
):
self
.
_restore_serial_loss
()
if
len
(
self
.
_original_serial_loss
)
==
1
:
self
.
_serial_loss
=
self
.
_original_serial_loss
[
0
]
elif
len
(
self
.
_original_serial_loss
)
==
0
:
self
.
_serial_loss
=
self
.
_original_serial_loss
else
:
raise
ValueError
(
"multi loss vars are not supported."
)
else
:
self
.
_serial_loss
=
self
.
_original_serial_loss
if
not
self
.
_serial_optimizer
:
if
not
self
.
_serial_optimizer
:
self
.
_serial_optimizer
=
self
.
_original_serial_optimizer
self
.
_serial_optimizer
=
self
.
_original_serial_optimizer
if
not
self
.
_serial_feed_vars
:
if
not
self
.
_serial_feed_vars
:
self
.
_
serial_feed_vars
=
self
.
_original_serial_feed_vars
self
.
_
restore_serial_feed_vars
()
if
not
self
.
_serial_fetch_vars
:
if
not
self
.
_serial_fetch_vars
:
self
.
_
serial_fetch_vars
=
self
.
_original_serial_fetch_vars
self
.
_
restore_serial_fetch_vars
()
self
.
_init_dist_attr_for_program
()
self
.
_init_dist_attr_for_program
()
# Backup the original distributed information for later restore
# Backup the original distributed information for later restore
...
@@ -856,7 +861,11 @@ class DistributedContext:
...
@@ -856,7 +861,11 @@ class DistributedContext:
"_serial_main_program"
,
"_serial_startup_program"
,
"_serial_graph"
,
\
"_serial_main_program"
,
"_serial_startup_program"
,
"_serial_graph"
,
\
"_dist_main_programs"
,
"_dist_startup_programs"
,
\
"_dist_main_programs"
,
"_dist_startup_programs"
,
\
"_serial_ordered_nodes"
,
"_serial_ordered_tensor_nodes"
,
\
"_serial_ordered_nodes"
,
"_serial_ordered_tensor_nodes"
,
\
"_serial_ordered_op_nodes"
]:
"_serial_ordered_op_nodes"
,
"_original_serial_loss"
,
\
"_original_serial_feed_vars"
,
"_original_serial_fetch_vars"
,
\
"_serial_loss"
,
"_serial_feed_vars"
,
"_serial_fetch_vars"
,
"_lr_optimizer"
,
\
"_backup_serial_main_program_stack"
,
"_backup_serial_startup_program_stack"
,
\
"_pass_context"
]:
setattr
(
result
,
k
,
v
)
setattr
(
result
,
k
,
v
)
else
:
else
:
setattr
(
result
,
k
,
copy
.
deepcopy
(
v
,
memo
))
setattr
(
result
,
k
,
copy
.
deepcopy
(
v
,
memo
))
...
...
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
bb6bd223
...
@@ -16,7 +16,6 @@ import time
...
@@ -16,7 +16,6 @@ import time
import
copy
import
copy
import
logging
import
logging
from
collections
import
defaultdict
from
collections
import
defaultdict
import
socket
import
paddle
import
paddle
import
paddle.utils
as
utils
import
paddle.utils
as
utils
...
@@ -35,7 +34,6 @@ from paddle.fluid.framework import Operator, Parameter, _non_static_mode
...
@@ -35,7 +34,6 @@ from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from
paddle.fluid.framework
import
_current_expected_place
as
_get_device
from
paddle.fluid.framework
import
_current_expected_place
as
_get_device
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.distributed
import
fleet
from
paddle.distributed
import
fleet
from
paddle.distributed.utils
import
get_logger
from
paddle.distributed.passes
import
new_pass
,
PassContext
from
paddle.distributed.passes
import
new_pass
,
PassContext
from
.hepler
import
ProgramHelper
from
.hepler
import
ProgramHelper
...
@@ -76,7 +74,18 @@ class Engine:
...
@@ -76,7 +74,18 @@ class Engine:
self
.
_cur_rank
=
paddle
.
distributed
.
get_rank
()
self
.
_cur_rank
=
paddle
.
distributed
.
get_rank
()
self
.
_nranks
=
paddle
.
distributed
.
get_world_size
()
self
.
_nranks
=
paddle
.
distributed
.
get_world_size
()
self
.
_saver
=
DistributedSaver
()
self
.
_saver
=
DistributedSaver
()
self
.
_logger
=
get_logger
(
logging
.
INFO
)
# TODO: add logger module
self
.
_logger
=
logging
.
getLogger
()
self
.
_logger
.
propagate
=
False
if
not
self
.
_logger
.
handlers
:
self
.
_logger
.
setLevel
(
logging
.
INFO
)
log_handler
=
logging
.
StreamHandler
()
log_format
=
logging
.
Formatter
(
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler
.
setFormatter
(
log_format
)
self
.
_logger
.
addHandler
(
log_handler
)
self
.
_orig_main_prog
=
static
.
default_main_program
()
self
.
_orig_main_prog
=
static
.
default_main_program
()
self
.
_orig_startup_prog
=
static
.
default_startup_program
()
self
.
_orig_startup_prog
=
static
.
default_startup_program
()
...
@@ -307,7 +316,7 @@ class Engine:
...
@@ -307,7 +316,7 @@ class Engine:
mode
].
dist_startup_programs
mode
].
dist_startup_programs
self
.
_feed_vars
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_feed_vars
self
.
_feed_vars
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_feed_vars
self
.
_fetch_vars
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_fetch_vars
self
.
_fetch_vars
[
mode
]
=
self
.
_dist_contexts
[
mode
].
serial_fetch_vars
self
.
_
optimizer
=
self
.
_dist_contexts
[
mode
].
serial
_optimizer
self
.
_
lr_optimizer
=
self
.
_dist_contexts
[
mode
].
_lr
_optimizer
if
self
.
_nranks
>
1
:
if
self
.
_nranks
>
1
:
# Traverse different rank programs and traverse each op of them,
# Traverse different rank programs and traverse each op of them,
...
@@ -429,25 +438,27 @@ class Engine:
...
@@ -429,25 +438,27 @@ class Engine:
lr_scheduler
=
self
.
get_lr_scheduler
(
self
.
main_program
)
lr_scheduler
=
self
.
get_lr_scheduler
(
self
.
main_program
)
for
epoch
in
range
(
epochs
):
for
epoch
in
range
(
epochs
):
train_logs
=
{
"epoch"
:
epoch
}
train_logs
=
{
"epoch
: {:d}
"
:
epoch
}
for
step
,
_
in
enumerate
(
train_dataloader
):
for
step
,
_
in
enumerate
(
train_dataloader
):
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
fetch_list
=
fetch_list
,
fetch_list
=
fetch_list
,
use_program_cache
=
use_cache
,
use_program_cache
=
use_cache
,
return_numpy
=
return_numpy
)
return_numpy
=
return_numpy
)
train_logs
[
"step: {:d} "
]
=
step
if
lr_scheduler
is
not
None
:
if
lr_scheduler
is
not
None
:
lr_scheduler
.
step
()
lr_scheduler
.
step
()
train_logs
[
"lr"
]
=
self
.
_optimizer
.
get_lr
()
train_logs
[
"lr: {:5e} "
]
=
self
.
_lr_optimizer
.
get_lr
()
train_logs
[
"step"
]
=
step
# inner fetches
# inner fetches
if
fetch_loss
:
if
fetch_loss
:
train_logs
[
"
train_loss
"
]
=
outs
[
0
][
0
]
train_logs
[
"
loss: {:9f}
"
]
=
outs
[
0
][
0
]
# user fetches
# user fetches
user_outs
=
outs
[
len
(
fetch_loss
):]
user_outs
=
outs
[
len
(
fetch_loss
):]
user_fetch_list
=
fetch_list
[
len
(
fetch_loss
):]
user_fetch_list
=
fetch_list
[
len
(
fetch_loss
):]
for
i
,
out
in
enumerate
(
user_outs
):
for
i
,
out
in
enumerate
(
user_outs
):
train_logs
[
"train_"
+
fetch_map
[
user_fetch_list
[
i
]]]
=
out
train_logs
[
fetch_map
[
user_fetch_list
[
i
]]
+
": {}"
]
=
out
self
.
_logger
.
info
(
train_logs
)
# logger
string
=
'[train] '
+
''
.
join
(
list
(
train_logs
.
keys
()))
self
.
_logger
.
info
(
string
.
format
(
*
list
(
train_logs
.
values
())))
def
evaluate
(
self
,
def
evaluate
(
self
,
eval_data
,
eval_data
,
...
@@ -473,14 +484,14 @@ class Engine:
...
@@ -473,14 +484,14 @@ class Engine:
fetch_list
,
fetch_map
=
self
.
_fetch_map
(
inner_fetch
,
usr_fetch
)
fetch_list
,
fetch_map
=
self
.
_fetch_map
(
inner_fetch
,
usr_fetch
)
for
step
,
_
in
enumerate
(
eval_dataloader
):
for
step
,
_
in
enumerate
(
eval_dataloader
):
eval_logs
=
{
"step"
:
step
}
eval_logs
=
{
"step
: {:d}
"
:
step
}
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
fetch_list
=
fetch_list
,
fetch_list
=
fetch_list
,
use_program_cache
=
use_cache
,
use_program_cache
=
use_cache
,
return_numpy
=
return_numpy
)
return_numpy
=
return_numpy
)
# inner fetches
# inner fetches
if
fetch_loss
:
if
fetch_loss
:
eval_logs
[
"
eval_loss
"
]
=
outs
[
0
][
0
]
eval_logs
[
"
loss: {:9f}
"
]
=
outs
[
0
][
0
]
# Metric
# Metric
if
fetch_metrics
:
if
fetch_metrics
:
metric_out
=
outs
[
len
(
fetch_loss
):
len
(
inner_fetch
)]
metric_out
=
outs
[
len
(
fetch_loss
):
len
(
inner_fetch
)]
...
@@ -488,14 +499,15 @@ class Engine:
...
@@ -488,14 +499,15 @@ class Engine:
metric
.
update
(
*
metric_out
)
metric
.
update
(
*
metric_out
)
results
=
metric
.
accumulate
()
results
=
metric
.
accumulate
()
for
i
,
res
in
enumerate
(
to_list
(
results
)):
for
i
,
res
in
enumerate
(
to_list
(
results
)):
eval_logs
[
"eval_"
+
metric
.
name
()[
i
]
]
=
res
eval_logs
[
metric
.
name
()[
i
]
+
": {:9f} "
]
=
res
# usr fetches
# usr fetches
usr_outs
=
outs
[
len
(
inner_fetch
):]
usr_outs
=
outs
[
len
(
inner_fetch
):]
usr_fetch_list
=
fetch_list
[
len
(
inner_fetch
):]
usr_fetch_list
=
fetch_list
[
len
(
inner_fetch
):]
for
i
,
out
in
enumerate
(
usr_outs
):
for
i
,
out
in
enumerate
(
usr_outs
):
eval_logs
[
"eval_"
+
fetch_map
[
usr_fetch_list
[
i
]]
]
=
out
eval_logs
[
fetch_map
[
usr_fetch_list
[
i
]]
+
": {}"
]
=
out
# logger
# logger
self
.
_logger
.
info
(
eval_logs
)
string
=
'[eval] '
+
''
.
join
(
list
(
eval_logs
.
keys
()))
self
.
_logger
.
info
(
string
.
format
(
*
list
(
eval_logs
.
values
())))
def
predict
(
self
,
def
predict
(
self
,
test_data
,
test_data
,
...
@@ -520,15 +532,17 @@ class Engine:
...
@@ -520,15 +532,17 @@ class Engine:
outputs
=
[]
outputs
=
[]
for
step
,
_
in
enumerate
(
test_dataloader
):
for
step
,
_
in
enumerate
(
test_dataloader
):
predict_logs
=
{
"step"
:
step
}
predict_logs
=
{
"step
: {:d}
"
:
step
}
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
outs
=
self
.
_executor
.
run
(
self
.
main_program
,
fetch_list
=
fetch_list
,
fetch_list
=
fetch_list
,
use_program_cache
=
use_cache
,
use_program_cache
=
use_cache
,
return_numpy
=
return_numpy
)
return_numpy
=
return_numpy
)
outputs
.
append
(
outs
[:
len
(
fetch_outputs
)])
outputs
.
append
(
outs
[:
len
(
fetch_outputs
)])
for
i
,
out
in
enumerate
(
outs
):
for
i
,
out
in
enumerate
(
outs
):
predict_logs
[
"pred_"
+
fetch_map
[
fetch_list
[
i
]]]
=
out
predict_logs
[
fetch_map
[
fetch_list
[
i
]]
+
": {}"
]
=
out
self
.
_logger
.
info
(
predict_logs
)
# logger
string
=
'[pred] '
+
''
.
join
(
list
(
predict_logs
.
keys
()))
self
.
_logger
.
info
(
string
.
format
(
*
list
(
predict_logs
.
values
())))
return
outputs
return
outputs
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
bb6bd223
...
@@ -20,7 +20,7 @@ from collections import defaultdict
...
@@ -20,7 +20,7 @@ from collections import defaultdict
import
paddle
import
paddle
from
paddle.fluid
import
program_guard
from
paddle.fluid
import
program_guard
from
paddle.fluid.backward
import
append_backward
from
paddle.fluid.backward
import
append_backward
from
paddle.fluid.framework
import
_non_static_mode
from
paddle.fluid.framework
import
_non_static_mode
,
unique_name
from
paddle.distributed.passes
import
new_pass
from
paddle.distributed.passes
import
new_pass
from
paddle.distributed.utils
import
get_logger
from
paddle.distributed.utils
import
get_logger
...
@@ -143,15 +143,18 @@ class Parallelizer:
...
@@ -143,15 +143,18 @@ class Parallelizer:
def
_generate_optimizer
(
self
,
main_program
,
startup_program
,
optimizer
,
def
_generate_optimizer
(
self
,
main_program
,
startup_program
,
optimizer
,
params_grads
):
params_grads
):
# NOTE: `apply_gradients` will add an Accumulator for a parameter only once,
# but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
if
self
.
_dist_context
.
_dygraph_mode
:
if
self
.
_dist_context
.
_dygraph_mode
:
paddle
.
disable_static
()
paddle
.
disable_static
()
optimizer
=
copy
.
deepcopy
(
optimizer
)
optimizer
=
copy
.
deepcopy
(
optimizer
)
paddle
.
enable_static
()
paddle
.
enable_static
()
else
:
else
:
optimizer
=
copy
.
deepcopy
(
optimizer
)
optimizer
=
copy
.
deepcopy
(
optimizer
)
self
.
_dist_context
.
_
serial
_optimizer
=
optimizer
self
.
_dist_context
.
_
lr
_optimizer
=
optimizer
with
program_guard
(
main_program
,
startup_program
):
with
program_guard
(
main_program
,
startup_program
):
optimizer_ops
=
optimizer
.
apply_gradients
(
params_grads
)
with
unique_name
.
guard
(
"opt_"
):
optimizer_ops
=
optimizer
.
apply_gradients
(
params_grads
)
self
.
_completer
.
complete_update_annotation
(
main_program
)
self
.
_completer
.
complete_update_annotation
(
main_program
)
return
optimizer_ops
return
optimizer_ops
...
...
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
bb6bd223
...
@@ -30,10 +30,13 @@ from .cost import build_comm_desc, CommContext
...
@@ -30,10 +30,13 @@ from .cost import build_comm_desc, CommContext
from
.cost
import
AllgatherOpCost
,
SendOpCost
from
.cost
import
AllgatherOpCost
,
SendOpCost
from
.cost
import
SliceOpCost
,
SplitOpCost
,
ConcatOpCost
from
.cost
import
SliceOpCost
,
SplitOpCost
,
ConcatOpCost
from
.cluster
import
Cluster
from
.cluster
import
Cluster
from
.utils
import
print_program_with_dist_attr
from
.utils
import
print_program_with_dist_attr
,
_is_gradient_clip_op
# NOTE: If op in _g_special_ops, it will not be resharded.
# NOTE: If op in _g_special_ops
or _g_gradient_clip_ops
, it will not be resharded.
_g_special_ops
=
[
'check_finite_and_unscale'
,
'update_loss_scaling'
]
_g_special_ops
=
[
'check_finite_and_unscale'
,
'update_loss_scaling'
]
_g_gradient_clip_ops
=
[
"sum"
,
"sqrt"
,
"fill_constant"
,
"elementwise_max"
,
"elementwise_div"
]
def
get_var_with_recursion
(
var_name
,
block
,
program
):
def
get_var_with_recursion
(
var_name
,
block
,
program
):
...
@@ -1076,9 +1079,11 @@ class Resharder:
...
@@ -1076,9 +1079,11 @@ class Resharder:
return
True
return
True
def
is_special_op
(
self
,
op
):
def
is_special_op
(
self
,
op
):
global
_g_special_ops
global
_g_special_ops
,
_g_gradient_clip_ops
if
op
.
type
in
_g_special_ops
:
if
op
.
type
in
_g_special_ops
:
return
True
return
True
if
_is_gradient_clip_op
(
op
)
and
op
.
type
in
_g_gradient_clip_ops
:
return
True
return
False
return
False
def
is_condition_replicative
(
self
,
op
):
def
is_condition_replicative
(
self
,
op
):
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
bb6bd223
...
@@ -1131,6 +1131,11 @@ def is_loss_grad_op(op):
...
@@ -1131,6 +1131,11 @@ def is_loss_grad_op(op):
return
op_role
&
int
(
OpRole
.
Backward
)
and
op_role
&
int
(
OpRole
.
Loss
)
return
op_role
&
int
(
OpRole
.
Backward
)
and
op_role
&
int
(
OpRole
.
Loss
)
def
_is_gradient_clip_op
(
op
):
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/gradient_clip"
)
def
is_prim_op
(
op
):
def
is_prim_op
(
op
):
return
op
.
type
.
endswith
(
"_p"
)
return
op
.
type
.
endswith
(
"_p"
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
bb6bd223
...
@@ -64,4 +64,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
...
@@ -64,4 +64,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_cluster_v2 MODULES test_cluster_v2
)
py_test_modules
(
test_cluster_v2 MODULES test_cluster_v2
)
py_test_modules
(
test_process_mesh_v2 MODULES test_process_mesh_v2
)
py_test_modules
(
test_process_mesh_v2 MODULES test_process_mesh_v2
)
py_test_modules
(
test_dist_attr_v2 MODULES test_dist_attr_v2
)
py_test_modules
(
test_dist_attr_v2 MODULES test_dist_attr_v2
)
py_test_modules
(
test_lr_grad_clip MODULES test_lr_grad_clip
)
endif
()
endif
()
python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py
浏览文件 @
bb6bd223
...
@@ -108,9 +108,7 @@ def train(fetch):
...
@@ -108,9 +108,7 @@ def train(fetch):
dropout_ratio
=
0.1
,
dropout_ratio
=
0.1
,
initializer_range
=
0.02
)
initializer_range
=
0.02
)
loss
=
paddle
.
nn
.
CrossEntropyLoss
()
loss
=
paddle
.
nn
.
CrossEntropyLoss
()
scheduler
=
paddle
.
optimizer
.
lr
.
CosineAnnealingDecay
(
learning_rate
=
0.00001
,
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.00001
,
T_max
=
10
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
scheduler
,
beta1
=
0.9
,
beta1
=
0.9
,
beta2
=
0.999
,
beta2
=
0.999
,
epsilon
=
1e-08
,
epsilon
=
1e-08
,
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py
浏览文件 @
bb6bd223
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
unittest
import
unittest
import
os
import
os
import
json
import
json
import
copy
import
paddle
import
paddle
import
numpy
as
np
import
numpy
as
np
...
@@ -194,6 +195,32 @@ class TestDistributedContext(unittest.TestCase):
...
@@ -194,6 +195,32 @@ class TestDistributedContext(unittest.TestCase):
dist_context
.
_backup
(
serial
=
True
,
dist
=
True
)
dist_context
.
_backup
(
serial
=
True
,
dist
=
True
)
dist_context
.
_restore
(
serial
=
True
,
dist
=
True
,
dist_mode
=
"to_nothing"
)
dist_context
.
_restore
(
serial
=
True
,
dist
=
True
,
dist_mode
=
"to_nothing"
)
def
test_deepcopy
(
self
):
train_program
,
start_program
,
dataloader
,
loss
,
optimizer
,
feed_vars
,
fetch_vars
=
get_program
(
)
dist_context
=
DistributedContext
(
train_program
,
start_program
,
optimizer
,
loss
,
feed_vars
,
fetch_vars
)
dist_context
.
initialize
()
copy_dist_context
=
copy
.
deepcopy
(
dist_context
)
copy_list
=
[
"_original_serial_main_program"
,
"_original_serial_startup_program"
,
\
"_serial_main_program"
,
"_serial_startup_program"
,
"_serial_graph"
,
\
"_dist_main_programs"
,
"_dist_startup_programs"
,
\
"_serial_ordered_nodes"
,
"_serial_ordered_tensor_nodes"
,
\
"_serial_ordered_op_nodes"
,
"_original_serial_loss"
,
\
"_original_serial_feed_vars"
,
"_original_serial_fetch_vars"
,
\
"_serial_loss"
,
"_serial_feed_vars"
,
"_serial_fetch_vars"
,
"_lr_optimizer"
,
\
"_backup_serial_main_program_stack"
,
"_backup_serial_startup_program_stack"
,
\
"_pass_context"
]
for
i
in
range
(
len
(
copy_list
)):
copy_obj
=
"copy_dist_context."
+
copy_list
[
i
]
obj
=
"dist_context."
+
copy_list
[
i
]
assert
id
(
eval
(
copy_obj
))
==
id
(
eval
(
obj
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py
0 → 100644
浏览文件 @
bb6bd223
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
os
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
paddle.distributed.auto_parallel
as
auto
import
paddle.distributed.fleet
as
fleet
from
paddle.io
import
Dataset
from
paddle.static
import
InputSpec
from
paddle.fluid.framework
import
_non_static_mode
from
paddle.distributed.auto_parallel.engine
import
Engine
from
paddle.distributed.auto_parallel.hepler
import
ProgramHelper
from
test_to_static
import
MLPLayer
,
MyDataset
paddle
.
enable_static
()
class
TestEngineBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
4
self
.
batch_num
=
5
self
.
hidden_size
=
1024
self
.
init_model
()
self
.
init_optimizer
()
self
.
init_dataset
()
self
.
init_engine
()
def
init_model
(
self
):
self
.
mlp
=
MLPLayer
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
4
*
self
.
hidden_size
,
dropout_ratio
=
0.1
,
initializer_range
=
0.02
)
self
.
loss
=
paddle
.
nn
.
CrossEntropyLoss
()
def
init_optimizer
(
self
):
self
.
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.00001
,
parameters
=
self
.
mlp
.
parameters
())
def
init_dataset
(
self
):
self
.
dataset
=
MyDataset
(
self
.
batch_num
*
self
.
batch_size
)
def
init_engine
(
self
):
inputs
=
InputSpec
([
self
.
batch_size
,
self
.
hidden_size
],
'float32'
,
'x'
)
labels
=
InputSpec
([
self
.
batch_size
],
'int64'
,
'label'
)
self
.
engine
=
Engine
(
model
=
self
.
mlp
,
inputs_spec
=
inputs
,
labels_spec
=
labels
)
self
.
engine
.
prepare
(
optimizer
=
self
.
optimizer
,
loss
=
self
.
loss
,
metrics
=
paddle
.
metric
.
Accuracy
())
class
TestLRScheduler
(
TestEngineBase
):
def
init_optimizer
(
self
):
scheduler
=
paddle
.
optimizer
.
lr
.
CosineAnnealingDecay
(
learning_rate
=
0.00001
,
T_max
=
10
)
self
.
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
scheduler
)
def
test_lr_scheduler
(
self
):
self
.
init_engine
()
lr
=
self
.
engine
.
_optimizer
.
_learning_rate
assert
isinstance
(
lr
,
paddle
.
optimizer
.
lr
.
LRScheduler
)
self
.
engine
.
fit
(
self
.
dataset
,
batch_size
=
self
.
batch_size
)
class
TestGradClip
(
TestEngineBase
):
def
init_optimizer
(
self
):
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
clip_norm
=
1.0
)
self
.
optimizer
=
paddle
.
optimizer
.
SGD
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
def
test_grad_clip
(
self
):
clip
=
self
.
engine
.
_optimizer
.
_grad_clip
assert
isinstance
(
clip
,
paddle
.
nn
.
ClipGradByGlobalNorm
)
self
.
engine
.
fit
(
self
.
dataset
,
batch_size
=
self
.
batch_size
)
self
.
check_program
()
def
check_program
(
self
):
ops
=
self
.
engine
.
main_program
.
global_block
().
ops
has_grad_clip
=
False
for
op
in
ops
:
if
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/gradient_clip"
):
has_grad_clip
=
True
break
assert
has_grad_clip
is
True
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录