Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
cc007dce
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cc007dce
编写于
3年前
作者:
L
lilong12
提交者:
GitHub
3年前
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move the recv op the beginning of the forward/backward phase for pipeline (#34197)
* mv recv to head, test=develop
上级
9c7f6af5
develop
Ligoml-patch-1
ZHUI-patch-1
add_some_yaml_config
addfile
ascendrelease
cherry_undefined_var
delete_delete_addfile
delete_disable_iterable_dataset_unittest
delete_fix_retry_ci
delete_fix_undefined_var
delete_improve_sccache
delete_paralleltest
delete_prv-disable-more-cache
delete_revert-34910-spinlocks_for_allocator
delete_revert-35069-revert-34910-spinlocks_for_allocator
delete_revert-36057-dev/read_flags_in_ut
dingjiaweiww-patch-1
disable_iterable_dataset_unittest
dy2static
enable_eager_model_test
final_state_gen_python_c
final_state_intermediate
fix-numpy-issue
fix_concat_slice
fix_npu_ci
fix_op_flops
fix_retry_ci
fix_rnn_docs
fix_tensor_type
fix_undefined_var
fixiscan
fixiscan1
fixiscan2
fixiscan3
improve_sccache
incubate/infrt
inplace_addto
make_flag_adding_easier
move_embedding_to_phi
move_histogram_to_pten
move_sgd_to_phi
move_slice_to_pten
move_temporal_shift_to_phi
move_yolo_box_to_phi
npu_fix_alloc
paralleltest
preln_ernie
prv-disable-more-cache
prv-md-even-more
prv-onednn-2.5
pten_tensor_refactor
release/2.2
release/2.3
release/2.3-fc-ernie-fix
release/2.4
revert-33475-fix_cifar_label_dimension
revert-34406-add_copy_from_tensor
revert-34910-spinlocks_for_allocator
revert-35069-revert-34910-spinlocks_for_allocator
revert-36057-dev/read_flags_in_ut
revert-36201-refine_fast_threaded_ssa_graph_executor
revert-36985-add_license
revert-37318-refactor_dygraph_to_eager
revert-37926-eager_coreops_500
revert-37956-revert-37727-pylayer_support_tuple
revert-38100-mingdong
revert-38301-allocation_rearrange_pr
revert-38703-numpy_bf16_package_reupload
revert-38732-remove_useless_header_in_elementwise_mul_grad
revert-38959-Reduce_Grad
revert-39143-adjust_empty
revert-39227-move_trace_op_to_pten
revert-39268-dev/remove_concat_fluid_kernel
revert-40170-support_partial_grad
revert-41056-revert-40727-move_some_activaion_to_phi
revert-41065-revert-40993-mv_ele_floordiv_pow
revert-41068-revert-40790-phi_new
revert-41944-smaller_inference_api_test
revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator
revert-43155-fix_ut_tempfile
revert-43882-revert-41944-smaller_inference_api_test
revert-45808-phi/simplify_size_op
revert-46827-deform_comment
support_weight_transpose
zhiqiu-patch-1
v2.4.0-rc0
v2.3.2
v2.3.1
v2.3.0
v2.3.0-rc0
v2.2.2
v2.2.1
v2.2.0
v2.2.0-rc0
v2.2.0-bak0
无相关合并请求
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
53 addition
and
0 deletion
+53
-0
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+52
-0
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+1
-0
未找到文件。
python/paddle/fluid/optimizer.py
浏览文件 @
cc007dce
...
...
@@ -5280,6 +5280,55 @@ class PipelineOptimizer(object):
attrs
=
{
self
.
_op_role_key
:
self
.
_op_role
.
Backward
})
block
.
_sync_with_cpp
()
def
_mv_head_recv
(
self
,
program
):
"""
A pass to move the recv op to the beginning of
the forward/backward phase
"""
forward_insert_index
=
0
backward_insert_index
=
None
block
=
program
.
global_block
()
num_ops
=
len
(
program
.
global_block
().
ops
)
for
i
in
range
(
num_ops
):
insert_index
=
None
op
=
program
.
global_block
().
ops
[
i
]
op_role
=
int
(
op
.
attr
(
self
.
_op_role_key
))
if
op_role
==
int
(
self
.
_op_role
.
Backward
)
and
backward_insert_index
is
None
:
backward_insert_index
=
i
if
op
.
type
!=
"partial_recv"
and
op
.
type
!=
"partial_allgather"
and
op
.
type
!=
"nop"
and
op
.
type
!=
"recv_v2"
:
continue
if
op_role
==
int
(
self
.
_op_role
.
Forward
):
if
i
==
forward_insert_index
:
forward_insert_index
+=
1
continue
insert_index
=
forward_insert_index
elif
op_role
==
int
(
self
.
_op_role
.
Backward
):
if
i
==
backward_insert_index
:
backward_insert_index
+=
1
continue
insert_index
=
backward_insert_index
else
:
raise
ValueError
(
"Unknown op_role: {}"
.
format
(
op_role
))
op_inputs
=
dict
()
for
name
in
op
.
input_names
:
op_inputs
[
name
]
=
op
.
input
(
name
)
op_outputs
=
dict
()
for
name
in
op
.
output_names
:
op_outputs
[
name
]
=
op
.
output
(
name
)
block
.
_insert_op_without_sync
(
index
=
insert_index
,
type
=
op
.
type
,
inputs
=
op_inputs
,
outputs
=
op_outputs
,
attrs
=
op
.
all_attrs
())
block
.
_remove_op
(
i
+
1
)
if
op_role
==
int
(
self
.
_op_role
.
Forward
):
forward_insert_index
+=
1
elif
op_role
==
int
(
self
.
_op_role
.
Backward
):
backward_insert_index
+=
1
block
.
_sync_with_cpp
()
def
minimize
(
self
,
loss
,
startup_program
=
None
,
...
...
@@ -5393,6 +5442,9 @@ class PipelineOptimizer(object):
place_id
=
int
(
os
.
getenv
(
"FLAGS_selected_gpus"
,
"0"
))
elif
core
.
is_compiled_with_npu
():
place_id
=
int
(
os
.
getenv
(
"FLAGS_selected_npus"
,
"0"
))
# A pass to move the recv op to the beginning of
# the forward/backward phase
self
.
_mv_head_recv
(
program_list
[
self
.
local_rank
])
main_program
.
_pipeline_opt
=
{
"trainer"
:
"PipelineTrainer"
,
"device_worker"
:
"Section"
,
...
...
This diff is collapsed.
Click to expand it.
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
cc007dce
...
...
@@ -144,6 +144,7 @@ class TestDistRunnerBase(object):
loss
=
loss
[
0
]
if
loss
else
None
out_losses
.
append
(
loss
)
print_to_err
(
type
(
self
).
__name__
,
"run step %d finished"
%
i
)
data_loader
.
reset
()
print_to_err
(
type
(
self
).
__name__
,
"trainer run finished"
)
sys
.
stdout
.
buffer
.
write
(
pickle
.
dumps
(
out_losses
))
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部