Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
db2c71a4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
db2c71a4
编写于
7月 07, 2022
作者:
Z
zhaoyingli
提交者:
GitHub
7月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] fix 'op_role' for gradient merge & recompute (#44138)
* fix op_role * fix engine * update op_role
上级
7e3833a7
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
182 addition
and
184 deletion
+182
-184
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+87
-118
python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py
.../auto_parallel/operators/dist_check_finite_and_unscale.py
+3
-3
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
...dle/distributed/auto_parallel/operators/dist_embedding.py
+19
-9
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+39
-15
python/paddle/distributed/auto_parallel/partitioner.py
python/paddle/distributed/auto_parallel/partitioner.py
+5
-3
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+9
-6
python/paddle/distributed/passes/auto_parallel_amp.py
python/paddle/distributed/passes/auto_parallel_amp.py
+2
-2
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
...paddle/distributed/passes/auto_parallel_gradient_merge.py
+18
-28
未找到文件。
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
db2c71a4
...
...
@@ -18,7 +18,6 @@ from collections import defaultdict
import
paddle
import
paddle.utils
as
utils
import
paddle.distributed.auto_parallel
as
auto
from
paddle
import
fluid
,
static
from
paddle.io
import
Dataset
...
...
@@ -72,7 +71,6 @@ class Engine:
self
.
_saver
=
DistributedSaver
()
self
.
_logger
=
get_logger
(
logging
.
INFO
)
self
.
_default_strategy
=
None
self
.
_orig_main_prog
=
static
.
default_main_program
()
self
.
_orig_startup_prog
=
static
.
default_startup_program
()
self
.
_orig_dist_context
=
get_default_distributed_context
()
...
...
@@ -117,9 +115,11 @@ class Engine:
self
.
_planned_mode
=
None
self
.
_modes
=
[
'train'
,
'eval'
,
'predict'
]
self
.
_build
()
# Do auto parallel process
# Build program and do auto parallel process
for
mode
in
self
.
_modes
:
# Build forward program
self
.
_build
(
mode
)
for
mode
in
self
.
_modes
:
# Do the planning process
self
.
_plan
(
mode
)
...
...
@@ -129,56 +129,49 @@ class Engine:
# Init comm and startup program
self
.
_initialize
(
mode
)
def
_build
(
self
):
for
mode
in
self
.
_modes
:
serial_main_prog
=
self
.
_serial_main_progs
.
get
(
mode
,
None
)
if
serial_main_prog
is
not
None
:
return
losses
=
[]
metrics
=
[]
serial_main_prog
=
self
.
_orig_main_prog
.
clone
()
serial_startup_prog
=
self
.
_orig_startup_prog
.
clone
()
with
static
.
program_guard
(
serial_main_prog
,
serial_startup_prog
),
\
utils
.
unique_name
.
guard
():
inputs_spec
=
self
.
inputs_spec
labels_spec
=
self
.
labels_spec
if
self
.
labels_spec
else
[]
inputs
=
[
s
.
_create_feed_layer
()
for
s
in
inputs_spec
]
labels
=
[
s
.
_create_feed_layer
()
for
s
in
labels_spec
]
outputs
=
to_list
(
self
.
model
(
*
inputs
))
if
mode
!=
"predict"
and
self
.
_loss
:
losses
=
to_list
(
self
.
_loss
(
*
(
outputs
+
labels
)))
if
mode
!=
"predict"
:
for
metric
in
self
.
_metrics
:
metrics
.
extend
(
to_list
(
metric
.
compute
(
*
(
outputs
+
labels
))))
default_ctx
=
get_default_distributed_context
()
if
not
default_ctx
.
has_annotation
or
self
.
_default_strategy
:
# We build the world process group because the data parallel
# needs all ranks by default.
new_process_group
(
list
(
range
(
self
.
_nranks
)))
default_ctx
.
data_parallel
=
True
# self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars
=
{
"inputs"
:
inputs
,
"labels"
:
labels
}
# self._fetch_vars[mode] = {
# "outputs": flatten(outputs),
# "loss": losses,
# "metrics": metrics
# }
fetch_vars
=
{
"outputs"
:
flatten
(
outputs
),
"loss"
:
losses
,
"metrics"
:
metrics
}
self
.
_dist_contexts
[
mode
]
=
DistributedContext
(
serial_main_prog
,
serial_startup_prog
,
self
.
_optimizer
,
losses
,
feed_vars
,
fetch_vars
,
self
.
cluster
,
self
.
strategy
)
self
.
_dist_contexts
[
mode
].
gradient_scale
=
self
.
_gradient_scale
def
_build
(
self
,
mode
):
serial_main_prog
=
self
.
_serial_main_progs
.
get
(
mode
,
None
)
if
serial_main_prog
is
not
None
:
return
losses
=
[]
metrics
=
[]
serial_main_prog
=
self
.
_orig_main_prog
.
clone
()
serial_startup_prog
=
self
.
_orig_startup_prog
.
clone
()
with
static
.
program_guard
(
serial_main_prog
,
serial_startup_prog
),
\
utils
.
unique_name
.
guard
():
inputs_spec
=
self
.
inputs_spec
labels_spec
=
self
.
labels_spec
if
self
.
labels_spec
else
[]
inputs
=
[
s
.
_create_feed_layer
()
for
s
in
inputs_spec
]
labels
=
[
s
.
_create_feed_layer
()
for
s
in
labels_spec
]
outputs
=
to_list
(
self
.
model
(
*
inputs
))
if
mode
!=
"predict"
and
self
.
_loss
:
losses
=
to_list
(
self
.
_loss
(
*
(
outputs
+
labels
)))
if
mode
!=
"predict"
:
for
metric
in
self
.
_metrics
:
metrics
.
extend
(
to_list
(
metric
.
compute
(
*
(
outputs
+
labels
))))
default_ctx
=
get_default_distributed_context
()
if
not
default_ctx
.
has_annotation
:
# We build the world process group because the data parallel
# needs all ranks by default.
new_process_group
(
list
(
range
(
self
.
_nranks
)))
default_ctx
.
data_parallel
=
True
feed_vars
=
{
"inputs"
:
inputs
,
"labels"
:
labels
}
fetch_vars
=
{
"outputs"
:
flatten
(
outputs
),
"loss"
:
losses
,
"metrics"
:
metrics
}
self
.
_dist_contexts
[
mode
]
=
DistributedContext
(
serial_main_prog
,
serial_startup_prog
,
self
.
_optimizer
,
losses
,
feed_vars
,
fetch_vars
,
self
.
cluster
,
self
.
strategy
)
self
.
_dist_contexts
[
mode
].
gradient_scale
=
self
.
_gradient_scale
def
_plan
(
self
,
mode
):
if
self
.
_planned_mode
is
None
:
...
...
@@ -240,7 +233,6 @@ class Engine:
continue
process_group
.
instantiate
()
# initialize
self
.
_place
=
_get_device
()
if
isinstance
(
self
.
_place
,
fluid
.
CUDAPlace
):
self
.
_place
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
...
...
@@ -273,8 +265,8 @@ class Engine:
train_dataloader
=
self
.
_create_dataloader
(
train_data
,
batch_size
,
epochs
,
steps_per_epoch
)
usr_fetch
=
self
.
_
to_map_fetch
(
fetches
)
fetch_loss
=
self
.
_
inner_fetch
(
self
.
fetch_vars
[
"loss"
])
usr_fetch
=
self
.
_
validate_fetches
(
fetches
)
fetch_loss
=
self
.
_
validate_fetches
(
self
.
fetch_vars
[
"loss"
])
fetch_list
,
fetch_map
=
self
.
_fetch_map
(
fetch_loss
,
usr_fetch
)
for
epoch
in
range
(
epochs
):
...
...
@@ -292,8 +284,7 @@ class Engine:
user_outs
=
outs
[
len
(
fetch_loss
):]
user_fetch_list
=
fetch_list
[
len
(
fetch_loss
):]
for
i
,
out
in
enumerate
(
user_outs
):
train_logs
[
"train_"
+
fetch_map
[
user_fetch_list
[
i
]]]
=
out
[
0
]
train_logs
[
"train_"
+
fetch_map
[
user_fetch_list
[
i
]]]
=
out
self
.
_logger
.
info
(
train_logs
)
def
evaluate
(
self
,
...
...
@@ -307,9 +298,9 @@ class Engine:
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader
=
self
.
_create_dataloader
(
eval_data
,
batch_size
)
usr_fetch
=
self
.
_
to_map_fetch
(
fetches
)
fetch_loss
=
self
.
_
inner_fetch
(
self
.
fetch_vars
[
"loss"
])
fetch_metrics
=
self
.
_
inner_fetch
(
self
.
fetch_vars
[
"metrics"
])
usr_fetch
=
self
.
_
validate_fetches
(
fetches
)
fetch_loss
=
self
.
_
validate_fetches
(
self
.
fetch_vars
[
"loss"
])
fetch_metrics
=
self
.
_
validate_fetches
(
self
.
fetch_vars
[
"metrics"
])
inner_fetch
=
dict
(
fetch_loss
,
**
fetch_metrics
)
fetch_list
,
fetch_map
=
self
.
_fetch_map
(
inner_fetch
,
usr_fetch
)
...
...
@@ -321,7 +312,7 @@ class Engine:
return_numpy
=
return_numpy
)
# inner fetches
if
fetch_loss
:
eval_logs
[
"eval_loss"
]
=
outs
[
0
]
eval_logs
[
"eval_loss"
]
=
outs
[
0
]
[
0
]
# Metric
if
fetch_metrics
:
metric_out
=
outs
[
len
(
fetch_loss
):
len
(
inner_fetch
)]
...
...
@@ -331,9 +322,9 @@ class Engine:
for
i
,
res
in
enumerate
(
to_list
(
results
)):
eval_logs
[
"eval_"
+
metric
.
name
()[
i
]]
=
res
# usr fetches
usr_out
=
outs
[
len
(
inner_fetch
):]
usr_out
s
=
outs
[
len
(
inner_fetch
):]
usr_fetch_list
=
fetch_list
[
len
(
inner_fetch
):]
for
i
,
out
in
enumerate
(
usr_out
):
for
i
,
out
in
enumerate
(
usr_out
s
):
eval_logs
[
"eval_"
+
fetch_map
[
usr_fetch_list
[
i
]]]
=
out
# logger
self
.
_logger
.
info
(
eval_logs
)
...
...
@@ -349,8 +340,8 @@ class Engine:
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader
=
self
.
_create_dataloader
(
test_data
,
batch_size
)
usr_fetch
=
self
.
_
to_map_fetch
(
fetches
)
fetch_outputs
=
self
.
_
inner_fetch
(
self
.
fetch_vars
[
"outputs"
])
usr_fetch
=
self
.
_
validate_fetches
(
fetches
)
fetch_outputs
=
self
.
_
validate_fetches
(
self
.
fetch_vars
[
"outputs"
])
fetch_list
,
fetch_map
=
self
.
_fetch_map
(
fetch_outputs
,
usr_fetch
)
outputs
=
[]
...
...
@@ -362,42 +353,11 @@ class Engine:
return_numpy
=
return_numpy
)
outputs
.
append
(
outs
[:
len
(
fetch_outputs
)])
for
i
,
out
in
enumerate
(
outs
):
predict_logs
[
"pred_"
+
fetch_map
[
fetch_list
[
i
]]]
=
out
[
0
]
predict_logs
[
"pred_"
+
fetch_map
[
fetch_list
[
i
]]]
=
out
self
.
_logger
.
info
(
predict_logs
)
return
outputs
def
_local_var
(
self
,
var
):
var_name
=
_to_name_str
(
var
)
return
var_name
in
self
.
main_program
.
global_block
().
vars
def
_to_map_fetch
(
self
,
fetches
):
if
not
fetches
:
return
{}
if
isinstance
(
fetches
,
dict
):
fetch_var_names
=
list
(
map
(
_to_name_str
,
fetches
.
values
()))
usr_fetches
=
dict
(
zip
(
fetch_var_names
,
list
(
fetches
.
keys
())))
elif
isinstance
(
fetches
,
list
):
fetch_var_names
=
list
(
map
(
_to_name_str
,
fetches
))
usr_fetches
=
dict
(
zip
(
fetch_var_names
,
fetch_var_names
))
return
dict
(
filter
(
lambda
x
:
self
.
_local_var
(
x
[
0
]),
usr_fetches
.
items
()))
def
_inner_fetch
(
self
,
fetch_vars
):
fetch_list
=
list
(
map
(
lambda
x
:
x
.
name
,
list
(
filter
(
self
.
_local_var
,
fetch_vars
))))
inner_fetches
=
dict
(
zip
(
fetch_list
,
fetch_list
))
return
inner_fetches
def
_fetch_map
(
self
,
inner_fetch
,
usr_fetch
):
# replace inner fetch name if usr set for it
for
iname
in
inner_fetch
:
if
iname
in
usr_fetch
:
inner_fetch
[
iname
]
=
usr_fetch
[
iname
]
usr_fetch
.
pop
(
iname
)
fetches
=
dict
(
inner_fetch
,
**
usr_fetch
)
return
list
(
fetches
.
keys
()),
fetches
def
_create_dataloader
(
self
,
dataset
,
batch_size
,
...
...
@@ -468,26 +428,35 @@ class Engine:
.
format
(
i
,
spec
))
return
specs
def
_set_data_parallel
(
self
,
var
):
if
self
.
_nranks
==
1
:
self
.
_default_strategy
=
'serial'
auto
.
shard_tensor
(
var
,
dist_attr
=
{
"process_mesh"
:
[
0
],
"dims_mapping"
:
[
-
1
for
_
in
range
(
len
(
var
.
shape
))]
})
def
_is_local_var
(
self
,
var
):
var_name
=
_to_name_str
(
var
)
return
var_name
in
self
.
main_program
.
global_block
().
vars
def
_validate_fetches
(
self
,
fetches
):
# 1. Check user-defined fetches type
# 2. Prepare fetches_dict like {user_defined_name: var_name}
if
not
fetches
:
return
{}
if
isinstance
(
fetches
,
dict
):
fetch_var_names
=
list
(
map
(
_to_name_str
,
fetches
.
values
()))
fetches_dict
=
dict
(
zip
(
fetch_var_names
,
list
(
fetches
.
keys
())))
elif
isinstance
(
fetches
,
list
):
fetch_var_names
=
list
(
map
(
_to_name_str
,
fetches
))
fetches_dict
=
dict
(
zip
(
fetch_var_names
,
fetch_var_names
))
else
:
self
.
_default_strategy
=
'dp'
auto
.
shard_tensor
(
var
,
dist_attr
=
{
"process_mesh"
:
list
(
range
(
self
.
_nranks
)),
"dims_mapping"
:
[
0
]
+
[
-
1
for
_
in
range
(
len
(
var
.
shape
)
-
1
)]
})
return
var
raise
TypeError
(
"'fetches' only support 'dict' and 'list', "
"but got '{}'"
.
format
(
str
(
type
(
fetches
))))
return
dict
(
filter
(
lambda
x
:
self
.
_is_local_var
(
x
[
0
]),
fetches_dict
.
items
()))
def
_fetch_map
(
self
,
inner_fetch
,
usr_fetch
):
# replace inner fetch name if usr set for it
for
iname
in
inner_fetch
:
if
iname
in
usr_fetch
:
inner_fetch
[
iname
]
=
usr_fetch
[
iname
]
usr_fetch
.
pop
(
iname
)
fetches
=
dict
(
inner_fetch
,
**
usr_fetch
)
return
list
(
fetches
.
keys
()),
fetches
def
_get_data_parallel_info
(
self
,
var
,
dist_context
):
# get data parallel world size and current data parallel rank
...
...
python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py
浏览文件 @
db2c71a4
...
...
@@ -137,7 +137,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs
=
{
"in_dtype"
:
inf_var
.
dtype
,
"out_dtype"
:
inf_var_int32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Backward
OP_ROLE_KEY
:
OpRole
.
Optimize
})
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_int32
},
...
...
@@ -145,7 +145,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
OP_ROLE_KEY
:
OpRole
.
Optimize
})
cast_op2
=
main_block
.
append_op
(
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_int32
},
...
...
@@ -153,7 +153,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs
=
{
"in_dtype"
:
inf_var_int32
.
dtype
,
"out_dtype"
:
inf_var
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Backward
OP_ROLE_KEY
:
OpRole
.
Optimize
})
main_block
.
_sync_with_cpp
()
...
...
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
浏览文件 @
db2c71a4
...
...
@@ -222,7 +222,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'W'
:
[
Weight_var
]
},
outputs
=
{
'Out'
:
[
intermediate_var_0
]},
attrs
=
{
"start_index"
:
relative_idx
})
attrs
=
{
"start_index"
:
relative_idx
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
intermediate_var_0
.
shape
!=
ref_shape
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -235,6 +238,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -442,6 +446,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
dp_group
=
new_process_group
(
group_ranks
)
if
need_gradient_allreduce
:
added_ops
=
[]
W_Grad_var
=
main_block
.
var
(
kwargs
[
'W@GRAD'
][
0
])
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
W_Grad_var
]},
...
...
@@ -451,19 +456,24 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
W_Grad_var
},
outputs
=
{
'Out'
:
W_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
allreduce_op
)
if
ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
W_Grad_var
},
outputs
=
{
'Out'
:
W_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
scale_op
)
main_block
.
_sync_with_cpp
()
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
W_Grad_var
).
dims_mapping
process_mesh
=
dist_attr
.
process_mesh
for
op
in
[
allreduce_op
,
scale_op
]
:
for
op
in
added_ops
:
op_attr
=
OperatorDistributedAttribute
()
op_attr
.
process_mesh
=
process_mesh
op_attr
.
set_output_dims_mapping
(
W_Grad_var
.
name
,
dims_mapping
)
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
db2c71a4
...
...
@@ -405,6 +405,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
dp_group
=
new_process_group
(
group_ranks
)
if
need_gradient_allreduce
and
is_parameter_related
(
Y_var
.
name
,
main_block
):
added_ops
=
[]
Y_Grad_var
=
main_block
.
var
(
kwargs
[
'Y@GRAD'
][
0
])
allreduce_op
=
main_block
.
append_op
(
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
[
Y_Grad_var
]},
...
...
@@ -414,19 +415,24 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
Y_Grad_var
},
outputs
=
{
'Out'
:
Y_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
allreduce_op
)
if
ctx
.
gradient_scale
:
scale_op
=
main_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
Y_Grad_var
},
outputs
=
{
'Out'
:
Y_Grad_var
},
attrs
=
{
'scale'
:
1.0
/
dp_degree
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
added_ops
.
append
(
scale_op
)
main_block
.
_sync_with_cpp
()
dims_mapping
=
ctx
.
get_tensor_dist_attr_for_program
(
Y_Grad_var
).
dims_mapping
process_mesh
=
dist_attr
.
process_mesh
for
op
in
[
allreduce_op
,
scale_op
]
:
for
op
in
added_ops
:
op_attr
=
OperatorDistributedAttribute
()
op_attr
.
process_mesh
=
process_mesh
op_attr
.
set_output_dims_mapping
(
Y_Grad_var
.
name
,
dims_mapping
)
...
...
@@ -617,6 +623,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
...
...
@@ -629,6 +636,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'transpose_X'
:
False
,
'transpose_Y'
:
False
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
(
'op_role'
)
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
matmul_op
=
main_block
.
append_op
(
type
=
'matmul'
,
...
...
@@ -814,6 +822,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
'transpose_X'
:
False
,
'transpose_Y'
:
False
,
'alpha'
:
1
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
...
@@ -853,7 +862,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -1137,6 +1147,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
),
})
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
...
...
@@ -1145,7 +1156,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
intermediate_var_0
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
False
}
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
False
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
matmul_v2_op
=
main_block
.
append_op
(
type
=
'matmul_v2'
,
inputs
=
inputs
,
...
...
@@ -1322,7 +1337,11 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
'linear'
)
check_dtype
(
X_var
.
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
False
}
attrs
=
{
'trans_x'
:
False
,
'trans_y'
:
False
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
# infer out var shape with op dist attr
...
...
@@ -1361,7 +1380,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
...
@@ -1646,6 +1666,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
intermediate_var_0
.
shape
!=
ref_shape_x
:
intermediate_var_0
.
desc
.
set_shape
(
ref_shape_x
)
...
...
@@ -1657,7 +1678,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
)
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
),
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
[
intermediate_var_0
],
'Y'
:
[
Weight_var
]}
mul_op
=
main_block
.
append_op
(
type
=
'mul'
,
...
...
@@ -1838,7 +1860,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# attrs = {'trans_x': False, 'trans_y': False}
attrs
=
{
"x_num_col_dims"
:
src_op
.
desc
.
attr
(
"x_num_col_dims"
),
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
)
"y_num_col_dims"
:
src_op
.
desc
.
attr
(
"y_num_col_dims"
),
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
}
inputs
=
{
'X'
:
X_var
,
'Y'
:
Weight_var
}
...
...
@@ -1878,7 +1901,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
attrs
=
{
'ring_id'
:
group
.
id
,
'use_calc_stream'
:
True
,
'use_model_parallel'
:
True
'use_model_parallel'
:
True
,
OP_ROLE_KEY
:
src_op
.
attr
(
'op_role'
)
})
if
Out_var
.
shape
!=
ref_shape
:
Out_var
.
desc
.
set_shape
(
ref_shape
)
...
...
python/paddle/distributed/auto_parallel/partitioner.py
浏览文件 @
db2c71a4
...
...
@@ -264,10 +264,12 @@ class Partitioner(object):
self
.
_dist_context
,
**
kinputs
,
**
koutputs
,
**
{
"grad_var_to_var"
:
grad_var_to_var
})
elif
is_optimize_op
(
op
):
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS
kinputs
,
koutputs
=
dist_op_context
.
prepare_context
(
op
)
dist_op_impl
=
get_distributed_operator_impl_container
(
"default"
).
get_impl
(
0
)
dist_op_impl
.
backward
(
self
.
_dist_context
,
**
kinputs
,
**
koutputs
)
dist_op_opt_impl
=
_get_dist_op_backward_implement
(
op
,
self
.
_dist_context
,
forward_op_id2forward_op
)
dist_op_opt_impl
.
backward
(
self
.
_dist_context
,
**
kinputs
,
**
koutputs
)
else
:
raise
NotImplementedError
(
"partitioner only support forward and backward, optimize ops, but got {}"
...
...
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
db2c71a4
...
...
@@ -1065,7 +1065,7 @@ def set_grad_var_shape(program, dist_context):
"softmax"
,
"cross_entropy2"
,
"dropout"
,
"tanh"
,
[
"slice_grad"
,
"c_allgather"
],
"assign"
,
"matmul_v2_grad_grad"
,
"elementwise_add_grad_grad"
,
"shape"
,
"sqrt"
,
"fused_softmax_mask_upper_triangle
_grad
"
"fused_softmax_mask_upper_triangle"
]
if
op
.
type
in
need_set_shape_list
:
for
forward_op
in
block
.
ops
:
...
...
@@ -1096,11 +1096,9 @@ OpRole = core.op_proto_and_checker_maker.OpRole
def
is_forward_op
(
op
):
ref_role1
=
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
ref_role2
=
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Loss
)
op_role
=
int
(
op
.
attr
(
'op_role'
))
return
OP_ROLE_KEY
in
op
.
attr_names
and
(
op_role
==
ref_role1
or
op_role
==
ref_role2
)
return
OP_ROLE_KEY
in
op
.
attr_names
and
(
op_role
==
int
(
OpRole
.
Forward
)
or
op_role
==
int
(
OpRole
.
Loss
)
)
def
is_backward_op
(
op
):
...
...
@@ -1113,9 +1111,14 @@ def is_optimize_op(op):
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
&
int
(
OpRole
.
Optimize
)
def
is_lr_sched_op
(
op
):
return
OP_ROLE_KEY
in
op
.
attr_names
and
\
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
&
int
(
OpRole
.
Optimize
.
LRSched
)
def
is_loss_op
(
op
):
return
OP_ROLE_KEY
in
op
.
attr_names
and
\
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
==
(
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
)
|
int
(
core
.
op_proto_and_checker_maker
.
OpRole
.
Loss
))
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
==
(
int
(
OpRole
.
Forward
)
|
int
(
OpRole
.
Loss
))
def
is_prim_op
(
op
):
...
...
python/paddle/distributed/passes/auto_parallel_amp.py
浏览文件 @
db2c71a4
...
...
@@ -452,7 +452,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
inputs
=
{
'X'
:
grads
,
'Scale'
:
loss_scaling
}
outputs
=
{
'Out'
:
grads
,
'FoundInfinite'
:
found_inf
}
attrs
=
{
'op_role'
:
OpRole
.
Backward
}
attrs
=
{
'op_role'
:
OpRole
.
Optimize
}
new_op
=
main_block
.
append_op
(
type
=
'check_finite_and_unscale'
,
inputs
=
inputs
,
outputs
=
outputs
,
...
...
@@ -732,7 +732,7 @@ class AMPPass(PassBase):
'incr_ratio'
:
self
.
get_attr
(
"incr_ratio"
),
'decr_ratio'
:
self
.
get_attr
(
"decr_ratio"
),
'stop_update'
:
self
.
get_attr
(
"stop_update"
),
'op_role'
:
OpRole
.
Backward
'op_role'
:
OpRole
.
Optimize
}
new_op
=
main_block
.
append_op
(
type
=
'update_loss_scaling'
,
...
...
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
浏览文件 @
db2c71a4
...
...
@@ -21,20 +21,13 @@ from paddle.framework import core
from
paddle.fluid
import
layers
from
paddle.fluid.framework
import
program_guard
,
device_guard
from
.pass_base
import
PassBase
,
PassType
,
register_pass
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
from
paddle.distributed.auto_parallel.utils
import
set_var_dist_attr
from
paddle.distributed.auto_parallel.utils
import
set_var_dist_attr
,
is_optimize_op
,
OpRole
,
OP_ROLE_KEY
from
paddle.distributed.auto_parallel.utils
import
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from
paddle.distributed.auto_parallel.process_group
import
get_world_process_group
world_process_group
=
get_world_process_group
()
def
_is_the_optimizer_op
(
op
):
OP_ROLE_KEY
=
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
return
OP_ROLE_KEY
in
op
.
attr_names
and
\
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
])
&
int
(
OpRole
.
Optimize
)
def
_remove_and_get_optimizer_op
(
main_program
,
dist_context
):
# 1 create tmp block
# 2 mv optimizer op from global program to tmp block
...
...
@@ -43,9 +36,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
temp_block
=
main_program
.
_create_block
()
removed_op_idx
=
[]
optimize_ops_desc
=
[]
skip_ops
=
[
"increment"
,
"elementwise_mod"
,
"equal"
]
for
idx
,
op
in
enumerate
(
main_block
.
ops
):
if
_is_the_optimizer_op
(
op
)
and
op
.
type
not
in
skip_ops
:
if
is_optimize_op
(
op
)
:
# append optimizer op to tmp block
new_op_desc
=
temp_block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op
.
desc
)
...
...
@@ -57,7 +49,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
dist_context
.
del_dist_op_for_program
(
op
)
for
idx
in
removed_op_idx
[::
-
1
]:
main_block
.
_remove_op
(
idx
)
main_block
.
_remove_op
(
idx
,
sync
=
False
)
main_block
.
_sync_with_cpp
()
return
optimize_ops_desc
...
...
@@ -109,7 +102,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
outputs
=
{
'Out'
:
[
step_var
]},
attrs
=
{
'step'
:
float
(
1.0
),
'op_role'
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Backward
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
increment_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
...
...
@@ -123,7 +116,8 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
attrs
=
{
'axis'
:
-
1
,
'use_mkldnn'
:
False
,
'op_role'
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Backward
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
elementwise_mod_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
...
...
@@ -134,7 +128,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
'Y'
:
zero_var
},
outputs
=
{
'Out'
:
cond_var
},
attrs
=
{
'op_role'
:
OpRole
.
Optimize
})
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Backward
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
equal_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
...
...
@@ -143,7 +137,6 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
def
_append_gradient_merge_backward_op
(
main_program
,
startup_program
,
params_grads
:
List
[
Tuple
[
Any
,
Any
]],
cond_var_name
:
str
,
dist_context
)
->
Tuple
[
List
[
Tuple
[
Any
,
Any
]],
Dict
[
str
,
Any
]]:
main_block
=
main_program
.
global_block
()
startup_block
=
startup_program
.
global_block
()
...
...
@@ -201,7 +194,7 @@ def _append_gradient_merge_backward_op(
attrs
=
{
'axis'
:
-
1
,
'use_mkldnn'
:
False
,
'op_role'
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Backward
})
new_params_to_grads
.
append
([
param
,
gradient_merge_var
])
grad_to_gradient_merge
[
grad
.
name
]
=
gradient_merge_var
.
name
...
...
@@ -233,8 +226,7 @@ def _create_cond_block_and_update_optimizer(
'bias'
:
0.0
,
'bias_after_scale'
:
False
})
new_grad
.
op
.
_set_attr
(
op_maker
.
kOpRoleAttrName
(),
OpRole
.
Optimize
)
new_grad
.
op
.
_set_attr
(
OP_ROLE_KEY
,
OpRole
.
Optimize
)
# append optimizer ops
for
op_desc
in
optimize_ops_desc
:
...
...
@@ -272,29 +264,27 @@ def _create_cond_block_and_update_optimizer(
dtype
=
new_grad
.
dtype
,
value
=
0.0
,
out
=
new_grad
)
new_grad
.
op
.
_set_attr
(
op_maker
.
kOpRoleAttrName
(),
op_maker
.
OpRole
.
Optimize
)
new_grad
.
op
.
_set_attr
(
OP_ROLE_KEY
,
op_maker
.
OpRole
.
Optimize
)
layers
.
cond
(
cond_var
,
true_fn
=
true_apply_gradient
,
false_fn
=
None
)
cond_op
=
main_program
.
global_block
().
ops
[
-
1
]
cond_op
.
_set_attr
(
'op_role'
,
OpRole
.
Optimize
)
cond_op
.
_set_attr
(
OP_ROLE_KEY
,
OpRole
.
Optimize
)
def
parse_program
(
main_program
,
startup_program
,
params_grads
,
k_steps
,
avg
,
dist_context
):
# 1 create gradient_merge_cond
cond_var
=
_get_gm_cond_var
(
main_program
,
k_steps
,
dist_context
)
# 2 remove optimizer_op from main_program
# 1 remove optimizer_op from main_program
optimize_ops_desc
=
_remove_and_get_optimizer_op
(
main_program
,
dist_context
)
# back to block 0
main_program
.
_rollback
()
#
3
append gradient merge backward op to main_program
#
2
append gradient merge backward op to main_program
new_params_to_grads
,
grad_to_gradient_merge
=
_append_gradient_merge_backward_op
(
main_program
,
startup_program
,
params_grads
,
cond_var
.
name
,
dist_context
)
main_program
,
startup_program
,
params_grads
,
dist_context
)
# 3 create gradient_merge_cond
cond_var
=
_get_gm_cond_var
(
main_program
,
k_steps
,
dist_context
)
# 4 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer
(
main_program
,
cond_var
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录