Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6b86e966
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6b86e966
编写于
5月 05, 2021
作者:
L
lilong12
提交者:
GitHub
5月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix the bug in pipeline for dygraph mode (#32716) (#32728)
* update, test=develop
上级
4593597d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
231 addition
and
155 deletion
+231
-155
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
...tributed/fleet/meta_parallel/parallel_layers/pp_layers.py
+0
-1
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+201
-141
python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
.../paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
+30
-13
未找到文件。
python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py
浏览文件 @
6b86e966
...
...
@@ -108,7 +108,6 @@ class PipelineLayer(Layer):
# construct layer
self
.
run_function
=
[]
self
.
_build_layer
()
self
.
to
(
paddle
.
CUDAPlace
(
self
.
device_id
))
def
_segment_network
(
self
,
seg_method
):
logger
.
info
(
"start segment network.."
)
...
...
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
6b86e966
...
...
@@ -22,15 +22,11 @@ from numpy import prod
import
paddle
import
paddle.fluid
as
fluid
from
.meta_parallel_base
import
MetaParallelBase
from
.pp_utils.utils
import
get_tensor_bytes
from
.pp_utils.utils
import
get_tensor_bytes
,
is_float_tensor
from
.pp_utils
import
utils
from
.parallel_layers.pp_layers
import
PipelineLayer
FLOAT_TYPES
=
[
paddle
.
float16
,
paddle
.
float32
,
paddle
.
float64
,
]
from
..utils.hybrid_parallel_util
import
*
from
..utils.log_util
import
logger
class
PipelineParallel
(
MetaParallelBase
):
...
...
@@ -46,20 +42,18 @@ class PipelineParallel(MetaParallelBase):
'inputs'
:
[],
'labels'
:
[],
'outputs'
:
[],
'backward_tensors'
:
[],
}
self
.
recv_cache
=
None
self
.
grad_tensors
=
None
self
.
meta_buffer
=
None
self
.
send_meta
=
True
self
.
first_gradient_send
=
True
self
.
current_loss
=
paddle
.
to_tensor
(
0.0
)
self
.
total_loss
=
None
def
_prepare_for_model
(
self
):
self
.
use_amp
=
self
.
_strategy
.
amp
self
.
init_loss_scaling
=
self
.
_strategy
.
amp_configs
[
'init_loss_scaling'
]
self
.
micro_batch_size
=
self
.
_strategy
.
pipeline_configs
[
'micro_batch_size'
]
self
.
accumulate_steps
=
self
.
_strategy
.
pipeline_configs
[
...
...
@@ -69,9 +63,17 @@ class PipelineParallel(MetaParallelBase):
self
.
stage_id
=
self
.
_hcg
.
get_stage_id
()
self
.
prev_stage_id
=
self
.
stage_id
-
1
self
.
next_stage_id
=
self
.
stage_id
+
1
self
.
_layers
=
PipelineLayer
(
layers
=
self
.
_layers
,
num_stages
=
self
.
num_stages
)
#TODO: init process group
self
.
pp_group
=
self
.
_hcg
.
get_pipe_parallel_group
()
logger
.
info
(
"Pipeline Info -- num_stages: {}, stage_id: {}"
.
format
(
self
.
num_stages
,
self
.
stage_id
))
if
self
.
use_model_parallel
:
logger
.
info
(
"start broadcast mp parameters"
)
broadcast_mp_parameters
(
self
.
_layers
,
self
.
_hcg
)
if
self
.
use_data_parallel
:
logger
.
info
(
"start broadcast mp parameters"
)
broadcast_dp_parameters
(
self
.
_layers
,
self
.
_hcg
)
def
_allocate_caches
(
self
,
num_caches
):
if
self
.
num_caches
>=
num_caches
:
...
...
@@ -82,19 +84,19 @@ class PipelineParallel(MetaParallelBase):
for
key
in
self
.
caches
:
self
.
caches
[
key
].
extend
([
None
]
*
num
)
def
train_batch
(
self
,
data
_iter
,
optimizer
):
def
train_batch
(
self
,
data
,
optimizer
):
self
.
optimizer
=
optimizer
assert
fluid
.
framework
.
_dygraph_tracer
().
_has_grad
,
(
'Please enable the generation of gradients.'
)
if
self
.
stage_id
==
0
or
self
.
stage_id
==
self
.
num_stages
-
1
:
assert
data
_iter
,
(
assert
data
,
(
"For the first and the last stage, the data_iter must be set."
)
else
:
assert
data
_iter
is
None
,
(
assert
data
is
None
,
(
"For pipe stages other than the first and the last one, "
"the data_iter must be None."
)
self
.
data
_iter
=
data_iter
self
.
data
=
data
self
.
_layers
.
train
()
self
.
total_loss
=
None
...
...
@@ -104,39 +106,24 @@ class PipelineParallel(MetaParallelBase):
return
self
.
total_loss
def
_train
(
self
,
minibatch_cmds
):
self
.
_allocate_caches
(
self
.
num_stages
)
for
microbatch_cmds
in
minibatch_cmds
:
for
cmd
in
microbatch_cmds
:
if
type
(
cmd
)
not
in
self
.
_COMMAND_MAP
:
#FIXME:
continue
self
.
_allocate_caches
(
self
.
accumulate_steps
)
for
micro_cmds
in
minibatch_cmds
:
for
cmd
in
micro_cmds
:
assert
type
(
cmd
)
in
self
.
_COMMAND_MAP
,
"unknow cmd: {}"
.
format
(
type
(
cmd
))
self
.
_apply_cmd
=
MethodType
(
self
.
_COMMAND_MAP
[
type
(
cmd
)],
self
)
self
.
_apply_cmd
(
**
cmd
.
kwargs
)
def
_allreduce_grads
(
self
):
self
.
_modifying_grad
=
True
assert
self
.
use_data_parallel
<=
1
,
(
"Do not support data parallel "
"with pipeline parallel now."
)
self
.
_modifying_grad
=
False
def
_get_data
(
self
):
if
self
.
use_model_parallel
:
mp_rank
=
self
.
_hcg
.
get_model_parallel_rank
()
else
:
mp_rank
=
0
data
=
None
# mp rank 0 loads the data and broadcat it to others.
if
mp_rank
==
0
:
data
=
next
(
self
.
data_iter
)
if
self
.
use_model_parallel
:
data
=
paddle
.
distributed
.
broadcast
(
data
,
group
=
self
.
_hcg
.
get_model_parallel_group
())
return
data
if
not
self
.
use_data_parallel
:
return
fused_allreduce_gradients
(
list
(
self
.
_layers
.
parameters
()),
self
.
_hcg
)
def
_forward
(
self
,
cache_id
):
# load data
self
.
_load_micro_batch
(
cache_id
)
if
self
.
stage_id
!=
0
:
self
.
_recv_activations
(
cache_id
)
if
isinstance
(
self
.
caches
[
'inputs'
][
cache_id
],
tuple
):
inputs
=
tuple
(
t
.
clone
()
for
t
in
self
.
caches
[
'inputs'
][
cache_id
])
else
:
...
...
@@ -144,9 +131,13 @@ class PipelineParallel(MetaParallelBase):
self
.
_clear_grads
(
inputs
)
outputs
=
self
.
_layers
.
forward
(
inputs
)
self
.
caches
[
'outputs'
][
cache_id
]
=
outputs
if
self
.
stage_id
==
self
.
num_stages
-
1
:
if
self
.
_layers
.
_loss_fn
is
not
None
:
labels
=
self
.
caches
[
'labels'
][
cache_id
]
outputs
=
self
.
_layers
.
_loss_fn
(
outputs
,
labels
)
if
self
.
stage_id
==
self
.
num_stages
-
1
:
self
.
current_loss
=
outputs
if
isinstance
(
self
.
current_loss
,
paddle
.
Tensor
):
...
...
@@ -160,18 +151,28 @@ class PipelineParallel(MetaParallelBase):
]
for
idx
,
v
in
enumerate
(
self
.
current_loss
):
self
.
total_loss
[
idx
]
+=
v
.
detach
()
if
self
.
use_data_parallel
:
self
.
current_loss
=
self
.
current_loss
/
self
.
_hcg
.
get_data_parallel_world_size
(
)
if
self
.
accumulate_steps
>
1
:
self
.
current_loss
=
self
.
current_loss
/
self
.
accumulate_steps
self
.
caches
[
'outputs'
][
cache_id
]
=
self
.
current_loss
.
clone
()
else
:
self
.
_send_activations
(
cache_id
)
def
_backward
(
self
,
cache_id
):
assert
self
.
optimizer
is
not
None
if
self
.
stage_id
==
self
.
num_stages
-
1
:
paddle
.
autograd
.
backward
(
self
.
current_loss
)
paddle
.
autograd
.
backward
(
self
.
caches
[
'outputs'
][
cache_id
])
self
.
_send_gradients
(
cache_id
)
return
self
.
_recv_gradients
(
cache_id
)
outputs
=
self
.
caches
[
'outputs'
][
cache_id
]
grad_tensors
=
self
.
grad_tensors
if
isinstance
(
outputs
,
tuple
):
out_tensors
=
[
t
for
t
in
outputs
if
t
.
dtype
in
FLOAT_TYPES
]
out_tensors
=
[
t
for
t
in
outputs
if
is_float_tensor
(
t
)
]
assert
len
(
out_tensors
)
==
len
(
grad_tensors
)
paddle
.
autograd
.
backward
(
tensors
=
out_tensors
,
grad_tensors
=
grad_tensors
)
...
...
@@ -179,41 +180,76 @@ class PipelineParallel(MetaParallelBase):
paddle
.
autograd
.
backward
(
tensors
=
[
outputs
],
grad_tensors
=
[
grad_tensors
])
self
.
caches
[
'outputs'
][
cache_id
]
=
None
grad_tensors
=
None
if
self
.
stage_id
!=
0
:
self
.
_send_gradients
(
cache_id
)
self
.
caches
[
'outputs'
][
cache_id
]
=
None
#self.caches['backward_tensors'][cache_id] = None
def
_get_data
(
self
):
if
self
.
use_model_parallel
:
mp_rank
=
self
.
_hcg
.
get_model_parallel_rank
()
else
:
mp_rank
=
0
# mp rank 0 loads the data and broadcat it to others.
data
=
self
.
data
if
self
.
use_model_parallel
and
(
self
.
stage_id
==
0
or
self
.
stage_id
==
self
.
num_stages
-
1
):
assert
isinstance
(
data
,
(
tuple
,
paddle
.
Tensor
))
if
isinstance
(
data
,
paddle
.
Tensor
):
paddle
.
distributed
.
broadcast
(
data
,
src
=
self
.
_hcg
.
get_model_parallel_group_src_rank
(),
group
=
self
.
_hcg
.
get_model_parallel_group
())
else
:
data
=
[]
for
d
in
self
.
data
:
assert
isinstance
(
d
,
paddle
.
Tensor
)
paddle
.
distributed
.
broadcast
(
d
,
src
=
self
.
_hcg
.
get_model_parallel_group_src_rank
(),
group
=
self
.
_hcg
.
get_model_parallel_group
())
data
.
append
(
d
)
data
=
tuple
(
data
)
return
data
def
_load_micro_batch
(
self
,
cache_id
):
inputs
=
self
.
_get_data
()
if
self
.
stage_id
==
0
:
data
=
None
if
isinstance
(
inputs
[
0
],
paddle
.
Tensor
):
#if isinstance(inputs[0], paddle.Tensor):
if
len
(
inputs
)
==
1
:
assert
isinstance
(
inputs
[
0
],
paddle
.
Tensor
)
data
=
inputs
[
0
].
clone
().
detach
()
data
.
stop_gradient
=
data
.
dtype
==
paddle
.
float32
#data.stop_gradient = not is_float_tensor(data)
data
.
stop_gradient
=
True
else
:
assert
isinstance
(
inputs
[
0
],
tuple
)
# Assume list or tuple
assert
isinstance
(
inputs
,
tuple
)
data
=
[]
for
d
in
inputs
[
0
]
:
for
d
in
inputs
:
assert
isinstance
(
d
,
paddle
.
Tensor
)
d
=
d
.
clone
().
detach
()
d
.
stop_gradient
=
d
.
dtype
==
paddle
.
float32
loaded
.
append
(
d
)
i
=
d
.
clone
().
detach
()
#i.stop_gradient = not is_float_tensor(i)
i
.
stop_gradient
=
True
data
.
append
(
i
)
data
=
tuple
(
data
)
self
.
caches
[
'inputs'
][
cache_id
]
=
data
if
self
.
stage_id
==
self
.
num_stages
-
1
:
label
=
None
if
isinstance
(
inputs
[
1
],
paddle
.
Tensor
):
label
=
inputs
[
1
]
elif
isinstance
(
data
[
1
],
tuple
):
label
=
[]
for
l
in
inputs
[
1
]:
assert
isinstance
(
l
,
paddle
.
Tensor
)
l
=
l
.
detach
()
label
.
append
(
l
)
label
=
tuple
(
label
)
self
.
caches
[
'labels'
][
cache_id
]
=
label
labels
=
None
#if isinstance(inputs[1], paddle.Tensor):
if
len
(
inputs
)
==
1
:
assert
isinstance
(
inputs
[
0
],
paddle
.
Tensor
)
labels
=
inputs
[
0
]
elif
isinstance
(
inputs
,
tuple
):
labels
=
[]
for
label
in
inputs
:
assert
isinstance
(
label
,
paddle
.
Tensor
)
label
=
label
.
detach
()
labels
.
append
(
label
)
labels
=
tuple
(
labels
)
self
.
caches
[
'labels'
][
cache_id
]
=
labels
def
_send_meta
(
self
,
data
,
peer
):
"""
...
...
@@ -225,54 +261,67 @@ class PipelineParallel(MetaParallelBase):
"""
if
isinstance
(
data
,
paddle
.
Tensor
):
tensor_type
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
send
(
tensor_type
,
peer
)
paddle
.
distributed
.
send
(
tensor_type
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
dims
=
paddle
.
to_tensor
(
len
(
data
.
shape
))
paddle
.
distributed
.
send
(
dims
,
peer
)
paddle
.
distributed
.
send
(
dims
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
shape
=
paddle
.
to_tensor
(
data
.
shape
)
paddle
.
distributed
.
send
(
shape
,
peer
)
paddle
.
distributed
.
send
(
shape
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
elif
isinstance
(
data
,
tuple
):
tensor_type
=
paddle
.
to_tensor
([
1
])
paddle
.
distributed
.
send
(
tensor_type
,
peer
)
paddle
.
distributed
.
send
(
tensor_type
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
nums
=
paddle
.
to_tensor
(
len
(
data
))
paddle
.
distributed
.
send
(
nums
,
peer
)
paddle
.
distributed
.
send
(
nums
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
for
idx
,
d
in
enumerate
(
data
):
assert
isinstance
(
d
,
paddle
.
Tensor
)
dims
=
paddle
.
to_tensor
(
len
(
d
.
shape
))
paddle
.
distributed
.
send
(
dims
,
peer
)
paddle
.
distributed
.
send
(
dims
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
shape
=
paddle
.
to_tensor
(
d
.
shape
)
paddle
.
distributed
.
send
(
shape
,
peer
)
paddle
.
distributed
.
send
(
shape
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
def
_recv_meta
(
self
,
peer
):
tensor_type
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
tensor_type
,
peer
)
paddle
.
distributed
.
recv
(
tensor_type
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
tensor_type
=
tensor_type
.
numpy
()[
0
]
if
tensor_type
==
0
:
dims
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
dims
,
peer
)
paddle
.
distributed
.
recv
(
dims
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
dims
=
dims
.
numpy
()[
0
]
shape
=
paddle
.
to_tensor
([
0
]
*
dims
)
paddle
.
distributed
.
recv
(
shape
,
peer
)
paddle
.
distributed
.
recv
(
shape
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
shape
=
shape
.
numpy
().
tolist
()
return
self
.
_allocate_buffer
(
shape
,
dtype
=
"float32"
,
num_caches
=
1
)[
0
]
elif
tensor_type
==
1
:
num
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
num
,
peer
)
paddle
.
distributed
.
recv
(
num
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
num
=
num
.
numpy
()[
0
]
shapes
=
[]
for
i
in
range
(
num
):
dims
=
paddle
.
to_tensor
([
0
])
paddle
.
distributed
.
recv
(
dims
,
peer
)
paddle
.
distributed
.
recv
(
dims
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
dims
=
dims
.
numpy
()[
0
]
shape
=
paddle
.
to_tensor
([
0
]
*
dims
)
paddle
.
distributed
.
recv
(
shape
,
peer
)
paddle
.
distributed
.
recv
(
shape
,
peer
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
shapes
.
append
(
shape
.
numpy
().
tolist
())
dtypes
=
[
"float32"
]
*
len
(
shapes
)
caches
=
self
.
_allocate_buffers
(
shapes
,
dtypes
,
num_
buffer
s
=
1
)[
0
]
buffers
=
tuple
(
buffer
s
)
return
buffer
s
caches
=
self
.
_allocate_buffers
(
shapes
,
dtypes
,
num_
cache
s
=
1
)[
0
]
caches
=
tuple
(
cache
s
)
return
cache
s
def
_send_activations
(
self
,
cache_id
):
outputs
=
self
.
caches
[
'outputs'
][
cache_id
]
...
...
@@ -282,10 +331,18 @@ class PipelineParallel(MetaParallelBase):
self
.
_send_meta
(
outputs
,
self
.
next_stage_id
)
if
isinstance
(
outputs
,
paddle
.
Tensor
):
paddle
.
distributed
.
send
(
outputs
,
self
.
next_stage_id
)
paddle
.
distributed
.
send
(
outputs
,
self
.
next_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
elif
isinstance
(
outputs
,
tuple
):
for
output
in
outputs
:
paddle
.
distributed
.
send
(
output
,
self
.
next_stage_id
)
paddle
.
distributed
.
send
(
output
,
self
.
next_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
def
_send_gradients
(
self
,
cache_id
):
inputs
=
self
.
caches
[
'inputs'
][
cache_id
]
...
...
@@ -293,15 +350,22 @@ class PipelineParallel(MetaParallelBase):
if
isinstance
(
inputs
,
paddle
.
Tensor
):
assert
inputs
.
grad
is
not
None
paddle
.
distributed
.
send
(
paddle
.
to_tensor
(
inputs
.
grad
),
self
.
prev_stage_id
)
paddle
.
to_tensor
(
inputs
.
grad
),
self
.
prev_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
else
:
for
idx
,
d
in
enumerate
(
inputs
):
# Skip tensors that will not produce a grad
if
not
d
.
dtype
in
FLOAT_TYPES
:
if
not
is_float_tensor
(
d
)
:
assert
d
.
grad
is
None
continue
assert
d
.
grad
is
not
None
paddle
.
distributed
.
send
(
d
.
grad
,
self
.
prev_stage_id
)
paddle
.
distributed
.
send
(
d
.
grad
,
self
.
prev_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
self
.
caches
[
'inputs'
][
cache_id
]
=
None
def
_recv_activations
(
self
,
cache_id
):
...
...
@@ -312,22 +376,30 @@ class PipelineParallel(MetaParallelBase):
self
.
recv_cache
=
self
.
_recv_meta
(
self
.
prev_stage_id
)
if
isinstance
(
self
.
recv_cache
,
paddle
.
Tensor
):
paddle
.
distributed
.
recv
(
self
.
recv_cache
,
self
.
prev_stage_id
)
paddle
.
distributed
.
recv
(
self
.
recv_cache
,
self
.
prev_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
inputs
=
self
.
recv_cache
.
clone
().
detach
()
inputs
.
stop_gradient
=
inputs
.
dtype
not
in
FLOAT_TYPES
inputs
.
stop_gradient
=
not
is_float_tensor
(
inputs
)
else
:
assert
isinstance
(
self
.
recv_cache
,
tuple
)
inputs
=
[
None
]
*
len
(
self
.
recv_cache
)
for
idx
,
d
in
enumerate
(
self
.
recv_cache
):
assert
isinstance
(
d
,
paddle
.
Tensor
)
paddle
.
distributed
.
recv
(
d
,
self
.
prev_stage_id
)
paddle
.
distributed
.
recv
(
d
,
self
.
prev_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
inputs
[
idx
]
=
d
.
clone
().
detach
()
inputs
=
tuple
(
inputs
)
for
d
in
inputs
:
d
.
stop_gradient
=
d
.
dtype
not
in
FLOAT_TYPES
d
.
stop_gradient
=
not
is_float_tensor
(
d
)
self
.
caches
[
'inputs'
][
cache_id
]
=
inputs
...
...
@@ -336,29 +408,35 @@ class PipelineParallel(MetaParallelBase):
if
self
.
grad_tensors
is
None
:
if
isinstance
(
outputs
,
paddle
.
Tensor
):
s
=
list
(
outputs
.
shape
)
dtype
=
'float
32'
dtype
=
'float
16'
if
self
.
use_amp
else
"float32"
self
.
grad_tensors
=
self
.
_allocate_buffer
(
s
,
dtype
,
num_buffers
=
1
)[
0
]
else
:
sizes
=
[
list
(
d
.
shape
)
for
d
in
outputs
if
d
.
dtype
in
FLOAT_TYPES
]
dtypes
=
[
'float32'
]
*
len
(
sizes
)
sizes
=
[
list
(
d
.
shape
)
for
d
in
outputs
if
is_float_tensor
(
d
)]
dtypes
=
[
'float16'
]
*
len
(
sizes
)
if
self
.
use_amp
else
[
'float32'
]
*
len
(
sizes
)
self
.
grad_tensors
=
self
.
_allocate_buffers
(
sizes
,
dtypes
,
num_
buffer
s
=
1
)[
0
]
sizes
,
dtypes
,
num_
cache
s
=
1
)[
0
]
if
isinstance
(
self
.
grad_tensors
,
paddle
.
Tensor
):
paddle
.
distributed
.
recv
(
self
.
grad_tensors
,
self
.
next_stage_id
)
paddle
.
distributed
.
recv
(
self
.
grad_tensors
,
self
.
next_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
else
:
assert
isinstance
(
outputs
,
tuple
)
for
d
in
self
.
grad_tensors
:
paddle
.
distributed
.
recv
(
d
,
self
.
next_stage_id
)
def
_step
(
self
,
lr_kwargs
=
None
):
self
.
_modifying_grad
=
True
paddle
.
distributed
.
recv
(
d
,
self
.
next_stage_id
,
use_calc_stream
=
True
,
group
=
self
.
pp_group
)
def
_step
(
self
):
self
.
_allreduce_grads
()
self
.
optimizer
.
step
()
self
.
optimizer
.
clear_gradients
()
self
.
_modifying_grad
=
False
def
_clear_grads
(
self
,
inputs
):
if
isinstance
(
inputs
,
paddle
.
Tensor
):
...
...
@@ -372,26 +450,24 @@ class PipelineParallel(MetaParallelBase):
def
_allocate_zeros
(
self
,
shape
,
dtype
):
return
paddle
.
zeros
(
shape
,
dtype
)
def
_allocate_buffer
(
self
,
shape
,
dtype
,
num_
buffers
=-
1
,
**
kwargs
):
buffer
s
=
[]
if
num_
buffer
s
==
-
1
:
num_
buffer
s
=
self
.
num_caches
for
count
in
range
(
num_
buffer
s
):
buffer
s
.
append
(
self
.
_allocate_zeros
(
shape
,
dtype
))
return
buffer
s
def
_allocate_buffers
(
self
,
shapes
,
dtypes
,
num_
buffer
s
=-
1
):
buffer
s
=
[]
if
num_
buffer
s
==
-
1
:
num_
buffer
s
=
self
.
num_caches
for
count
in
range
(
num_
buffer
s
):
buffer
=
[]
def
_allocate_buffer
(
self
,
shape
,
dtype
,
num_
caches
=-
1
):
cache
s
=
[]
if
num_
cache
s
==
-
1
:
num_
cache
s
=
self
.
num_caches
for
count
in
range
(
num_
cache
s
):
cache
s
.
append
(
self
.
_allocate_zeros
(
shape
,
dtype
))
return
cache
s
def
_allocate_buffers
(
self
,
shapes
,
dtypes
,
num_
cache
s
=-
1
):
cache
s
=
[]
if
num_
cache
s
==
-
1
:
num_
cache
s
=
self
.
num_caches
for
count
in
range
(
num_
cache
s
):
cache
=
[]
for
shape
,
dtype
in
zip
(
shapes
,
dtypes
):
buffer
.
append
(
self
.
_allocate_zeros
(
shape
,
dtype
,
requires_grad
=
requires_grad
))
buffers
.
append
(
buffer
)
return
buffers
cache
.
append
(
self
.
_allocate_zeros
(
shape
,
dtype
))
caches
.
append
(
cache
)
return
caches
def
save_state_dict
(
self
,
model_path
):
state_dict
=
self
.
_layers
.
state_dict
()
...
...
@@ -403,25 +479,9 @@ class PipelineParallel(MetaParallelBase):
_COMMAND_MAP
=
{
utils
.
Optimize
:
_step
,
#utils.ReduceGrads: _allreduce_grads,
utils
.
Forward
:
_forward
,
utils
.
Backward
:
_backward
,
}
def
_pre_forward
(
self
,
*
inputs
,
**
kwargs
):
pass
def
forward
(
self
,
*
inputs
,
**
kwargs
):
raise
RuntimeError
(
"Call train_batch for pipeline instead of forward."
)
def
_post_forward
(
self
,
output
):
pass
def
_pre_backward
(
self
,
loss
):
pass
def
backward_impl
(
self
,
loss
,
parameters
):
pass
def
_post_backward
(
self
,
loss
):
pass
python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
浏览文件 @
6b86e966
...
...
@@ -16,7 +16,21 @@ import abc
import
paddle
from
...utils
import
hybrid_parallel_util
as
hp_util
__all__
=
[
'get_tensor_bytes'
,
]
__all__
=
[
'get_tensor_bytes'
,
'is_float_tensor'
,
]
FLOAT_TYPES
=
[
paddle
.
float16
,
paddle
.
float32
,
paddle
.
float64
,
]
def
is_float_tensor
(
tensor
):
"""Is a float tensor"""
return
tensor
.
dtype
in
FLOAT_TYPES
def
get_tensor_bytes
(
tensor
):
...
...
@@ -48,10 +62,6 @@ class Generator():
self
.
stage_id
=
stage_id
self
.
prev_stage
=
self
.
stage_id
-
1
self
.
next_stage
=
self
.
stage_id
+
1
assert
self
.
micro_batches
>=
self
.
stages
,
(
"micro_batches {} "
"must be greater than or equal to {}"
.
format
(
self
.
micro_batches
,
self
.
stages
))
@
abc
.
abstractmethod
def
generate
(
self
):
...
...
@@ -73,18 +83,25 @@ class TrainGenerator(Generator):
cmds
=
[]
forward_steps
=
0
backward_steps
=
0
while
(
forward_steps
<
startup_steps
):
cmds
.
append
(
Forward
)
forward_steps
+=
1
#while (forward_steps < startup_steps):
# cmds.append(Forward(cache_id=forward_steps))
# forward_steps += 1
#while (forward_steps < self.micro_batches):
# cmds.append(Forward(cache_id=forward_steps))
# forward_steps += 1
# cmds.append(Backward(cache_id=backward_steps))
# backward_steps += 1
#while (backward_steps < self.micro_batches):
# cmds.append(Backward(cache_id=backward_steps))
# backward_steps += 1
#cmds.append(Optimize())
while
(
forward_steps
<
self
.
micro_batches
):
cmds
.
append
(
Forward
)
cmds
.
append
(
Forward
(
cache_id
=
forward_steps
)
)
forward_steps
+=
1
cmds
.
append
(
Backward
)
backward_steps
+=
1
while
(
backward_steps
<
self
.
micro_batches
):
cmds
.
append
(
Backward
)
cmds
.
append
(
Backward
(
cache_id
=
backward_steps
)
)
backward_steps
+=
1
cmds
.
append
(
Optimize
)
cmds
.
append
(
Optimize
()
)
yield
cmds
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录