未验证 提交 ba1bbe8e 编写于 作者: R Ruibiao Chen 提交者: GitHub

Fix op_happen_before_ update bug for AddDownstreamOp (#46486)

上级 38e82868
...@@ -70,15 +70,15 @@ const std::map<int, std::set<int>>& DependencyBuilder::Build( ...@@ -70,15 +70,15 @@ const std::map<int, std::set<int>>& DependencyBuilder::Build(
BuildOpHappensBefore(); BuildOpHappensBefore();
ShrinkDownstreamMap(); ShrinkDownstreamMap();
if (is_sequential_run) {
AddDependencyForSequentialRun();
}
AddDependencyForCoalesceTensorOp(); AddDependencyForCoalesceTensorOp();
AddDependencyForCommunicationOp(); AddDependencyForCommunicationOp();
AddDependencyForRandomOp(); AddDependencyForRandomOp();
AddDependencyForReadOp(); AddDependencyForReadOp();
if (is_sequential_run) {
AddDependencyForSequentialRun();
}
is_build_ = true; is_build_ = true;
VLOG(8) << "Finish build dependency"; VLOG(8) << "Finish build dependency";
...@@ -335,6 +335,10 @@ void DependencyBuilder::AddDownstreamOp(int prior_op_idx, ...@@ -335,6 +335,10 @@ void DependencyBuilder::AddDownstreamOp(int prior_op_idx,
if (op_happens_before_.size() != 0) { if (op_happens_before_.size() != 0) {
for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) { for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) {
if (op_happens_before_[op_idx][prior_op_idx]) {
op_happens_before_[op_idx][posterior_op_idx] = true;
}
if (op_happens_before_[posterior_op_idx][op_idx]) { if (op_happens_before_[posterior_op_idx][op_idx]) {
op_happens_before_[prior_op_idx][op_idx] = true; op_happens_before_[prior_op_idx][op_idx] = true;
} }
...@@ -461,10 +465,6 @@ void DependencyBuilder::BuildDownstreamMap() { ...@@ -461,10 +465,6 @@ void DependencyBuilder::BuildDownstreamMap() {
AddDownstreamOp(dep_op, op); AddDownstreamOp(dep_op, op);
} }
} }
VLOG(6) << "downstream count: " << CountDownstreamMap(op_downstream_map_);
VLOG(6) << "downstream_map: " << std::endl
<< StringizeDownstreamMap(op_downstream_map_);
} }
void DependencyBuilder::BuildOpHappensBefore() { void DependencyBuilder::BuildOpHappensBefore() {
...@@ -542,8 +542,9 @@ void DependencyBuilder::ShrinkDownstreamMap() { ...@@ -542,8 +542,9 @@ void DependencyBuilder::ShrinkDownstreamMap() {
} }
op_downstream_map_.at(i) = minumum_nexts; op_downstream_map_.at(i) = minumum_nexts;
} }
VLOG(6) << "downstream count: " << CountDownstreamMap(op_downstream_map_); VLOG(8) << "Finish shrink downstream map";
VLOG(6) << "downstream_map: " << std::endl VLOG(8) << "downstream count: " << CountDownstreamMap(op_downstream_map_);
VLOG(8) << "downstream_map: " << std::endl
<< StringizeDownstreamMap(op_downstream_map_); << StringizeDownstreamMap(op_downstream_map_);
} }
......
...@@ -268,7 +268,7 @@ void create_all_ops(const framework::BlockDesc& block, ...@@ -268,7 +268,7 @@ void create_all_ops(const framework::BlockDesc& block,
std::vector<std::unique_ptr<OperatorBase>>* ops) { std::vector<std::unique_ptr<OperatorBase>>* ops) {
for (auto& op : block.AllOps()) { for (auto& op : block.AllOps()) {
auto op_type = op->Type(); auto op_type = op->Type();
VLOG(1) << "CreateOp from : " << op_type; VLOG(8) << "CreateOp from : " << op_type;
auto& info = OpInfoMap::Instance().Get(op_type); auto& info = OpInfoMap::Instance().Get(op_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册