未验证 提交 eeec4b1b 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] shrink downstream map (#41471)

* shrink downstream map

* shrink last live ops of var

* add comment

* fix bug
上级 34f30f79
...@@ -192,7 +192,8 @@ void InterpreterCore::BuildOperatorDependences() { ...@@ -192,7 +192,8 @@ void InterpreterCore::BuildOperatorDependences() {
// Schedule // Schedule
auto op_nums = vec_instruction_.size(); auto op_nums = vec_instruction_.size();
dependecy_count_.resize(op_nums); dependecy_count_.resize(op_nums);
auto op2downstream = interpreter::build_op_downstream_map(vec_instruction_); auto op2downstream = interpreter::build_op_downstream_map(
vec_instruction_, &op_happens_before_);
for (size_t op = 0; op < vec_instruction_.size(); ++op) { for (size_t op = 0; op < vec_instruction_.size(); ++op) {
auto op_list = op2downstream[op]; auto op_list = op2downstream[op];
std::vector<size_t> downsteam_vector(op_list.begin(), op_list.end()); std::vector<size_t> downsteam_vector(op_list.begin(), op_list.end());
...@@ -213,18 +214,21 @@ void InterpreterCore::Convert( ...@@ -213,18 +214,21 @@ void InterpreterCore::Convert(
auto op_nums = nodes.size(); auto op_nums = nodes.size();
vec_instruction_.reserve(op_nums); vec_instruction_.reserve(op_nums);
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = nodes[op_idx]; auto& op_func_node = nodes[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);
auto& instr = vec_instruction_.back(); }
BuildOperatorDependences();
// calculate last_live_ops_
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& instr = vec_instruction_[op_idx];
OpInOutInfo info; OpInOutInfo info;
std::vector<size_t> gc_check_input_list; std::set<size_t> gc_check_inputs;
for (auto& item : op_func_node.input_index) { for (auto& item : instr.Inputs()) {
for (auto id : item.second) { for (auto id : item.second) {
if (id == kEmptyVarIndex) { if (id == kEmptyVarIndex) {
continue; continue;
...@@ -232,38 +236,24 @@ void InterpreterCore::Convert( ...@@ -232,38 +236,24 @@ void InterpreterCore::Convert(
input_var2op_info_.at(id).push_back(op_idx); input_var2op_info_.at(id).push_back(op_idx);
// var can be gc-ed // var can be gc-ed
if (!info.IsBuilt()) { if (!info.IsBuilt()) {
info.Build(op_func_node.operator_base_.get()); info.Build(instr.OpBase());
} }
auto* var_desc = global_scope_->VarDesc(id); auto* var_desc = global_scope_->VarDesc(id);
if (var_desc) { if (var_desc) {
if (info.IsInArgBufferNeeded(var_desc->Name())) { if (info.IsInArgBufferNeeded(var_desc->Name())) {
gc_check_input_list.push_back(id); gc_check_inputs.insert(id);
} }
} else { } else {
gc_check_input_list.push_back(id); gc_check_inputs.insert(id);
} }
} }
} }
std::sort(gc_check_input_list.begin(), gc_check_input_list.end());
auto last =
std::unique(gc_check_input_list.begin(), gc_check_input_list.end());
gc_check_input_list.erase(last, gc_check_input_list.end());
for (auto var_id : gc_check_input_list) { for (auto var_id : gc_check_inputs) {
paddle::framework::Variable* var = global_scope_->Var(var_id); paddle::framework::Variable* var = global_scope_->Var(var_id);
if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>() || if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>() ||
var->IsType<LoDTensorArray>()) { var->IsType<LoDTensorArray>()) {
vec_meta_info[var_id].var_ref_count_++; last_live_ops_[var_id].insert(op_idx);
// TODO(zhiqiu): not all var needs to be checked, var need to be checked
// only
// after the last_live_op. For example,
// b = op1(a)
// c = op2(a, b)
// in this case, a is the input of op1 and op2, we only need to check
// a after op2, because op2 always uses a after op1.
instr.AddGCCheckVar(var_id);
VLOG(4) << "clear " << global_scope_->GetNameById(var_id) << " after "
<< instr.OpBase()->Type();
} else { } else {
VLOG(4) << "not clear " << global_scope_->GetNameById(var_id) VLOG(4) << "not clear " << global_scope_->GetNameById(var_id)
<< " after " << instr.OpBase()->Type() << " after " << instr.OpBase()->Type()
...@@ -276,19 +266,45 @@ void InterpreterCore::Convert( ...@@ -276,19 +266,45 @@ void InterpreterCore::Convert(
for (size_t i = 0; i < vec_instruction_.size(); ++i) { for (size_t i = 0; i < vec_instruction_.size(); ++i) {
// checkout ouput // checkout ouput
for (auto& item : vec_instruction_[i].Outputs()) { for (auto& item : vec_instruction_[i].Outputs()) {
for (auto id : item.second) { for (auto var_id : item.second) {
if (input_var2op_info_.at(id).size() == 0) { if (input_var2op_info_.at(var_id).size() == 0) {
// output var not be used by any kernel last_live_ops_[var_id].insert(i);
vec_instruction_[i].AddGCCheckVar(id);
VLOG(4) << "clear " << global_scope_->GetNameById(id) << " after "
<< vec_instruction_[i].OpBase()->Type();
vec_meta_info[id].var_ref_count_++;
} }
} }
} }
} }
BuildOperatorDependences(); // shrink, find the downstream op that has no other op in the
// downstream list happens before it
// For example,
// b = op1(a)
// c = op2(a, b)
// in this case, a is the input of op1 and op2, we only need to check
// a after op2, because op2 always uses a after op1.
for (size_t i = 0; i < last_live_ops_.size(); ++i) {
std::set<size_t> minumum_last_live_ops;
for (size_t item : last_live_ops_[i]) {
bool not_before_any = true;
// find the op that is not executed before any
for (size_t other_item : last_live_ops_[i]) {
if (op_happens_before_[item][other_item]) {
VLOG(8) << "happens_before: " << item << "->" << other_item
<< ", so skip " << item;
not_before_any = false;
break;
}
}
if (not_before_any) {
VLOG(8) << "last live op of var " << i << " "
<< global_scope_->GetNameById(i) << " : " << item << " "
<< vec_instruction_[item].OpBase()->Type();
minumum_last_live_ops.insert(item);
vec_instruction_[item].AddGCCheckVar(i);
}
}
last_live_ops_[i] = minumum_last_live_ops;
vec_meta_info[i].var_ref_count_ = last_live_ops_[i].size();
}
for (size_t i = 0; i < vec_instruction_.size(); ++i) { for (size_t i = 0; i < vec_instruction_.size(); ++i) {
BuildAndCacheInstructionCtx(&vec_instruction_[i]); BuildAndCacheInstructionCtx(&vec_instruction_[i]);
......
...@@ -109,6 +109,11 @@ class InterpreterCore { ...@@ -109,6 +109,11 @@ class InterpreterCore {
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
// op_happens_before_[i][j] == true means op[i] happens before op[j]
std::vector<std::vector<bool>> op_happens_before_;
// last_live_ops_[i] contains the id of operatos that last access var[i]
std::map<size_t, std::set<size_t>> last_live_ops_;
std::vector<size_t> dependecy_count_; std::vector<size_t> dependecy_count_;
std::atomic<size_t> unfinished_op_numer_{0}; std::atomic<size_t> unfinished_op_numer_{0};
std::vector<std::vector<size_t>> input_var2op_info_; std::vector<std::vector<size_t>> input_var2op_info_;
......
...@@ -614,23 +614,125 @@ void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences, ...@@ -614,23 +614,125 @@ void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
} }
std::map<int, std::list<int>> get_downstream_map( std::map<int, std::list<int>> get_downstream_map(
const std::map<int, std::set<int>>& op2dependences) { const std::map<int, std::set<int>>& op2dependences,
// op2dependences is op -> it's dependences. we want to get op -> [ops] map, std::vector<std::vector<bool>>* op_happens_before) {
// step1: convert op2dependences to downstream_map directly
// op2dependences is op -> it's dependences.
// we want to get op -> [next ops] map,
// where ops is the next instruction of op. // where ops is the next instruction of op.
std::map<int, std::list<int>> result; std::map<int, std::list<int>> downstream;
for (auto& item : op2dependences) { for (auto& item : op2dependences) {
int op = item.first; int op = item.first;
for (auto dep_op : item.second) { for (auto dep_op : item.second) {
if (result.find(dep_op) == result.end()) if (downstream.find(dep_op) == downstream.end())
result[dep_op] = std::list<int>(); downstream[dep_op] = std::list<int>();
result[dep_op].push_back(op); downstream[dep_op].push_back(op);
} }
} }
return std::move(result);
auto downstream_map_to_str = [&]() -> std::string {
std::ostringstream oss;
for (auto pair : downstream) {
oss << pair.first << " -> ";
std::copy(pair.second.begin(), pair.second.end(),
std::ostream_iterator<int>(oss, " "));
oss << std::endl;
}
return oss.str();
};
auto downstream_map_count = [&]() -> size_t {
size_t count = 0;
for (auto pair : downstream) {
count += pair.second.size();
}
return count;
};
VLOG(6) << "downstream count: " << downstream_map_count();
VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str();
// step2: remove unneccessary downstream ops
// for example, a->b->c
// a: b, c
// b: c
// =>
// a: b
// b: c
// NOTE(zhiqiu): the size of downstream != size of op2dependences
// since there are some ops that have no downstream-op.
auto op_num = op2dependences.size();
// happens_before[i][j] means i should be executed before j
op_happens_before->resize(op_num);
for (size_t i = 0; i < op_num; ++i) {
(*op_happens_before)[i].resize(op_num);
std::fill((*op_happens_before)[i].begin(), (*op_happens_before)[i].end(),
false);
}
// bfs to get all next ops
auto bfs = [&](size_t op_idx) {
std::queue<size_t> q;
std::vector<bool> visited(op_num, false);
q.push(op_idx);
while (!q.empty()) {
size_t op = q.front();
q.pop();
visited[op] = true;
if (!downstream.count(op)) {
continue;
}
for (auto next : downstream[op]) {
if (!visited[next]) {
PADDLE_ENFORCE_EQ((*op_happens_before)[next][op_idx], false,
paddle::platform::errors::AlreadyExists(
"There exists circle in graph, expected "
"%d->%d, but already got %d->%d",
op_idx, next, next, op_idx));
(*op_happens_before)[op_idx][next] = true;
VLOG(8) << "happens before: " << op_idx << " " << next;
q.push(next);
}
}
}
};
for (size_t i = 0; i < op_num; ++i) {
bfs(i);
}
// shrink, find the downstream op that has no other op in the
// downstream list happens before it
for (size_t i = 0; i < op_num; ++i) {
std::list<int> minumum_nexts;
for (size_t item : downstream[i]) {
bool not_after_any = true;
// find the op that is not executed after any
for (size_t other_item : downstream[i]) {
if ((*op_happens_before)[other_item][item]) {
VLOG(8) << "happens_before: " << other_item << "->" << item
<< ", so skip " << item;
not_after_any = false;
break;
}
}
if (not_after_any) {
VLOG(8) << "downstream op of " << i << ": " << item;
minumum_nexts.push_back(item);
}
}
downstream[i] = minumum_nexts;
}
VLOG(6) << "downstream count: " << downstream_map_count();
VLOG(6) << "downstream_map: " << std::endl << downstream_map_to_str();
return std::move(downstream);
} }
std::map<int, std::list<int>> build_op_downstream_map( std::map<int, std::list<int>> build_op_downstream_map(
const std::vector<Instruction>& vec_instruction) { const std::vector<Instruction>& vec_instruction,
std::vector<std::vector<bool>>* op_happens_before) {
auto var2min_rw_op = std::map< auto var2min_rw_op = std::map<
int, std::list<int>>(); // # map from variable id to read / write op id. int, std::list<int>>(); // # map from variable id to read / write op id.
auto var2recent_write_op = auto var2recent_write_op =
...@@ -873,13 +975,13 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -873,13 +975,13 @@ std::map<int, std::list<int>> build_op_downstream_map(
} }
} }
for (auto pair : op2dependences) { for (auto pair : op2dependences) {
VLOG(10) << pair.first << " Depends on " << pair.second.size();
std::ostringstream oss; std::ostringstream oss;
oss << pair.first << " Depends on " << pair.second.size() << " ops: ";
std::copy(pair.second.begin(), pair.second.end(), std::copy(pair.second.begin(), pair.second.end(),
std::ostream_iterator<int>(oss, " ")); std::ostream_iterator<int>(oss, " "));
VLOG(10) << oss.str(); VLOG(10) << oss.str();
} }
return std::move(get_downstream_map(op2dependences)); return std::move(get_downstream_map(op2dependences, op_happens_before));
} }
} // namespace interpreter } // namespace interpreter
......
...@@ -116,7 +116,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -116,7 +116,8 @@ void build_op_func_list(const platform::Place& place,
VariableScope* var_scope, bool use_local_scope = true); VariableScope* var_scope, bool use_local_scope = true);
std::map<int, std::list<int>> build_op_downstream_map( std::map<int, std::list<int>> build_op_downstream_map(
const std::vector<Instruction>& vec_instruction); const std::vector<Instruction>& vec_instruction,
std::vector<std::vector<bool>>* op_happens_before);
void add_fetch(const std::vector<std::string>& fetch_names, void add_fetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block); framework::BlockDesc* block);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册