Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5a9214d8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2301
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5a9214d8
编写于
8月 31, 2023
作者:
S
ShenLiang
提交者:
GitHub
8月 31, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Distributed]Fix cache p2p in pp (#56796)
* add usecache * add p2p cache fix * add cache
上级
dfcfc8b7
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
248 addition
and
198 deletion
+248
-198
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+47
-20
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
...ributed/fleet/meta_parallel/pp_utils/p2p_communication.py
+201
-178
未找到文件。
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
5a9214d8
...
...
@@ -199,11 +199,13 @@ class PipelineParallel(MetaParallelBase):
p2p
.
initialize_p2p_groups
(
hcg
,
self
.
_using_cache
,
self
.
_enable_partial_send_recv
,
self
.
_enable_timer
,
)
# construct pipeline meta info
self
.
_p2p_helper
=
p2p
.
P2pHelper
(
self
.
_using_cache
)
self
.
global_rank
=
self
.
_hcg
.
get_global_rank
()
self
.
micro_batch_id
=
0
...
...
@@ -349,10 +351,14 @@ class PipelineParallel(MetaParallelBase):
micro_dataset
=
self
.
_wrap_data
(
data
)
for
step_id
in
range
(
startup_steps
):
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
input_tensor
=
self
.
_p2p_helper
.
recv_forward
(
self
.
is_pipeline_first_stage
()
)
output_tensor
=
self
.
_forward_step
(
input_tensor
,
micro_dataset
)
p2p
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
self
.
_p2p_helper
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
()
)
input_buffers
.
append
(
input_tensor
)
output_buffers
.
append
(
output_tensor
)
...
...
@@ -361,14 +367,16 @@ class PipelineParallel(MetaParallelBase):
self
.
_release_output
(
output_tensor
)
if
steady_steps
>
0
:
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
input_tensor
=
self
.
_p2p_helper
.
recv_forward
(
self
.
is_pipeline_first_stage
()
)
for
i
in
range
(
steady_steps
):
last_iter
=
i
==
(
steady_steps
-
1
)
output_tensor
=
self
.
_forward_step
(
input_tensor
,
micro_dataset
)
output_tensor_grad
=
p2p
.
send_forward_recv_backward
(
output_tensor_grad
=
self
.
_p2p_helper
.
send_forward_recv_backward
(
output_tensor
,
self
.
is_pipeline_last_stage
()
)
...
...
@@ -388,11 +396,11 @@ class PipelineParallel(MetaParallelBase):
if
last_iter
:
input_tensor
=
None
p2p
.
send_backward
(
self
.
_p2p_helper
.
send_backward
(
input_tensor_grad
,
self
.
is_pipeline_first_stage
()
)
else
:
input_tensor
=
p2p
.
send_backward_recv_forward
(
input_tensor
=
self
.
_p2p_helper
.
send_backward_recv_forward
(
input_tensor_grad
,
self
.
is_pipeline_first_stage
()
)
...
...
@@ -400,14 +408,16 @@ class PipelineParallel(MetaParallelBase):
input_tensor
=
input_buffers
.
pop
(
0
)
output_tensor
=
output_buffers
.
pop
(
0
)
output_tensor_grad
=
p2p
.
recv_backward
(
output_tensor_grad
=
self
.
_p2p_helper
.
recv_backward
(
self
.
is_pipeline_last_stage
()
)
input_tensor_grad
=
self
.
_backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
)
p2p
.
send_backward
(
input_tensor_grad
,
self
.
is_pipeline_first_stage
())
self
.
_p2p_helper
.
send_backward
(
input_tensor_grad
,
self
.
is_pipeline_first_stage
()
)
if
self
.
_comm_overlap
:
assert
(
...
...
@@ -513,28 +523,38 @@ class PipelineParallel(MetaParallelBase):
micro_dataset
=
self
.
_wrap_data
(
data
)
for
step_id
in
range
(
startup_steps
):
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
input_tensor
=
self
.
_p2p_helper
.
recv_forward
(
self
.
is_pipeline_first_stage
()
)
output_tensor
=
self
.
_forward_step
(
input_tensor
,
micro_dataset
)
p2p
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
self
.
_p2p_helper
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
()
)
input_buffers
.
append
(
input_tensor
)
output_buffers
.
append
(
output_tensor
)
if
steady_steps
>
0
:
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
input_tensor
=
self
.
_p2p_helper
.
recv_forward
(
self
.
is_pipeline_first_stage
()
)
for
i
in
range
(
steady_steps
):
last_iter
=
i
==
(
steady_steps
-
1
)
output_tensor
=
self
.
_forward_step
(
input_tensor
,
micro_dataset
)
p2p
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
())
self
.
_p2p_helper
.
send_forward
(
output_tensor
,
self
.
is_pipeline_last_stage
()
)
input_buffers
.
append
(
input_tensor
)
output_buffers
.
append
(
output_tensor
)
if
not
last_iter
:
input_tensor
=
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
())
input_tensor
=
self
.
_p2p_helper
.
recv_forward
(
self
.
is_pipeline_first_stage
()
)
if
self
.
_compute_loss
:
self
.
train_loss
=
self
.
_broadcast_final_loss
()
...
...
@@ -859,6 +879,11 @@ class PipelineParallelWithInterleave(PipelineParallel):
not
forward_only
),
"compute_loss can only be set to False when forward_only is set to True"
# NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled
assert
(
self
.
_using_cache
),
"cache should be enabled for pipeline with interleave"
# init some attributes for this batch run
self
.
scaler
=
scaler
self
.
total_loss
=
None
...
...
@@ -904,7 +929,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
self
.
set_virtual_pipeline_rank
(
0
)
self
.
input_tensors
[
0
].
append
(
p2p
.
recv_forward
(
self
.
is_pipeline_first_stage
(),
sync_recv
=
False
)
self
.
_p2p_helper
.
recv_forward
(
self
.
is_pipeline_first_stage
(),
sync_recv
=
False
)
)
# run startup steps
...
...
@@ -942,7 +969,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
(
input_tensor
,
output_tensor_grad
,
)
=
p2p
.
send_forward_backward_recv_forward_backward
(
)
=
self
.
_p2p_helper
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
...
...
@@ -952,7 +979,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
output_tensor_grad
)
else
:
input_tensor
=
p2p
.
send_forward_recv_forward
(
input_tensor
=
self
.
_p2p_helper
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
=
recv_prev
)
self
.
input_tensors
[
next_virtual_pp_rank
].
append
(
input_tensor
)
...
...
@@ -1033,7 +1060,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
(
input_tensor
,
output_tensor_grad
,
)
=
p2p
.
send_forward_backward_recv_forward_backward
(
)
=
self
.
_p2p_helper
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
...
...
@@ -1057,7 +1084,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
if
not
forward_only
:
if
all_startup_steps
:
self
.
output_tensor_grads
[
self
.
num_model_chunks
-
1
].
append
(
p2p
.
recv_backward
(
self
.
_p2p_helper
.
recv_backward
(
self
.
is_pipeline_last_stage
(),
sync_recv
=
False
)
)
...
...
@@ -1080,7 +1107,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
recv_next
=
False
self
.
output_tensor_grads
[
next_backward_virtual_pp_rank
].
append
(
p2p
.
send_backward_recv_backward
(
self
.
_p2p_helper
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
)
)
...
...
python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py
浏览文件 @
5a9214d8
...
...
@@ -31,17 +31,16 @@ from ...utils import timer_helper as timer
from
.utils
import
number_2_dtype
,
paddle_2_number
_hcg
=
None
_use_cache
=
False
#
_use_cache = False
_enable_partial_send_recv
=
True
_timers
=
None
def
initialize_p2p_groups
(
hcg
,
use_cache
=
True
,
enable_partial_send_recv
=
True
,
enable_timer
=
False
hcg
,
enable_partial_send_recv
=
True
,
enable_timer
=
False
):
global
_hcg
,
_
use_cache
,
_
enable_partial_send_recv
,
_timers
global
_hcg
,
_enable_partial_send_recv
,
_timers
_hcg
=
hcg
_use_cache
=
use_cache
_enable_partial_send_recv
=
enable_partial_send_recv
if
enable_timer
:
_timers
=
timer
.
get_timers
()
...
...
@@ -170,8 +169,14 @@ class SendRecvMeta:
]
)
_send_recv_meta
=
SendRecvMeta
()
def
__repr__
(
self
):
return
"send_shape_message: {}, send_dtype_message: {}, recv_shape_message: {}, recv_dtype_message: {}, recv_stop_gradient: {}"
.
format
(
self
.
send_shape_message
,
self
.
send_dtype_message
,
self
.
recv_shape_message
,
self
.
recv_dtype_message
,
self
.
recv_stop_gradient
,
)
def
_is_valid_send_recv_partial
(
tensor
,
mp_degree
):
...
...
@@ -303,7 +308,12 @@ def _process_p2p_tuple_or_tensor(
def
_p2p_helper
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
sync_recv
=
True
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
sync_recv
=
True
,
send_recv_meta
=
None
,
):
global
_hcg
...
...
@@ -311,12 +321,13 @@ def _p2p_helper(
tensor_recv_next
=
None
# send / recv message
recv_shape_msg
=
_send_recv_meta
.
recv_shape_message
recv_dtype_msg
=
_send_recv_meta
.
recv_dtype_message
recv_stop_gradient
=
_send_recv_meta
.
recv_stop_gradient
assert
send_recv_meta
is
not
None
,
"send_recv_meta should not be None"
recv_shape_msg
=
send_recv_meta
.
recv_shape_message
recv_dtype_msg
=
send_recv_meta
.
recv_dtype_message
recv_stop_gradient
=
send_recv_meta
.
recv_stop_gradient
send_shape_msg
=
_
send_recv_meta
.
send_shape_message
send_dtype_msg
=
_
send_recv_meta
.
send_dtype_message
send_shape_msg
=
send_recv_meta
.
send_shape_message
send_dtype_msg
=
send_recv_meta
.
send_dtype_message
# model parallel message
mp_group
=
_hcg
.
get_model_parallel_group
()
...
...
@@ -441,183 +452,195 @@ def _p2p_helper(
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
pp_first_stage
,
sync_recv
=
True
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"recv_forward"
).
start
()
if
pp_first_stage
:
input_tensor
=
None
else
:
if
not
_send_recv_meta
.
has_recv_meta
:
_send_recv_meta
.
recv_meta
(
_hcg
.
get_pipe_parallel_group
())
_send_recv_meta
.
has_recv_meta
=
_use_cache
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_next
=
False
,
sync_recv
=
sync_recv
,
)
if
_timers
is
not
None
:
_timers
(
"recv_forward"
).
stop
()
return
input_tensor
class
P2pHelper
:
def
__init__
(
self
,
use_cache
=
True
):
self
.
_send_recv_meta
=
SendRecvMeta
()
self
.
_use_cache
=
use_cache
def
recv_backward
(
pp_last_stage
,
sync_recv
=
True
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"recv_backward"
).
start
()
if
pp_last_stage
:
output_tensor_grad
=
None
else
:
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
sync_recv
=
sync_recv
,
)
if
_timers
is
not
None
:
_timers
(
"recv_backward"
).
stop
()
return
output_tensor_grad
def
send_forward
(
output_tensor
,
pp_last_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward"
).
start
()
if
not
pp_last_stage
:
if
not
_send_recv_meta
.
has_send_meta
:
_send_recv_meta
.
set_send_message
(
output_tensor
)
_send_recv_meta
.
send_meta
(
def
_send_meta
(
self
,
output_tensor
):
if
not
self
.
_send_recv_meta
.
has_send_meta
:
self
.
_send_recv_meta
.
set_send_message
(
output_tensor
)
self
.
_send_recv_meta
.
send_meta
(
output_tensor
,
_hcg
.
get_pipe_parallel_group
()
)
_send_recv_meta
.
has_send_meta
=
_use_cache
_p2p_helper
(
self
.
_send_recv_meta
.
has_send_meta
=
self
.
_use_cache
def
_recv_meta
(
self
):
if
not
self
.
_send_recv_meta
.
has_recv_meta
:
self
.
_send_recv_meta
.
recv_meta
(
_hcg
.
get_pipe_parallel_group
())
self
.
_send_recv_meta
.
has_recv_meta
=
self
.
_use_cache
def
recv_forward
(
self
,
pp_first_stage
,
sync_recv
=
True
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"recv_forward"
).
start
()
if
pp_first_stage
:
input_tensor
=
None
else
:
self
.
_recv_meta
()
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_next
=
False
,
sync_recv
=
sync_recv
,
send_recv_meta
=
self
.
_send_recv_meta
,
)
if
_timers
is
not
None
:
_timers
(
"recv_forward"
).
stop
()
return
input_tensor
def
recv_backward
(
self
,
pp_last_stage
,
sync_recv
=
True
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"recv_backward"
).
start
()
if
pp_last_stage
:
output_tensor_grad
=
None
else
:
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
sync_recv
=
sync_recv
,
send_recv_meta
=
self
.
_send_recv_meta
,
)
if
_timers
is
not
None
:
_timers
(
"recv_backward"
).
stop
()
return
output_tensor_grad
def
send_forward
(
self
,
output_tensor
,
pp_last_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward"
).
start
()
if
not
pp_last_stage
:
self
.
_send_meta
(
output_tensor
)
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
,
send_recv_meta
=
self
.
_send_recv_meta
,
)
if
_timers
is
not
None
:
_timers
(
"send_forward"
).
stop
()
def
send_backward
(
self
,
input_tensor_grad
,
pp_first_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward"
).
start
()
if
not
pp_first_stage
:
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
,
send_recv_meta
=
self
.
_send_recv_meta
,
)
if
_timers
is
not
None
:
_timers
(
"send_backward"
).
stop
()
def
send_forward_recv_backward
(
self
,
output_tensor
,
pp_last_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_backward"
).
start
()
if
pp_last_stage
:
output_tensor_grad
=
None
else
:
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
send_recv_meta
=
self
.
_send_recv_meta
,
)
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_backward"
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
self
,
input_tensor_grad
,
pp_first_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_forward"
).
start
()
if
pp_first_stage
:
input_tensor
=
None
else
:
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
,
send_recv_meta
=
self
.
_send_recv_meta
,
)
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_forward"
).
stop
()
return
input_tensor
def
send_forward_backward_recv_forward_backward
(
self
,
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
):
# always have to send dytpe info to downstream
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward_backward_recv_forward_backward"
).
start
()
self
.
_send_meta
(
output_tensor
)
if
recv_prev
:
self
.
_recv_meta
()
input_tensor
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
,
)
if
_timers
is
not
None
:
_timers
(
"send_forward"
).
stop
()
def
send_backward
(
input_tensor_grad
,
pp_first_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward"
).
start
()
if
not
pp_first_stage
:
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
sync_recv
=
False
,
send_recv_meta
=
self
.
_send_recv_meta
,
)
if
_timers
is
not
None
:
_timers
(
"send_backward"
).
stop
()
if
_timers
is
not
None
:
_timers
(
"send_forward_backward_recv_forward_backward"
).
stop
()
return
input_tensor
,
output_tensor_grad
def
send_forward_recv_forward
(
self
,
output_tensor
,
recv_prev
):
# always have to send dytpe info to downstream
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_forward"
).
start
()
def
send_forward_recv_backward
(
output_tensor
,
pp_last_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_backward"
).
start
()
if
pp_last_stage
:
output_tensor_grad
=
None
else
:
_
,
output_tensor_grad
=
_p2p_helper
(
self
.
_send_meta
(
output_tensor
)
if
recv_prev
:
self
.
_recv_meta
()
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
sync_recv
=
False
,
send_recv_meta
=
self
.
_send_recv_meta
,
)
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_backward"
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
pp_first_stage
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_forward"
).
start
()
if
pp_first_stage
:
input_tensor
=
None
else
:
input_tensor
,
_
=
_p2p_helper
(
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_forward"
).
stop
()
return
input_tensor
def
send_backward_recv_backward
(
self
,
input_tensor_grad
,
recv_next
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_backward"
).
start
()
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
,
recv_prev
=
False
,
recv_next
=
recv_next
,
sync_recv
=
False
,
send_recv_meta
=
self
.
_send_recv_meta
,
)
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_forward"
).
stop
()
return
input_tensor
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
):
# always have to send dytpe info to downstream
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward_backward_recv_forward_backward"
).
start
()
if
not
_send_recv_meta
.
has_send_meta
:
_send_recv_meta
.
set_send_message
(
output_tensor
)
_send_recv_meta
.
send_meta
(
output_tensor
,
_hcg
.
get_pipe_parallel_group
())
_send_recv_meta
.
has_send_meta
=
_use_cache
if
recv_prev
and
not
_send_recv_meta
.
has_recv_meta
:
_send_recv_meta
.
recv_meta
(
_hcg
.
get_pipe_parallel_group
())
_send_recv_meta
.
has_recv_meta
=
_use_cache
input_tensor
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
sync_recv
=
False
,
)
if
_timers
is
not
None
:
_timers
(
"send_forward_backward_recv_forward_backward"
).
stop
()
return
input_tensor
,
output_tensor_grad
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
):
# always have to send dytpe info to downstream
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_forward"
).
start
()
if
not
_send_recv_meta
.
has_send_meta
:
_send_recv_meta
.
set_send_message
(
output_tensor
)
_send_recv_meta
.
send_meta
(
output_tensor
,
_hcg
.
get_pipe_parallel_group
())
_send_recv_meta
.
has_send_meta
=
_use_cache
if
recv_prev
and
not
_send_recv_meta
.
has_recv_meta
:
_send_recv_meta
.
recv_meta
(
_hcg
.
get_pipe_parallel_group
())
_send_recv_meta
.
has_recv_meta
=
_use_cache
input_tensor
,
_
=
_p2p_helper
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
sync_recv
=
False
,
)
if
_timers
is
not
None
:
_timers
(
"send_forward_recv_forward"
).
stop
()
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
):
global
_timers
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_backward"
).
start
()
_
,
output_tensor_grad
=
_p2p_helper
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
,
sync_recv
=
False
,
)
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_backward"
).
stop
()
return
output_tensor_grad
if
_timers
is
not
None
:
_timers
(
"send_backward_recv_backward"
).
stop
()
return
output_tensor_grad
def
__repr__
(
self
):
debug_str
=
f
"using cache:
{
self
.
_use_cache
}
\n
"
debug_str
+=
repr
(
self
.
_send_recv_meta
)
return
debug_str
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录