Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
db2c71a4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
db2c71a4
编写于
7月 07, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
7月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] fix 'op_role' for gradient merge & recompute (#44138)
* fix op_role * fix engine * update op_role
上级
7e3833a7
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
182 addition
and
184 deletion
+182
-184
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+87
-118
python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py
.../auto_parallel/operators/dist_check_finite_and_unscale.py
+3
-3
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
...dle/distributed/auto_parallel/operators/dist_embedding.py
+19
-9
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+39
-15
python/paddle/distributed/auto_parallel/partitioner.py
python/paddle/distributed/auto_parallel/partitioner.py
+5
-3
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+9
-6
python/paddle/distributed/passes/auto_parallel_amp.py
python/paddle/distributed/passes/auto_parallel_amp.py
+2
-2
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
...paddle/distributed/passes/auto_parallel_gradient_merge.py
+18
-28
未找到文件。
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
db2c71a4
...
...
@@ -18,7 +18,6 @@ from collections import defaultdict
import
paddle
import
paddle.utils
as
utils
import
paddle.distributed.auto_parallel
as
auto
from
paddle
import
fluid
,
static
from
paddle.io
import
Dataset
...
...
@@ -72,7 +71,6 @@ class Engine:
self
.
_saver
=
DistributedSaver
()
self
.
_logger
=
get_logger
(
logging
.
INFO
)
self
.
_default_strategy
=
None
self
.
_orig_main_prog
=
static
.
default_main_program
()
self
.
_orig_startup_prog
=
static
.
default_startup_program
()
self
.
_orig_dist_context
=
get_default_distributed_context
()
...
...
@@ -117,9 +115,11 @@ class Engine:
self
.
_planned_mode
=
None
self
.
_modes
=
[
'train'
,
'eval'
,
'predict'
]
self
.
_build
()
# Do auto parallel process
# Build program and do auto parallel process
for
mode
in
self
.
_modes
:
# Build forward program
self
.
_build
(
mode
)
for
mode
in
self
.
_modes
:
# Do the planning process
self
.
_plan
(
mode
)
...
...
@@ -129,56 +129,49 @@ class Engine:
# Init comm and startup program
self
.
_initialize
(
mode
)
def
_build
(
self
):
for
mode
in
self
.
_modes
:
serial_main_prog
=
self
.
_serial_main_progs
.
get
(
mode
,
None
)
if
serial_main_prog
is
not
None
:
return
losses
=
[]
metrics
=
[]
serial_main_prog
=
self
.
_orig_main_prog
.
clone
()
serial_startup_prog
=
self
.
_orig_startup_prog
.
clone
()
with
static
.
program_guard
(
serial_main_prog
,
serial_startup_prog
),
\
utils
.
unique_name
.
guard
():
inputs_spec
=
self
.
inputs_spec
labels_spec
=
self
.
labels_spec
if
self
.
labels_spec
else
[]
inputs
=
[
s
.
_create_feed_layer
()
for
s
in
inputs_spec
]
labels
=
[
s
.
_create_feed_layer
()
for
s
in
labels_spec
]
outputs
=
to_list
(
self
.
model
(
*
inputs
))
if
mode
!=
"predict"
and
self
.
_loss
:
losses
=
to_list
(
self
.
_loss
(
*
(
outputs
+
labels
)))
if
mode
!=
"predict"
:
for
metric
in
self
.
_metrics
:
metrics
.
extend
(
to_list
(
metric
.
compute
(
*
(
outputs
+
labels
))))
default_ctx
=
get_default_distributed_context
()
if
not
default_ctx
.
has_annotation
or
self
.
_default_strategy
:
# We build the world process group because the data parallel
# needs all ranks by default.
new_process_group
(
list
(
range
(
self
.
_nranks
)))
default_ctx
.
data_parallel
=
True
# self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars
=
{
"inputs"
:
inputs
,
"labels"
:
labels
}
# self._fetch_vars[mode] = {
# "outputs": flatten(outputs),
# "loss": losses,
# "metrics": metrics
# }
fetch_vars
=
{
"outputs"
:
flatten
(
outputs
),
"loss"
:
losses
,
"metrics"
:
metrics
}
self
.
_dist_contexts
[
mode
]
=
DistributedContext
(
serial_main_prog
,
serial_startup_prog
,
self
.
_optimizer
,
losses
,
feed_vars
,
fetch_vars
,
self
.
cluster
,
self
.
strategy
)
self
.
_dist_contexts
[
mode
].
gradient_scale
=
self
.
_gradient_scale
def
_build
(
self
,
mode
):
serial_main_prog
=
self
.
_serial_main_progs
.
get
(
mode
,
None
)
if
serial_main_prog
is
not
None
:
return
losses
=
[]
metrics
=
[]
serial_main_prog
=
self
.
_orig_main_prog
.
clone
()
serial_startup_prog
=
self
.
_orig_startup_prog
.
clone
()
with
static
.
program_guard
(
serial_main_prog
,
serial_startup_prog
),
\
utils
.
unique_name
.
guard
():
inputs_spec
=
self
.
inputs_spec
labels_spec
=
self
.
labels_spec
if
self
.
labels_spec
else
[]
inputs
=
[
s
.
_create_feed_layer
()
for
s
in
inputs_spec
]
labels
=
[
s
.
_create_feed_layer
()
for
s
in
labels_spec
]
outputs
=
to_list
(
self
.
model
(
*
inputs
))
if
mode
!=
"predict"
and
self
.
_loss
:
losses
=
to_list
(
self
.
_loss
(
*
(
outputs
+
labels
)))
if
mode
!=
"predict"
:
for
metric
in
self
.
_metrics
:
metrics
.
extend
(
to_list
(
metric
.
compute
(
*
(
outputs
+
labels
))))
default_ctx
=
get_default_distributed_context
()
if
not
default_ctx
.
has_annotation
:
# We build the world process group because the data parallel
# needs all ranks by default.
new_process_group
(
list
(
range
(
self
.
_nranks
)))
default_ctx
.
data_parallel
=
True
feed_vars
=
{
"inputs"
:
inputs
,
"labels"
:
labels
}
fetch_vars
=
{
"outputs"
:
flatten
(
outputs
),
"loss"
:
losses
,
"metrics"
:
metrics
}
self
.
_dist_contexts
[
mode
]
=
DistributedContext
(
serial_main_prog
,
serial_startup_prog
,
self
.
_optimizer
,
losses
,
feed_vars
,
fetch_vars
,
self
.
cluster
,
self
.
strategy
)
self
.
_dist_contexts
[
mode
].
gradient_scale
=
self
.
_gradient_scale
def
_plan
(
self
,
mode
):
if
self
.
_planned_mode
is
None
:
...
...
@@ -240,7 +233,6 @@ class Engine:
continue
process_group
.
instantiate
()
# initialize
self
.
_place
=
_get_device
()
if
isinstance
(
self
.
_place
,
fluid
.
CUDAPlace
):
self
.
_place
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
...
...
@@ -273,8 +265,8 @@ class Engine:
train_dataloader
=
self
.
_create_dataloader
(
train_data
,
batch_size
,
epochs
,
steps_per_epoch
)
usr_fetch
=
self
.
_
to_map_fetch
(
fetches
)
fetch_loss
=
self
.
_
inner_fetch
(
self
.
fetch_vars
[
"loss"
])
usr_fetch
=
self
.
_
validate_fetches
(
fetches
)
fetch_loss
=
self
.
_
validate_fetches
(
self
.
fetch_vars
[
"loss"
])
fetch_list
,
fetch_map
=
self
.
_fetch_map
(
fetch_loss
,
usr_fetch
)
for
epoch
in
range
(
epochs
):
...
...
@@ -292,8 +284,7 @@ class Engine:
user_outs
=
outs
[
len
(
fetch_loss
):]
user_fetch_list
=
fetch_list
[
len
(
fetch_loss
):]
for
i
,
out
in
enumerate
(
user_outs
):
train_logs
[
"train_"
+
fetch_map
[
user_fetch_list
[
i
]]]
=
out
[
0
]
train_logs
[
"train_"
+
fetch_map
[
user_fetch_list
[
i
]]]
=
out
self
.
_logger
.
info
(
train_logs
)
def
evaluate
(
self
,
...
...
@@ -307,9 +298,9 @@ class Engine:
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader
=
self
.
_create_dataloader
(
eval_data
,
batch_size
)
usr_fetch
=
self
.
_
to_map_fetch
(
fetches
)
fetch_loss
=
self
.
_
inner_fetch
(
self
.
fetch_vars
[
"loss"
])
fetch_metrics
=
self
.
_
inner_fetch
(
self
.
fetch_vars
[
"metrics"
])
usr_fetch
=
self
.
_
validate_fetches
(
fetches
)
fetch_loss
=
self
.
_
validate_fetches
(
self
.
fetch_vars
[
"loss"
])
fetch_metrics
=
self
.
_
validate_fetches
(
self
.
fetch_vars
[
"metrics"
])
inner_fetch
=
dict
(
fetch_loss
,
**
fetch_metrics
)
fetch_list
,
fetch_map
=
self
.
_fetch_map
(
inner_fetch
,
usr_fetch
)
...
...
@@ -321,7 +312,7 @@ class Engine:
return_numpy
=
return_numpy
)
# inner fetches
if
fetch_loss
:
eval_logs
[
"eval_loss"
]
=
outs
[
0
]
eval_logs
[
"eval_loss"
]
=
outs
[
0
]
[
0
]
# Metric
if
fetch_metrics
:
metric_out
=
outs
[
len
(
fetch_loss
):
len
(
inner_fetch
)]
...
...
@@ -331,9 +322,9 @@ class Engine:
for
i
,
res
in
enumerate
(
to_list
(
results
)):
eval_logs
[
"eval_"
+
metric
.
name
()[
i
]]
=
res
# usr fetches
usr_out
=
outs
[
len
(
inner_fetch
):]
usr_out
s
=
outs
[
len
(
inner_fetch
):]
usr_fetch_list
=
fetch_list
[
len
(
inner_fetch
):]
for
i
,
out
in
enumerate
(
usr_out
):
for
i
,
out
in
enumerate
(
usr_out
s
):
eval_logs
[
"eval_"
+
fetch_map
[
usr_fetch_list
[
i
]]]
=
out
# logger
self
.
_logger
.
info
(
eval_logs
)
...
...
@@ -349,8 +340,8 @@ class Engine:
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader
=
self
.
_create_dataloader
(
test_data
,
batch_size
)
usr_fetch
=
self
.
_
to_map_fetch
(
fetches
)
fetch_outputs
=
self
.
_
inner_fetch
(
self
.
fetch_vars
[
"outputs"
])
usr_fetch
=
self
.
_
validate_fetches
(
fetches
)
fetch_outputs
=
self
.
_
validate_fetches
(
self
.
fetch_vars
[
"outputs"
])
fetch_list
,
fetch_map
=
self
.
_fetch_map
(
fetch_outputs
,
usr_fetch
)
outputs
=
[]
...
...
@@ -362,42 +353,11 @@ class Engine:
return_numpy
=
return_numpy
)
outputs
.
append
(
outs
[:
len
(
fetch_outputs
)])
for
i
,
out
in
enumerate
(
outs
):
predict_logs
[
"pred_"
+
fetch_map
[
fetch_list
[
i
]]]
=
out
[
0
]
predict_logs
[
"pred_"
+
fetch_map
[
fetch_list
[
i
]]]
=
out
self
.
_logger
.
info
(
predict_logs
)
return
outputs
def
_local_var
(
self
,
var
):
var_name
=
_to_name_str
(
var
)
return
var_name
in
self
.
main_program
.
global_block
().
vars
def
_to_map_fetch
(
self
,
fetches
):
if
not
fetches
:
return
{}
if
isinstance
(
fetches
,
dict
):
fetch_var_names
=
list
(
map
(
_to_name_str
,
fetches
.
values
()))
usr_fetches
=
dict
(
zip
(
fetch_var_names
,
list
(
fetches
.
keys
())))
elif
isinstance
(
fetches
,
list
):
fetch_var_names
=
list
(
map
(
_to_name_str
,
fetches
))
usr_fetches
=
dict
(
zip
(
fetch_var_names
,
fetch_var_names
))
return
dict
(
filter
(
lambda
x
:
self
.
_local_var
(
x
[
0
]),
usr_fetches
.
items
()))
def
_inner_fetch
(
self
,
fetch_vars
):
fetch_list
=
list
(
map
(
lambda
x
:
x
.
name
,
list
(
filter
(
self
.
_local_var
,
fetch_vars
))))
inner_fetches
=
dict
(
zip
(
fetch_list
,
fetch_list
))
return
inner_fetches
def
_fetch_map
(
self
,
inner_fetch
,
usr_fetch
):
# replace inner fetch name if usr set for it
for
iname
in
inner_fetch
:
if
iname
in
usr_fetch
:
inner_fetch
[
iname
]
=
usr_fetch
[
iname
]
usr_fetch
.
pop
(
iname
)
fetches
=
dict
(
inner_fetch
,
**
usr_fetch
)
return
list
(
fetches
.
keys
()),
fetches
def
_create_dataloader
(
self
,
dataset
,
batch_size
,
...
...
@@ -468,26 +428,35 @@ class Engine:
.
format
(
i
,
spec
))
return
specs
def
_set_data_parallel
(
self
,
var
):
if
self
.
_nranks
==
1
:
self
.
_default_strategy
=
'serial'
auto
.
shard_tensor
(
var
,
dist_attr
=
{
"process_mesh"
:
[
0
],
"dims_mapping"
:
[
-
1
for
_
in
range
(
len
(
var
.
shape
))]
})
def
_is_local_var
(
self
,
var
):
var_name
=
_to_name_str
(
var
)
return
var_name
in
self
.
main_program
.
global_block
().
vars
def
_validate_fetches
(
self
,
fetches
):
# 1. Check user-defined fetches type
# 2. Prepare fetches_dict like {user_defined_name: var_name}
if
not
fetches
:
return
{}
if
isinstance
(
fetches
,
dict
):
fetch_var_names
=
list
(
map
(
_to_name_str
,
fetches
.
values
()))
fetches_dict
=
dict
(
zip
(
fetch_var_names
,
list
(
fetches
.
keys
())))
elif
isinstance
(
fetches
,
list
):
fetch_var_names
=
list
(
map
(
_to_name_str
,
fetches
))
fetches_dict
=
dict
(
zip
(
fetch_var_names
,
fetch_var_names
))
else
:
self
.
_default_strategy
=
'dp'
auto
.
shard_tensor
(
var
,
dist_attr
=
{
"process_mesh"
:
list
(
range
(
self
.
_nranks
)),
"dims_mapping"
:
[
0
]
+
[
-
1
for
_
in
range
(
len
(
var
.
shape
)
-
1
)]
})
return
var
raise
TypeError
(
"'fetches' only support 'dict' and 'list', "
"but got '{}'"
.
format
(
str
(
type
(
fetches
))))
return
dict
(
filter
(
lambda
x
:
self
.
_is_local_var
(
x
[
0
]),
fetches_dict
.
items
()))
def
_fetch_map
(
self
,
inner_fetch
,
usr_fetch
):
# replace inner fetch name if usr set for it
for
iname
in
inner_fetch
:
if
iname
in
usr_fetch
:
inner_fetch
[
iname
]
=
usr_fetch
[
iname
]
usr_fetch
.
pop
(
iname
)
fetches
=
dict
(
inner_fetch
,
**
usr_fetch
)
return
list
(
fetches
.
keys
()),
fetches
def
_get_data_parallel_info
(
self
,
var
,
dist_context
):
# get data parallel world size and current data parallel rank
...
...
python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py
浏览文件 @
db2c71a4
...
...
@@ -137,7 +137,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs
=
{
"in_dtype"
:
inf_var
.
dtype
,
"out_dtype"
:
inf_var_int32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Backward
OP_ROLE_KEY
:
OpRole
.
Optimize
})
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_int32
},
...
...
@@ -145,7 +145,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
OP_ROLE_KEY
:
OpRole
.
Optimize
})
cast_op2
=
main_block
.
append_op
(
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_int32
},
...
...
@@ -153,7 +153,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs
=
{
"in_dtype"
:
inf_var_int32
.
dtype
,
"out_dtype"
:
inf_var
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Backward
OP_ROLE_KEY
:
OpRole
.
Optimize
})
main_block
.
_sync_with_cpp
()
...
...
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
浏览文件 @
db2c71a4
...
...
@@ -222,7 +222,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'W'
:
[
Weight_var
]
},
outputs
=
{
'Out'
:
[
intermediate_var_0
]},
attrs
=
{
"start_index"
:
relative_idx
})
attrs
=
{
"start_index"
:
relative_idx
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
intermediate_var_0
.
shape
!=
ref_shape
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -235,6 +238,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -442,6 +446,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
dp_group
=
new_process_group
(
group_ranks
)
if
need_gradient_allreduce
:
added_ops
=
[]
W_Grad_var
=
main_block
.
var
(
kwargs
[
'W@GRAD'
][
0
])
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
W_Grad_var
]},
...
...
@@ -451,19 +456,24 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
W_Grad_var
},
outputs
=
{
'Out'
:
W_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
allreduce_op
)
if
ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
W_Grad_var
},
outputs
=
{
'Out'
:
W_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
scale_op
)
main_block
.
_sync_with_cpp
()
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
W_Grad_var
).
dims_mapping
process_mesh
=
dist_attr
.
process_mesh
for
op
in
[
allreduce_op
,
scale_op
]
:
for
op
in
added_ops
:
op_attr
=
OperatorDistributedAttribute
()
op_attr
.
process_mesh
=
process_mesh
op_attr
.
set_output_dims_mapping
(
W_Grad_var
.
name
,
dims_mapping
)
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
db2c71a4
...
...
@@ -405,6 +405,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
dp_group
=
new_process_group
(
group_ranks
)
if
need_gradient_allreduce
and
is_parameter_related
(
Y_var
.
name
,
main_block
):
added_ops
=
[]
Y_Grad_var
=
main_block
.
var
(
kwargs
[
'Y@GRAD'
][
0
])
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
Y_Grad_var
]},
...
...
@@ -414,19 +415,24 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
Y_Grad_var
},
outputs
=
{
'Out'
:
Y_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
allreduce_op
)
if
ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
Y_Grad_var
},
outputs
=
{
'Out'
:
Y_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
scale_op
)
main_block
.
_sync_with_cpp
()
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
Y_Grad_var
).
dims_mapping
process_mesh
=
dist_attr
.
process_mesh
for
op
in
[
allreduce_op
,
scale_op
]
:
for
op
in
added_ops
:
op_attr
=
OperatorDistributedAttribute
()
op_attr
.
process_mesh
=
process_mesh
op_attr
.
set_output_dims_mapping
(
Y_Grad_var
.
name
,
dims_mapping
)
...
...
@@ -617,6 +623,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
...
...
@@ -629,6 +636,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'transpose_X'
:
False
,
'transpose_Y'
:
False
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
(
'op_role'
)
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
matmul_op
=
main_block
.
append_op
(
type
=
'matmul'
,
...
...
@@ -814,6 +822,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
'transpose_X'
:
False
,
'transpose_Y'
:
False
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
...
@@ -853,7 +862,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -1137,6 +1147,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
})
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
...
...
@@ -1145,7 +1156,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
False
}
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
False
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
matmul_v2_op
=
main_block
.
append_op
(
type
=
'matmul_v2'
,
inputs
=
inputs
,
...
...
@@ -1322,7 +1337,11 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
False
}
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
False
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
# infer out var shape with op dist attr
...
...
@@ -1361,7 +1380,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -1646,6 +1666,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
...
...
@@ -1657,7 +1678,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
)
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
),
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
...
...
@@ -1838,7 +1860,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
)
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
),
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
...
@@ -1878,7 +1901,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
...
python/paddle/distributed/auto_parallel/partitioner.py
浏览文件 @
db2c71a4
...
...
@@ -264,10 +264,12 @@ class Partitioner(object):
self
.
_dist_context
,
**
kinputs
,
**
koutputs
,
**
{
"grad_var_to_var"
:
grad_var_to_var
})
elif
is_optimize_op
(
op
):
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS
kinputs
,
koutputs
=
dist_op_context
.
prepare_context
(
op
)
dist_op_impl
=
get_distributed_operator_impl_container
(
"default"
).
get_impl
(
0
)
dist_op_impl
.
backward
(
self
.
_dist_context
,
**
kinputs
,
**
koutputs
)
dist_op_opt_impl
=
_get_dist_op_backward_implement
(
op
,
self
.
_dist_context
,
forward_op_id2forward_op
)
dist_op_opt_impl
.
backward
(
self
.
_dist_context
,
**
kinputs
,
**
koutputs
)
else
:
raise
NotImplementedError
(
"partitioner only support forward and backward, optimize ops, but got {}"
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
db2c71a4
...
...
@@ -1065,7 +1065,7 @@ def set_grad_var_shape(program, dist_context):
"softmax"
,
"cross_entropy2"
,
"dropout"
,
"tanh"
,
[
"slice_grad"
,
"c_allgather"
],
"assign"
,
"matmul_v2_grad_grad"
,
"elementwise_add_grad_grad"
,
"shape"
,
"sqrt"
,
"fused_softmax_mask_upper_triangle
_grad
"
"fused_softmax_mask_upper_triangle"
]
if
op
.
type
in
need_set_shape_list
:
for
forward_op
in
block
.
ops
:
...
...
@@ -1096,11 +1096,9 @@ OpRole = core.op_proto_and_checker_maker.OpRole
def
is_forward_op
(
op
):
ref_role1
=
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
ref_role2
=
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Loss
)
op_role
=
int
(
op
.
attr
(
'op_role'
))
return
OP_ROLE_KEY
in
op
.
attr_names
and
(
op_role
==
ref_role1
or
op_role
==
ref_role2
)
return
OP_ROLE_KEY
in
op
.
attr_names
and
(
op_role
==
int
(
OpRole
.
Forward
)
or
op_role
==
int
(
OpRole
.
Loss
)
)
def
is_backward_op
(
op
):
...
...
@@ -1113,9 +1111,14 @@ def is_optimize_op(op):
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
&
int
(
OpRole
.
Optimize
)
def
is_lr_sched_op
(
op
):
return
OP_ROLE_KEY
in
op
.
attr_names
and
\
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
&
int
(
OpRole
.
Optimize
.
LRSched
)
def
is_loss_op
(
op
):
return
OP_ROLE_KEY
in
op
.
attr_names
and
\
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
==
(
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
|
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Loss
))
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
==
(
int
(
OpRole
.
Forward
)
|
int
(
OpRole
.
Loss
))
def
is_prim_op
(
op
):
...
...
python/paddle/distributed/passes/auto_parallel_amp.py
浏览文件 @
db2c71a4
...
...
@@ -452,7 +452,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
inputs
=
{
'X'
:
grads
,
'Scale'
:
loss_scaling
}
outputs
=
{
'Out'
:
grads
,
'FoundInfinite'
:
found_inf
}
attrs
=
{
'op_role'
:
OpRole
.
Backward
}
attrs
=
{
'op_role'
:
OpRole
.
Optimize
}
new_op
=
main_block
.
append_op
(
type
=
'check_finite_and_unscale'
,
inputs
=
inputs
,
outputs
=
outputs
,
...
...
@@ -732,7 +732,7 @@ class AMPPass(PassBase):
'incr_ratio'
:
self
.
get_attr
(
"incr_ratio"
),
'decr_ratio'
:
self
.
get_attr
(
"decr_ratio"
),
'stop_update'
:
self
.
get_attr
(
"stop_update"
),
'op_role'
:
OpRole
.
Backward
'op_role'
:
OpRole
.
Optimize
}
new_op
=
main_block
.
append_op
(
type
=
'update_loss_scaling'
,
...
...
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
浏览文件 @
db2c71a4
...
...
@@ -21,20 +21,13 @@ from paddle.framework import core
from
paddle.fluid
import
layers
from
paddle.fluid.framework
import
program_guard
,
device_guard
from
.pass_base
import
PassBase
,
PassType
,
register_pass
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.distributed.auto_parallel.utils
import
set_var_dist_attr
from
paddle.distributed.auto_parallel.utils
import
set_var_dist_attr
,
is_optimize_op
,
OpRole
,
OP_ROLE_KEY
from
paddle.distributed.auto_parallel.utils
import
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from
paddle.distributed.auto_parallel.process_group
import
get_world_process_group
world_process_group
=
get_world_process_group
()
def
_is_the_optimizer_op
(
op
):
OP_ROLE_KEY
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
return
OP_ROLE_KEY
in
op
.
attr_names
and
\
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
&
int
(
OpRole
.
Optimize
)
def
_remove_and_get_optimizer_op
(
main_program
,
dist_context
):
# 1 create tmp block
# 2 mv optimizer op from global program to tmp block
...
...
@@ -43,9 +36,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
temp_block
=
main_program
.
_create_block
()
removed_op_idx
=
[]
optimize_ops_desc
=
[]
skip_ops
=
[
"increment"
,
"elementwise_mod"
,
"equal"
]
for
idx
,
op
in
enumerate
(
main_block
.
ops
):
if
_is_the_optimizer_op
(
op
)
and
op
.
type
not
in
skip_ops
:
if
is_optimize_op
(
op
)
:
# append optimizer op to tmp block
new_op_desc
=
temp_block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op
.
desc
)
...
...
@@ -57,7 +49,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
dist_context
.
del_dist_op_for_program
(
op
)
for
idx
in
removed_op_idx
[::
-
1
]:
main_block
.
_remove_op
(
idx
)
main_block
.
_remove_op
(
idx
,
sync
=
False
)
main_block
.
_sync_with_cpp
()
return
optimize_ops_desc
...
...
@@ -109,7 +102,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
outputs
=
{
'Out'
:
[
step_var
]},
attrs
=
{
'step'
:
float
(
1.0
),
'op_role'
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Backward
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
increment_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
...
...
@@ -123,7 +116,8 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
attrs
=
{
'axis'
:
-
1
,
'use_mkldnn'
:
False
,
'op_role'
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Backward
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
elementwise_mod_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
...
...
@@ -134,7 +128,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
'Y'
:
zero_var
},
outputs
=
{
'Out'
:
cond_var
},
attrs
=
{
'op_role'
:
OpRole
.
Optimize
})
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Backward
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
equal_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
...
...
@@ -143,7 +137,6 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
def
_append_gradient_merge_backward_op
(
main_program
,
startup_program
,
params_grads
:
List
[
Tuple
[
Any
,
Any
]],
cond_var_name
:
str
,
dist_context
)
->
Tuple
[
List
[
Tuple
[
Any
,
Any
]],
Dict
[
str
,
Any
]]:
main_block
=
main_program
.
global_block
()
startup_block
=
startup_program
.
global_block
()
...
...
@@ -201,7 +194,7 @@ def _append_gradient_merge_backward_op(
attrs
=
{
'axis'
:
-
1
,
'use_mkldnn'
:
False
,
'op_role'
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Backward
})
new_params_to_grads
.
append
([
param
,
gradient_merge_var
])
grad_to_gradient_merge
[
grad
.
name
]
=
gradient_merge_var
.
name
...
...
@@ -233,8 +226,7 @@ def _create_cond_block_and_update_optimizer(
'bias'
:
0.0
,
'bias_after_scale'
:
False
})
new_grad
.
op
.
_set_attr
(
op_maker
.
kOpRoleAttrName
(),
OpRole
.
Optimize
)
new_grad
.
op
.
_set_attr
(
OP_ROLE_KEY
,
OpRole
.
Optimize
)
# append optimizer ops
for
op_desc
in
optimize_ops_desc
:
...
...
@@ -272,29 +264,27 @@ def _create_cond_block_and_update_optimizer(
dtype
=
new_grad
.
dtype
,
value
=
0.0
,
out
=
new_grad
)
new_grad
.
op
.
_set_attr
(
op_maker
.
kOpRoleAttrName
(),
op_maker
.
OpRole
.
Optimize
)
new_grad
.
op
.
_set_attr
(
OP_ROLE_KEY
,
op_maker
.
OpRole
.
Optimize
)
layers
.
cond
(
cond_var
,
true_fn
=
true_apply_gradient
,
false_fn
=
None
)
cond_op
=
main_program
.
global_block
().
ops
[
-
1
]
cond_op
.
_set_attr
(
'op_role'
,
OpRole
.
Optimize
)
cond_op
.
_set_attr
(
OP_ROLE_KEY
,
OpRole
.
Optimize
)
def
parse_program
(
main_program
,
startup_program
,
params_grads
,
k_steps
,
avg
,
dist_context
):
# 1 create gradient_merge_cond
cond_var
=
_get_gm_cond_var
(
main_program
,
k_steps
,
dist_context
)
# 2 remove optimizer_op from main_program
# 1 remove optimizer_op from main_program
optimize_ops_desc
=
_remove_and_get_optimizer_op
(
main_program
,
dist_context
)
# back to block 0
main_program
.
_rollback
()
#
3
append gradient merge backward op to main_program
#
2
append gradient merge backward op to main_program
new_params_to_grads
,
grad_to_gradient_merge
=
_append_gradient_merge_backward_op
(
main_program
,
startup_program
,
params_grads
,
cond_var
.
name
,
dist_context
)
main_program
,
startup_program
,
params_grads
,
dist_context
)
# 3 create gradient_merge_cond
cond_var
=
_get_gm_cond_var
(
main_program
,
k_steps
,
dist_context
)
# 4 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer
(
main_program
,
cond_var
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录