未验证 提交 ec6519e8 编写于 作者: G gongweibao 提交者: GitHub

Fix allreducedep bug (#16443)

上级 7c5319ba
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -52,13 +53,28 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -52,13 +53,28 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
// Note that must assert topology sort is stable // Note that must assert topology sort is stable
auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs); auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
for (auto* op_desc : ops) { for (auto* op_desc : ops) {
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(op_desc->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
auto backward_vars =
boost::get<std::vector<std::string>>(op_desc->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
auto outputs = op_desc->Outputs(); auto outputs = op_desc->Outputs();
for (auto& o_it : outputs) { for (auto& o_it : outputs) {
for (auto& v : o_it.second) { // values for (auto& v : o_it.second) { // values
vars[v] = order; vars[v] = order;
VLOG(1) << "in all_reduce_deps_pass:" << v;
} }
} }
order++; order++;
} catch (boost::bad_get e) {
}
} }
std::vector<OpHandleBase*> dist_ops; std::vector<OpHandleBase*> dist_ops;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册