未验证 提交 ec6519e8 编写于 作者: G gongweibao 提交者: GitHub

Fix allreducedep bug (#16443)

上级 7c5319ba
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -52,13 +53,28 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
// Note that must assert topology sort is stable
auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
for (auto* op_desc : ops) {
auto outputs = op_desc->Outputs();
for (auto& o_it : outputs) {
for (auto& v : o_it.second) { // values
vars[v] = order;
try {
bool is_bk_op =
static_cast<bool>(boost::get<int>(op_desc->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
auto backward_vars =
boost::get<std::vector<std::string>>(op_desc->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
auto outputs = op_desc->Outputs();
for (auto& o_it : outputs) {
for (auto& v : o_it.second) { // values
vars[v] = order;
VLOG(1) << "in all_reduce_deps_pass:" << v;
}
}
order++;
} catch (boost::bad_get e) {
}
order++;
}
std::vector<OpHandleBase*> dist_ops;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册