未验证 提交 fe35496b 编写于 作者: A Aurelius84 提交者: GitHub

Modify H2D and D2H as kQueue::Sync and Polish Schedule logic (#35866)

* Modify H2D and D2H as kQueue::Sync

* fix interface error
上级 ae65257d
......@@ -24,9 +24,12 @@ void EventManager::WaitEvent(const Instruction& instruction,
VLOG(3) << "Deal StreamWaitEventOrSync for "
<< instruction.kernel_func_.operator_base_->Type();
auto* dev_ctx = instruction.dev_ctx_;
WaitOrSync(instruction.intput_events_, dev_ctx);
for (auto& event_iter : instruction.intput_events_) {
VLOG(3) << "wait var_id: " << event_iter.var_id_
<< " 's event with waiter_type: " << event_iter.waiter_type_;
event_iter.event_->Wait(event_iter.waiter_type_, instruction.dev_ctx_);
}
}
void EventManager::RecordEvent(const Instruction& instruction,
......@@ -40,18 +43,5 @@ void EventManager::RecordEvent(const Instruction& instruction,
}
}
void EventManager::WaitOrSync(const std::vector<EventInter>& events,
const platform::DeviceContext* dev_ctx) {
for (auto& event_iter : events) {
if (event_iter.is_sync_) {
VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_;
event_iter.event_->Wait(platform::kCPU, dev_ctx);
} else {
VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_;
event_iter.event_->Wait(platform::kCUDA, dev_ctx);
}
}
}
} // namespace framework
} // namespace paddle
......@@ -24,10 +24,6 @@ class EventManager {
const platform::Place& place);
void WaitEvent(const Instruction& instruction, const platform::Place& place);
private:
void WaitOrSync(const std::vector<EventInter>& events,
const platform::DeviceContext* dev_ctx);
};
} // namespace framework
......
......@@ -183,8 +183,7 @@ void InterpreterCore::Convert() {
}
}
stream_analyzer_.Schedule(vec_func_list_, filter_next, i,
&vec_instruction_);
stream_analyzer_.Schedule(filter_next, &vec_instruction_, i);
for (auto inst_id : filter_next) {
dependecy_count_[inst_id]++;
......
......@@ -99,7 +99,6 @@ class InterpreterCore {
InterpreterCoreGarbageCollector gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_;
std::unique_ptr<WorkQueueGroup> group_thread_pool_;
};
} // namespace framework
} // namespace paddle
......@@ -365,7 +365,9 @@ void build_op_func_list(const platform::Place& place,
OpKernelComputeFunc(kernel_iter->second);
copy_op_func_node.kernel_func_(copy_exec_ctx);
VLOG(3) << "Run " << memcpy_op_type << " done.";
copy_op_func_node.type_ = OpFuncType::kQueueAsync;
// NOTE(Aurelius84): memcpy_op is expensive operation, so we tag them
// as kQueueSync and execute them in thread pool.
copy_op_func_node.type_ = OpFuncType::kQueueSync;
copy_op_func_node.dev_ctx_ = dev_ctx;
op_list->push_back(copy_op);
vec_func_list->push_back(copy_op_func_node);
......
......@@ -25,11 +25,6 @@
namespace paddle {
namespace framework {
namespace interpretercore {
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
} // namespace interpretercore
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap =
std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>;
......@@ -496,11 +491,11 @@ struct NextInstruction {
struct EventInter {
explicit EventInter(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
bool is_sync)
: var_id_(var_id), event_(event), is_sync_(is_sync) {}
platform::DeviceType waiter_type)
: var_id_(var_id), event_(event), waiter_type_(waiter_type) {}
size_t var_id_;
std::shared_ptr<platform::DeviceEvent> event_;
bool is_sync_;
platform::DeviceType waiter_type_;
};
struct InstructionInfo {
......@@ -543,5 +538,18 @@ struct OpFuncNode {
OpFuncType type_;
};
namespace interpretercore {
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
static bool IsMemcpyH2D(const Instruction& instr) {
return instr.kernel_func_.operator_base_->Type() == kMemcpyH2D;
}
static bool IsMemcpyD2H(const Instruction& instr) {
return instr.kernel_func_.operator_base_->Type() == kMemcpyD2H;
}
} // namespace interpretercore
} // namespace framework
} // namespace paddle
......@@ -22,7 +22,7 @@ namespace framework {
* Parse the var_ids that need to be associated with an event.
* The caller should guarantee front_op and back_op satisfy the
* following conditions:
* 1. kQueueAsync -> kQueueAsync
* 1. kQueueSync -> kQueueAsync
* 2. kQueueAsync -> kQueueSync
*
* For example: matmul(gpu) -> out_var -> memcpy_d2h
......@@ -48,7 +48,7 @@ std::vector<size_t> StreamAnalyzer::ParseEventVarIds(
void StreamAnalyzer::AssociateInputWithEvents(
const std::vector<size_t>& new_event_var_id, Instruction* next_instr,
bool is_sync) {
platform::DeviceType waiter_type) {
for (auto var_id : new_event_var_id) {
if (var_id2event_.count(var_id) == 0) {
auto device_event = std::make_shared<platform::DeviceEvent>(
......@@ -57,52 +57,43 @@ void StreamAnalyzer::AssociateInputWithEvents(
}
// Add events for next_instr.inputs
next_instr->intput_events_.emplace_back(var_id, var_id2event_.at(var_id),
is_sync);
waiter_type);
}
}
void StreamAnalyzer::Schedule(const std::vector<OpFuncNode>& op_func_nodes,
const std::vector<size_t>& downstream_ops,
size_t op_index,
std::vector<Instruction>* instructions) {
auto& op_func_type = op_func_nodes[op_index].type_;
void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops,
std::vector<Instruction>* instructions,
size_t op_index) {
auto& cur_instr = instructions->at(op_index);
auto& next_instruction = cur_instr.next_instruction_;
if (op_func_type == OpFuncType::kQueueSync) {
// all downstream ops of kQueueSync can directly run, such as CPU -> Any
next_instruction.direct_run_ = downstream_ops;
} else { // kQueueAsync
std::vector<size_t> event_var_ids;
for (auto next_op_id : downstream_ops) {
auto& next_instr = instructions->at(next_op_id);
// case 1: GPU -> GPU(same stream)
if (cur_instr.dev_ctx_ == next_instr.dev_ctx_) {
if (IsDirectRun(cur_instr, next_instr)) {
next_instruction.direct_run_.emplace_back(next_op_id);
continue;
}
} else {
// Always insert events between different stream
auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr);
event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(),
new_event_var_ids.end());
bool is_sync =
(op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync);
AssociateInputWithEvents(new_event_var_ids, &next_instr, is_sync);
auto waiter_type = GetWaiterType(next_instr);
AssociateInputWithEvents(new_event_var_ids, &next_instr, waiter_type);
if (is_sync) { // GPU -> CPU
if (waiter_type == platform::kCPU) { // GPU -> CPU
next_instruction.synchronize_run_.emplace_back(next_op_id);
} else { // GPU -> GPU(different stream)
next_instruction.event_wait_run_.emplace_back(next_op_id);
}
}
}
// Create events for these cross-stream vars
VLOG(3) << cur_instr.kernel_func_.operator_base_->Type()
<< " event_var_ids.size: " << event_var_ids.size();
for (auto var_id : event_var_ids) {
cur_instr.output_events_.emplace_back(var_id, var_id2event_.at(var_id),
false /*not used*/);
}
platform::kCUDA /*not used*/);
}
}
......@@ -121,5 +112,27 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
return dev_ctx;
}
/*
* NOTE(dev): The following cases are considered as directly run:
*
* 1. with same dev_ctx_, such as: CPU -> CPU, GPU -> GPU
* 2. D2H -> CPU
* 3. CPU -> H2D
*/
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) {
return (cur_instr.dev_ctx_ == next_instr.dev_ctx_ ||
interpretercore::IsMemcpyD2H(cur_instr) ||
interpretercore::IsMemcpyH2D(next_instr));
}
platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {
if (instr.type_ == OpFuncType::kQueueSync) {
return platform::kCPU;
} else {
return platform::kCUDA;
}
}
} // namespace framework
} // namespace paddle
......@@ -29,9 +29,8 @@ class StreamAnalyzer {
~StreamAnalyzer() {}
void Schedule(const std::vector<OpFuncNode>& op_func_nodes,
const std::vector<size_t>& downstream_ops, size_t op_index,
std::vector<Instruction>* instructions);
void Schedule(const std::vector<size_t>& downstream_ops,
std::vector<Instruction>* instructions, size_t op_index);
platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node,
const OperatorBase& op_base);
......@@ -41,7 +40,14 @@ class StreamAnalyzer {
const Instruction& next_instr);
void AssociateInputWithEvents(const std::vector<size_t>& new_event_var_id,
Instruction* next_instr, bool is_sync);
Instruction* next_instr,
platform::DeviceType waiter_type);
bool IsDirectRun(Instruction& cur_instr, // NOLINT
const Instruction& next_instr);
platform::DeviceType GetWaiterType(const Instruction& instr);
platform::Place place_;
platform::DeviceContextPool d2h_ctx_pool_;
platform::DeviceContextPool h2d_ctx_pool_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册