未验证 提交 e7d459ac 编写于 作者: R Ruibiao Chen 提交者: GitHub

Remove kSyncRun in StreamAnalyzer (#48425)

* Remove kSyncRun in StreamAnalyzer

* Update code
上级 2bae75ed
......@@ -64,8 +64,6 @@ class ContextManager {
inline std::string RunTypeToString(DownstreamRunType run_type) {
if (run_type == DownstreamRunType::kDirectRun) {
return "DirectRun";
} else if (run_type == DownstreamRunType::kSyncRun) {
return "SyncRun";
} else {
return "EventRun";
}
......@@ -238,16 +236,10 @@ void StreamAnalyzer::AnalyseAllEventInfo(
event_info) const {
for (size_t cur_instr_id = 0; cur_instr_id < instructions.size();
++cur_instr_id) {
const Instruction& cur_instr = instructions[cur_instr_id];
const std::vector<std::vector<size_t>>& next_instr_list =
run_type_info[cur_instr_id];
const std::vector<size_t>& next_instr_ids =
run_type_info[cur_instr_id][DownstreamRunType::kEventRun];
std::set<size_t> waiter_instr_ids;
std::vector<size_t> next_instr_ids =
next_instr_list[DownstreamRunType::kSyncRun];
next_instr_ids.insert(next_instr_ids.end(),
next_instr_list[DownstreamRunType::kEventRun].begin(),
next_instr_list[DownstreamRunType::kEventRun].end());
for (size_t next_instr_id : next_instr_ids) {
AnalyseEventInfoForTwoInstructions(instructions,
run_type_info,
......@@ -257,8 +249,9 @@ void StreamAnalyzer::AnalyseAllEventInfo(
}
for (size_t waiter_instr_id : waiter_instr_ids) {
(*event_info)[&(cur_instr.DeviceContext())][waiter_instr_id].insert(
cur_instr_id);
(*event_info)[&(instructions[cur_instr_id].DeviceContext())]
[waiter_instr_id]
.insert(cur_instr_id);
}
}
}
......@@ -284,7 +277,7 @@ void StreamAnalyzer::AnalyseAllRunType(
}
}
// The caller should guarantee cur_instr and next_instr is kSyncRun or kEventRun
// The caller should guarantee cur_instr and next_instr is kEventRun
void StreamAnalyzer::AnalyseEventInfoForTwoInstructions(
const std::vector<Instruction>& instructions,
const std::vector<std::vector<std::vector<size_t>>>& run_type_info,
......@@ -311,7 +304,6 @@ void StreamAnalyzer::AnalyseEventInfoForTwoInstructions(
// can only add event for it with the help of depend_op.
if (HasDataDependency(instructions[cur_instr_id],
instructions[next_instr_id]) ||
run_type_info[next_instr_id][DownstreamRunType::kSyncRun].size() ||
run_type_info[next_instr_id][DownstreamRunType::kEventRun].size() ||
instructions[next_instr_id].OpBase()->Type() == "depend") {
waiter_instr_ids->insert(next_instr_id);
......@@ -319,8 +311,8 @@ void StreamAnalyzer::AnalyseEventInfoForTwoInstructions(
}
// NOTE(Ruibiao): If no data dependency from cur_instr to next_instr, and
// simultaneously next_instr has neither sync_run nor event_run downstream
// instr, we try to recursively add events between cur_instr and next_instr's
// simultaneously next_instr has no event_run downstream instr, we try to
// recursively add events between cur_instr and next_instr's
// direct-run-instrs. This can delay the event wait and achieve better
// scheduling performance in some scenarios. However, when next_instr has too
// many direct-run-instrs, it may perform worse than add event directly
......@@ -393,16 +385,10 @@ DownstreamRunType StreamAnalyzer::AnalyseRunTypeForTwoInstructions(
}
}
if (cur_instr.KernelType() == OpFuncType::kGpuAsync) {
if (next_instr.KernelType() == OpFuncType::kCpuSync) {
return DownstreamRunType::kSyncRun;
} else {
// cross-stream: kGpuAsync -> kGpuSync, kGpuAsync -> kGpuSync
if (&cur_instr.DeviceContext() != &next_instr.DeviceContext()) {
if (cur_instr.KernelType() == OpFuncType::kGpuAsync &&
(&cur_instr.DeviceContext() != &next_instr.DeviceContext())) {
return DownstreamRunType::kEventRun;
}
}
}
return DownstreamRunType::kDirectRun;
}
......
......@@ -26,7 +26,7 @@ namespace paddle {
namespace framework {
namespace interpreter {
enum DownstreamRunType { kDirectRun, kSyncRun, kEventRun };
enum DownstreamRunType { kDirectRun, kEventRun };
class StreamAnalyzer {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册