未验证 提交 8126a41d 编写于 作者: L lilong12 提交者: GitHub

fix the bug of all_reduce pipeline gradient multiple times (#30437)

* update, test=develop
上级 621bc4f7
......@@ -233,6 +233,7 @@ class PipelineOptimizer(MetaOptimizerBase):
block = self.main_program_list[ring_id - 1]['program'].global_block()
origin_block = self.main_program.global_block()
grad = None
processed_param_name = set()
for idx, op in reversed(list(enumerate(block.ops))):
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
......@@ -242,7 +243,10 @@ class PipelineOptimizer(MetaOptimizerBase):
assert len(op_role_var) % 2 == 0
offset = idx
for i in range(0, len(op_role_var), 2):
param_name = op_role_var[i]
param = block.vars[op_role_var[i]]
if param_name in processed_param_name: continue
processed_param_name.add(param_name)
grad = block.vars[op_role_var[i + 1]]
origin_param = origin_block.vars[op_role_var[i]]
if origin_param.is_distributed:
......
......@@ -10,7 +10,7 @@ if(NOT WITH_NCCL)
endif()
string(REPLACE ".py" "" DIST_TEST_OPS "${DIST_TEST_OPS}")
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mnist)
#list(APPEND DIST_TEST_OPS test_pipeline)
list(APPEND DIST_TEST_OPS test_pipeline)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height)
......@@ -62,7 +62,6 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
endforeach()
list(REMOVE_ITEM TEST_OPS test_pipeline)
if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op)
......@@ -826,9 +825,9 @@ if(WITH_GPU AND NOT WIN32)
set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120)
# if(WITH_DISTRIBUTE)
# set_tests_properties(test_pipeline PROPERTIES TIMEOUT 120)
# endif()
if(WITH_DISTRIBUTE)
set_tests_properties(test_pipeline PROPERTIES TIMEOUT 120)
endif()
set_tests_properties(test_reducescatter_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_broadcast PROPERTIES TIMEOUT 120)
set_tests_properties(test_reducescatter PROPERTIES TIMEOUT 120)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册