未验证 提交 fe35496b 编写于 作者: A Aurelius84 提交者: GitHub

Modify H2D and D2H as kQueue::Sync and Polish Schedule logic (#35866)

* Modify H2D and D2H as kQueue::Sync

* fix interface error
上级 ae65257d
...@@ -24,9 +24,12 @@ void EventManager::WaitEvent(const Instruction& instruction, ...@@ -24,9 +24,12 @@ void EventManager::WaitEvent(const Instruction& instruction,
VLOG(3) << "Deal StreamWaitEventOrSync for " VLOG(3) << "Deal StreamWaitEventOrSync for "
<< instruction.kernel_func_.operator_base_->Type(); << instruction.kernel_func_.operator_base_->Type();
auto* dev_ctx = instruction.dev_ctx_;
WaitOrSync(instruction.intput_events_, dev_ctx); for (auto& event_iter : instruction.intput_events_) {
VLOG(3) << "wait var_id: " << event_iter.var_id_
<< " 's event with waiter_type: " << event_iter.waiter_type_;
event_iter.event_->Wait(event_iter.waiter_type_, instruction.dev_ctx_);
}
} }
void EventManager::RecordEvent(const Instruction& instruction, void EventManager::RecordEvent(const Instruction& instruction,
...@@ -40,18 +43,5 @@ void EventManager::RecordEvent(const Instruction& instruction, ...@@ -40,18 +43,5 @@ void EventManager::RecordEvent(const Instruction& instruction,
} }
} }
void EventManager::WaitOrSync(const std::vector<EventInter>& events,
const platform::DeviceContext* dev_ctx) {
for (auto& event_iter : events) {
if (event_iter.is_sync_) {
VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_;
event_iter.event_->Wait(platform::kCPU, dev_ctx);
} else {
VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_;
event_iter.event_->Wait(platform::kCUDA, dev_ctx);
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -24,10 +24,6 @@ class EventManager { ...@@ -24,10 +24,6 @@ class EventManager {
const platform::Place& place); const platform::Place& place);
void WaitEvent(const Instruction& instruction, const platform::Place& place); void WaitEvent(const Instruction& instruction, const platform::Place& place);
private:
void WaitOrSync(const std::vector<EventInter>& events,
const platform::DeviceContext* dev_ctx);
}; };
} // namespace framework } // namespace framework
......
...@@ -183,8 +183,7 @@ void InterpreterCore::Convert() { ...@@ -183,8 +183,7 @@ void InterpreterCore::Convert() {
} }
} }
stream_analyzer_.Schedule(vec_func_list_, filter_next, i, stream_analyzer_.Schedule(filter_next, &vec_instruction_, i);
&vec_instruction_);
for (auto inst_id : filter_next) { for (auto inst_id : filter_next) {
dependecy_count_[inst_id]++; dependecy_count_[inst_id]++;
......
...@@ -99,7 +99,6 @@ class InterpreterCore { ...@@ -99,7 +99,6 @@ class InterpreterCore {
InterpreterCoreGarbageCollector gc_; InterpreterCoreGarbageCollector gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_; std::vector<paddle::platform::DeviceEvent> gc_event_;
std::unique_ptr<WorkQueueGroup> group_thread_pool_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -365,7 +365,9 @@ void build_op_func_list(const platform::Place& place, ...@@ -365,7 +365,9 @@ void build_op_func_list(const platform::Place& place,
OpKernelComputeFunc(kernel_iter->second); OpKernelComputeFunc(kernel_iter->second);
copy_op_func_node.kernel_func_(copy_exec_ctx); copy_op_func_node.kernel_func_(copy_exec_ctx);
VLOG(3) << "Run " << memcpy_op_type << " done."; VLOG(3) << "Run " << memcpy_op_type << " done.";
copy_op_func_node.type_ = OpFuncType::kQueueAsync; // NOTE(Aurelius84): memcpy_op is expensive operation, so we tag them
// as kQueueSync and execute them in thread pool.
copy_op_func_node.type_ = OpFuncType::kQueueSync;
copy_op_func_node.dev_ctx_ = dev_ctx; copy_op_func_node.dev_ctx_ = dev_ctx;
op_list->push_back(copy_op); op_list->push_back(copy_op);
vec_func_list->push_back(copy_op_func_node); vec_func_list->push_back(copy_op_func_node);
......
...@@ -25,11 +25,6 @@ ...@@ -25,11 +25,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace interpretercore {
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
} // namespace interpretercore
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>; using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>; std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>;
...@@ -496,11 +491,11 @@ struct NextInstruction { ...@@ -496,11 +491,11 @@ struct NextInstruction {
struct EventInter { struct EventInter {
explicit EventInter(size_t var_id, explicit EventInter(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event, std::shared_ptr<platform::DeviceEvent> event,
bool is_sync) platform::DeviceType waiter_type)
: var_id_(var_id), event_(event), is_sync_(is_sync) {} : var_id_(var_id), event_(event), waiter_type_(waiter_type) {}
size_t var_id_; size_t var_id_;
std::shared_ptr<platform::DeviceEvent> event_; std::shared_ptr<platform::DeviceEvent> event_;
bool is_sync_; platform::DeviceType waiter_type_;
}; };
struct InstructionInfo { struct InstructionInfo {
...@@ -543,5 +538,18 @@ struct OpFuncNode { ...@@ -543,5 +538,18 @@ struct OpFuncNode {
OpFuncType type_; OpFuncType type_;
}; };
namespace interpretercore {
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
static bool IsMemcpyH2D(const Instruction& instr) {
return instr.kernel_func_.operator_base_->Type() == kMemcpyH2D;
}
static bool IsMemcpyD2H(const Instruction& instr) {
return instr.kernel_func_.operator_base_->Type() == kMemcpyD2H;
}
} // namespace interpretercore
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -22,7 +22,7 @@ namespace framework { ...@@ -22,7 +22,7 @@ namespace framework {
* Parse the var_ids that need to be associated with an event. * Parse the var_ids that need to be associated with an event.
* The caller should guarantee front_op and back_op satisfy the * The caller should guarantee front_op and back_op satisfy the
* following conditions: * following conditions:
* 1. kQueueAsync -> kQueueAsync * 1. kQueueSync -> kQueueAsync
* 2. kQueueAsync -> kQueueSync * 2. kQueueAsync -> kQueueSync
* *
* For example: matmul(gpu) -> out_var -> memcpy_d2h * For example: matmul(gpu) -> out_var -> memcpy_d2h
...@@ -48,7 +48,7 @@ std::vector<size_t> StreamAnalyzer::ParseEventVarIds( ...@@ -48,7 +48,7 @@ std::vector<size_t> StreamAnalyzer::ParseEventVarIds(
void StreamAnalyzer::AssociateInputWithEvents( void StreamAnalyzer::AssociateInputWithEvents(
const std::vector<size_t>& new_event_var_id, Instruction* next_instr, const std::vector<size_t>& new_event_var_id, Instruction* next_instr,
bool is_sync) { platform::DeviceType waiter_type) {
for (auto var_id : new_event_var_id) { for (auto var_id : new_event_var_id) {
if (var_id2event_.count(var_id) == 0) { if (var_id2event_.count(var_id) == 0) {
auto device_event = std::make_shared<platform::DeviceEvent>( auto device_event = std::make_shared<platform::DeviceEvent>(
...@@ -57,52 +57,43 @@ void StreamAnalyzer::AssociateInputWithEvents( ...@@ -57,52 +57,43 @@ void StreamAnalyzer::AssociateInputWithEvents(
} }
// Add events for next_instr.inputs // Add events for next_instr.inputs
next_instr->intput_events_.emplace_back(var_id, var_id2event_.at(var_id), next_instr->intput_events_.emplace_back(var_id, var_id2event_.at(var_id),
is_sync); waiter_type);
} }
} }
void StreamAnalyzer::Schedule(const std::vector<OpFuncNode>& op_func_nodes, void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops,
const std::vector<size_t>& downstream_ops, std::vector<Instruction>* instructions,
size_t op_index, size_t op_index) {
std::vector<Instruction>* instructions) {
auto& op_func_type = op_func_nodes[op_index].type_;
auto& cur_instr = instructions->at(op_index); auto& cur_instr = instructions->at(op_index);
auto& next_instruction = cur_instr.next_instruction_; auto& next_instruction = cur_instr.next_instruction_;
std::vector<size_t> event_var_ids;
for (auto next_op_id : downstream_ops) {
auto& next_instr = instructions->at(next_op_id);
if (op_func_type == OpFuncType::kQueueSync) { if (IsDirectRun(cur_instr, next_instr)) {
// all downstream ops of kQueueSync can directly run, such as CPU -> Any next_instruction.direct_run_.emplace_back(next_op_id);
next_instruction.direct_run_ = downstream_ops; } else {
} else { // kQueueAsync
std::vector<size_t> event_var_ids;
for (auto next_op_id : downstream_ops) {
auto& next_instr = instructions->at(next_op_id);
// case 1: GPU -> GPU(same stream)
if (cur_instr.dev_ctx_ == next_instr.dev_ctx_) {
next_instruction.direct_run_.emplace_back(next_op_id);
continue;
}
// Always insert events between different stream // Always insert events between different stream
auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr); auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr);
event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(), event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(),
new_event_var_ids.end()); new_event_var_ids.end());
bool is_sync = auto waiter_type = GetWaiterType(next_instr);
(op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync); AssociateInputWithEvents(new_event_var_ids, &next_instr, waiter_type);
AssociateInputWithEvents(new_event_var_ids, &next_instr, is_sync);
if (is_sync) { // GPU -> CPU if (waiter_type == platform::kCPU) { // GPU -> CPU
next_instruction.synchronize_run_.emplace_back(next_op_id); next_instruction.synchronize_run_.emplace_back(next_op_id);
} else { // GPU -> GPU(different stream) } else { // GPU -> GPU(different stream)
next_instruction.event_wait_run_.emplace_back(next_op_id); next_instruction.event_wait_run_.emplace_back(next_op_id);
} }
} }
// Create events for these cross-stream vars }
VLOG(3) << cur_instr.kernel_func_.operator_base_->Type() // Create events for these cross-stream vars
<< " event_var_ids.size: " << event_var_ids.size(); VLOG(3) << cur_instr.kernel_func_.operator_base_->Type()
for (auto var_id : event_var_ids) { << " event_var_ids.size: " << event_var_ids.size();
cur_instr.output_events_.emplace_back(var_id, var_id2event_.at(var_id), for (auto var_id : event_var_ids) {
false /*not used*/); cur_instr.output_events_.emplace_back(var_id, var_id2event_.at(var_id),
} platform::kCUDA /*not used*/);
} }
} }
...@@ -121,5 +112,27 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -121,5 +112,27 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
return dev_ctx; return dev_ctx;
} }
/*
* NOTE(dev): The following cases are considered as directly run:
*
* 1. with same dev_ctx_, such as: CPU -> CPU, GPU -> GPU
* 2. D2H -> CPU
* 3. CPU -> H2D
*/
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) {
return (cur_instr.dev_ctx_ == next_instr.dev_ctx_ ||
interpretercore::IsMemcpyD2H(cur_instr) ||
interpretercore::IsMemcpyH2D(next_instr));
}
platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {
if (instr.type_ == OpFuncType::kQueueSync) {
return platform::kCPU;
} else {
return platform::kCUDA;
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -29,9 +29,8 @@ class StreamAnalyzer { ...@@ -29,9 +29,8 @@ class StreamAnalyzer {
~StreamAnalyzer() {} ~StreamAnalyzer() {}
void Schedule(const std::vector<OpFuncNode>& op_func_nodes, void Schedule(const std::vector<size_t>& downstream_ops,
const std::vector<size_t>& downstream_ops, size_t op_index, std::vector<Instruction>* instructions, size_t op_index);
std::vector<Instruction>* instructions);
platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node, platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node,
const OperatorBase& op_base); const OperatorBase& op_base);
...@@ -41,7 +40,14 @@ class StreamAnalyzer { ...@@ -41,7 +40,14 @@ class StreamAnalyzer {
const Instruction& next_instr); const Instruction& next_instr);
void AssociateInputWithEvents(const std::vector<size_t>& new_event_var_id, void AssociateInputWithEvents(const std::vector<size_t>& new_event_var_id,
Instruction* next_instr, bool is_sync); Instruction* next_instr,
platform::DeviceType waiter_type);
bool IsDirectRun(Instruction& cur_instr, // NOLINT
const Instruction& next_instr);
platform::DeviceType GetWaiterType(const Instruction& instr);
platform::Place place_; platform::Place place_;
platform::DeviceContextPool d2h_ctx_pool_; platform::DeviceContextPool d2h_ctx_pool_;
platform::DeviceContextPool h2d_ctx_pool_; platform::DeviceContextPool h2d_ctx_pool_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册