未验证 提交 c00303ec 编写于 作者: S sneaxiy 提交者: GitHub

fix test allreduce tests (#39166)

上级 7874d0a5
...@@ -603,9 +603,9 @@ static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies( ...@@ -603,9 +603,9 @@ static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies(
for (const auto *op_desc : block_ops) { for (const auto *op_desc : block_ops) {
size_t op_idx = op_id_to_idx.size(); size_t op_idx = op_id_to_idx.size();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_id_to_idx.emplace(op_desc->Id(), op_idx).second, true, op_id_to_idx.emplace(op_desc->OriginalId(), op_idx).second, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"There should not be duplicate op id: %d", op_desc->Id())); "There should not be duplicate op id: %d", op_desc->OriginalId()));
} }
std::vector<std::vector<ir::Node::Dep>> dep_matrix(op_num); std::vector<std::vector<ir::Node::Dep>> dep_matrix(op_num);
...@@ -624,9 +624,9 @@ static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies( ...@@ -624,9 +624,9 @@ static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies(
for (const auto &pair : all_preceding_ops) { for (const auto &pair : all_preceding_ops) {
const auto *cur_op_node = pair.first; const auto *cur_op_node = pair.first;
size_t op_idx_1 = get_op_idx_by_id(cur_op_node->Op()->Id()); size_t op_idx_1 = get_op_idx_by_id(cur_op_node->Op()->OriginalId());
for (const auto *preceding_op_node : pair.second) { for (const auto *preceding_op_node : pair.second) {
size_t op_idx_2 = get_op_idx_by_id(preceding_op_node->Op()->Id()); size_t op_idx_2 = get_op_idx_by_id(preceding_op_node->Op()->OriginalId());
dep_matrix[op_idx_1][op_idx_2] = ir::Node::Dep::kAfter; dep_matrix[op_idx_1][op_idx_2] = ir::Node::Dep::kAfter;
dep_matrix[op_idx_2][op_idx_1] = ir::Node::Dep::kBefore; dep_matrix[op_idx_2][op_idx_1] = ir::Node::Dep::kBefore;
} }
......
...@@ -4,6 +4,6 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") ...@@ -4,6 +4,6 @@ string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
list(APPEND DIST_TEST_OPS ${TEST_OP}) list(APPEND DIST_TEST_OPS ${TEST_OP})
set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 90) set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 120)
set_tests_properties(${TEST_OP} PROPERTIES LABELS "RUN_TYPE=DIST") set_tests_properties(${TEST_OP} PROPERTIES LABELS "RUN_TYPE=DIST")
endforeach(TEST_OP) endforeach(TEST_OP)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册