未验证 提交 8528dd9f 编写于 作者: A Aurelius84 提交者: GitHub

Refactor StreamAnalyzer and EventManager from InterpreterCore (#35711)

上级 85e4f45a
...@@ -5,7 +5,9 @@ graph_to_program_pass variable_helper timer monitor) ...@@ -5,7 +5,9 @@ graph_to_program_pass variable_helper timer monitor)
cc_library(workqueue SRCS workqueue.cc DEPS enforce) cc_library(workqueue SRCS workqueue.cc DEPS enforce)
cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS}) cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS})
cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS}) cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS})
cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector) cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog)
cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context)
cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager)
cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore) cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore)
cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue) cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue)
# cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) # cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/event_manager.h"
namespace paddle {
namespace framework {
void EventManager::WaitEvent(const Instruction& instruction,
const platform::Place& place) {
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return;
VLOG(3) << "Deal StreamWaitEventOrSync for "
<< instruction.kernel_func_.operator_base_->Type();
auto* dev_ctx = instruction.dev_ctx_;
WaitOrSync(instruction.intput_events_, dev_ctx);
}
void EventManager::RecordEvent(const Instruction& instruction,
const OpFuncNode& op_func_node,
const platform::Place& place) {
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) return;
for (auto& event : instruction.output_events_) {
VLOG(3) << "Record event in out_var_id: " << event.var_id_;
event.event_->Record(instruction.dev_ctx_);
}
}
void EventManager::WaitOrSync(const std::vector<EventInter>& events,
const platform::DeviceContext* dev_ctx) {
for (auto& event_iter : events) {
if (event_iter.is_sync_) {
VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_;
event_iter.event_->Wait(platform::kCPU, dev_ctx);
} else {
VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_;
event_iter.event_->Wait(platform::kCUDA, dev_ctx);
}
}
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
namespace paddle {
namespace framework {
class EventManager {
public:
void RecordEvent(const Instruction& instruction,
const OpFuncNode& op_func_node,
const platform::Place& place);
void WaitEvent(const Instruction& instruction, const platform::Place& place);
private:
void WaitOrSync(const std::vector<EventInter>& events,
const platform::DeviceContext* dev_ctx);
};
} // namespace framework
} // namespace paddle
...@@ -20,101 +20,6 @@ ...@@ -20,101 +20,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace {
/*
* Parse the var_ids that need to be associated with an event.
* The caller should guarantee front_op and back_op satisfy the
* following conditions:
* 1. kQueueAsync -> kQueueAsync
* 2. kQueueAsync -> kQueueSync
*
* For example: matmul(gpu) -> out_var -> memcpy_d2h
* out_var should be associated with an event.
*/
std::vector<size_t> ParseEventVarIds(const Instruction& cur_instr,
const Instruction& next_instr) {
std::unordered_set<size_t> unique_var_ids;
for (auto& item : cur_instr.output_index_) {
unique_var_ids.insert(item.second.begin(), item.second.end());
}
std::vector<size_t> new_event_var_ids;
for (auto& item : next_instr.input_index_) {
for (auto var_id : item.second) {
if (unique_var_ids.count(var_id) > 0) {
new_event_var_ids.push_back(var_id);
}
}
}
return new_event_var_ids;
}
void AssociateInputWithEvents(
const platform::Place& place, const std::vector<size_t>& new_event_var_id,
Instruction* next_instr,
std::map<size_t, std::shared_ptr<platform::DeviceEvent>>* var_id2event,
bool is_sync) {
for (auto var_id : new_event_var_id) {
if (var_id2event->count(var_id) == 0) {
auto device_event = std::make_shared<platform::DeviceEvent>(
place, platform::GenerateDeviceEventFlag());
var_id2event->emplace(var_id, std::move(device_event));
}
// Add events for next_instr.inputs
next_instr->intput_events_.emplace_back(var_id, var_id2event->at(var_id),
is_sync);
}
}
void ParseDirectAndEventRunOps(
const platform::Place& place, const std::vector<OpFuncNode>& op_func_nodes,
const std::vector<size_t>& downstream_ops, size_t op_index,
std::map<size_t, std::shared_ptr<platform::DeviceEvent>>* var_id2event,
std::vector<Instruction>* instructions) {
auto& op_func_type = op_func_nodes[op_index].type_;
auto& cur_instr = instructions->at(op_index);
auto& next_instruction = cur_instr.next_instruction_;
if (op_func_type == OpFuncType::kQueueSync) {
// all downstream ops of kQueueSync can directly run, such as CPU -> Any
next_instruction.direct_run_ = downstream_ops;
} else { // kQueueAsync
std::vector<size_t> event_var_ids;
for (auto next_op_id : downstream_ops) {
auto& next_instr = instructions->at(next_op_id);
// case 1: GPU -> GPU(same stream)
if (cur_instr.dev_ctx_ == next_instr.dev_ctx_) {
next_instruction.direct_run_.emplace_back(next_op_id);
continue;
}
// Always insert events between different stream
auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr);
event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(),
new_event_var_ids.end());
bool is_sync =
(op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync);
AssociateInputWithEvents(place, new_event_var_ids, &next_instr,
var_id2event, is_sync);
if (is_sync) { // GPU -> CPU
next_instruction.synchronize_run_.emplace_back(next_op_id);
} else { // GPU -> GPU(different stream)
next_instruction.event_wait_run_.emplace_back(next_op_id);
}
}
// Create events for these cross-stream vars
VLOG(3) << cur_instr.kernel_func_.operator_base_->Type()
<< " event_var_ids.size: " << event_var_ids.size();
for (auto var_id : event_var_ids) {
cur_instr.output_events_.emplace_back(var_id, var_id2event->at(var_id),
false /*not used*/);
}
}
}
} // namespace
InterpreterCore::InterpreterCore(const platform::Place& place, InterpreterCore::InterpreterCore(const platform::Place& place,
const ProgramDesc& main_prog, const ProgramDesc& main_prog,
VariableScope* global_scope, VariableScope* global_scope,
...@@ -123,8 +28,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -123,8 +28,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
: place_(place), : place_(place),
main_program_(main_prog), main_program_(main_prog),
global_scope_(global_scope), global_scope_(global_scope),
d2h_ctx_pool_({place}), stream_analyzer_(place) {
h2d_ctx_pool_({place}) {
is_build_ = false; is_build_ = false;
feed_names_ = feed_names; feed_names_ = feed_names;
...@@ -199,7 +103,7 @@ void InterpreterCore::Convert() { ...@@ -199,7 +103,7 @@ void InterpreterCore::Convert() {
Instruction temp_inst; Instruction temp_inst;
auto* op_base = op_list_[i]; auto* op_base = op_list_[i];
temp_inst.dev_ctx_ = temp_inst.dev_ctx_ =
ParseDeviceContextForInstruction(vec_func_list_[i], *op_base); stream_analyzer_.ParseDeviceContext(vec_func_list_[i], *op_base);
temp_inst.kernel_func_.compute_func_ = vec_func_list_[i].kernel_func_; temp_inst.kernel_func_.compute_func_ = vec_func_list_[i].kernel_func_;
temp_inst.kernel_func_.operator_base_ = op_base; temp_inst.kernel_func_.operator_base_ = op_base;
temp_inst.input_index_ = vec_func_list_[i].input_index; temp_inst.input_index_ = vec_func_list_[i].input_index;
...@@ -270,8 +174,8 @@ void InterpreterCore::Convert() { ...@@ -270,8 +174,8 @@ void InterpreterCore::Convert() {
} }
} }
ParseDirectAndEventRunOps(place_, vec_func_list_, filter_next, i, stream_analyzer_.Schedule(vec_func_list_, filter_next, i,
&var_id2event_, &vec_instruction_); &vec_instruction_);
for (auto inst_id : filter_next) { for (auto inst_id : filter_next) {
dependecy_count_[inst_id]++; dependecy_count_[inst_id]++;
...@@ -361,7 +265,7 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -361,7 +265,7 @@ void InterpreterCore::ExecuteInstructionList(
working_queue.pop(); working_queue.pop();
auto& instr_node = vec_instr[instr_id]; auto& instr_node = vec_instr[instr_id];
// step1 : stream_wait (non-block host) or sync (block host) // step1 : stream_wait (non-block host) or sync (block host)
StreamWaitEventOrSync(instr_node); event_manager_.WaitEvent(instr_node, place_);
// step2: run instruction // step2: run instruction
RunInstruction(instr_node); RunInstruction(instr_node);
++run_op_number; ++run_op_number;
...@@ -371,7 +275,7 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -371,7 +275,7 @@ void InterpreterCore::ExecuteInstructionList(
} }
// step3: insert event for out_vars if needed // step3: insert event for out_vars if needed
RecordEventInstruction(instr_node, vec_func_list_[instr_id]); event_manager_.RecordEvent(instr_node, vec_func_list_[instr_id], place_);
// step4: update working_queue // step4: update working_queue
auto& next_instr = instr_node.next_instruction_.all_next_ops_; auto& next_instr = instr_node.next_instruction_.all_next_ops_;
...@@ -450,54 +354,5 @@ const CostInfo& InterpreterCore::DryRun( ...@@ -450,54 +354,5 @@ const CostInfo& InterpreterCore::DryRun(
return dry_run_profiler_.GetCostInfo(); return dry_run_profiler_.GetCostInfo();
} }
platform::DeviceContext* InterpreterCore::ParseDeviceContextForInstruction(
const OpFuncNode& op_func_node, const OperatorBase& op_base) {
auto& op_type = op_base.Type();
auto* dev_ctx = op_func_node.dev_ctx_;
if (op_type == interpretercore::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctx_pool_.Get(place_);
} else if (op_type == interpretercore::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from h2d_context_pool_";
dev_ctx = h2d_ctx_pool_.Get(place_);
}
return dev_ctx;
}
void InterpreterCore::RecordEventInstruction(const Instruction& instruction,
const OpFuncNode& op_func_node) {
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place_)) return;
for (auto& event : instruction.output_events_) {
VLOG(3) << "Record event in out_var_id: " << event.var_id_;
event.event_->Record(instruction.dev_ctx_);
}
}
void InterpreterCore::WaitOrSync(const std::vector<EventInter>& events,
const platform::DeviceContext* dev_ctx) {
for (auto& event_iter : events) {
if (event_iter.is_sync_) {
VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_;
event_iter.event_->Wait(platform::kCPU, dev_ctx);
} else {
VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_;
event_iter.event_->Wait(platform::kCUDA, dev_ctx);
}
}
}
void InterpreterCore::StreamWaitEventOrSync(const Instruction& instruction) {
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place_)) return;
VLOG(3) << "Deal StreamWaitEventOrSync for "
<< instruction.kernel_func_.operator_base_->Type();
auto* dev_ctx = instruction.dev_ctx_;
WaitOrSync(instruction.intput_events_, dev_ctx);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/new_executor/event_manager.h"
#include "paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h" #include "paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/interpretercore_util.h" #include "paddle/fluid/framework/new_executor/interpretercore_util.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h" #include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/profiler.h" #include "paddle/fluid/framework/new_executor/profiler.h"
#include "paddle/fluid/framework/new_executor/stream_analyzer.h"
#include "paddle/fluid/framework/new_executor/workqueue.h" #include "paddle/fluid/framework/new_executor/workqueue.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -64,17 +66,6 @@ class InterpreterCore { ...@@ -64,17 +66,6 @@ class InterpreterCore {
const VariableScope& var_scope, const platform::Place& place, const VariableScope& var_scope, const platform::Place& place,
std::vector<VariableMetaInfo>& working_var_ref); // NOLINT std::vector<VariableMetaInfo>& working_var_ref); // NOLINT
platform::DeviceContext* ParseDeviceContextForInstruction(
const OpFuncNode& op_func_node, const OperatorBase& op_base);
void RecordEventInstruction(const Instruction& instruction,
const OpFuncNode& op_func_node);
void WaitOrSync(const std::vector<EventInter>& events,
const platform::DeviceContext* dev_ctx);
void StreamWaitEventOrSync(const Instruction& instruction);
void AddFetch(const std::vector<std::string>& fetch_names); void AddFetch(const std::vector<std::string>& fetch_names);
bool is_build_; bool is_build_;
...@@ -83,9 +74,6 @@ class InterpreterCore { ...@@ -83,9 +74,6 @@ class InterpreterCore {
ProgramDesc main_program_; ProgramDesc main_program_;
VariableScope* global_scope_; VariableScope* global_scope_;
platform::DeviceContextPool d2h_ctx_pool_;
platform::DeviceContextPool h2d_ctx_pool_;
std::vector<Instruction> vec_instruction_; std::vector<Instruction> vec_instruction_;
InstructionInfo instruction_info_; InstructionInfo instruction_info_;
std::vector<size_t> dependecy_count_; std::vector<size_t> dependecy_count_;
...@@ -99,8 +87,8 @@ class InterpreterCore { ...@@ -99,8 +87,8 @@ class InterpreterCore {
std::vector<std::string> feed_names_; std::vector<std::string> feed_names_;
InterpreterProfiler dry_run_profiler_; InterpreterProfiler dry_run_profiler_;
StreamAnalyzer stream_analyzer_;
std::map<size_t, std::shared_ptr<platform::DeviceEvent>> var_id2event_; EventManager event_manager_;
InterpreterCoreGarbageCollector gc_; InterpreterCoreGarbageCollector gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_; std::vector<paddle::platform::DeviceEvent> gc_event_;
......
...@@ -476,9 +476,6 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -476,9 +476,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
namespace interpretercore { namespace interpretercore {
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
std::string get_memcpy_type(const platform::Place& src_place, std::string get_memcpy_type(const platform::Place& src_place,
const platform::Place& dst_place); const platform::Place& dst_place);
......
...@@ -25,6 +25,11 @@ ...@@ -25,6 +25,11 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace interpretercore {
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
} // namespace interpretercore
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>; using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>; std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>;
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/stream_analyzer.h"
#include <unordered_set>
namespace paddle {
namespace framework {
/*
* Parse the var_ids that need to be associated with an event.
* The caller should guarantee front_op and back_op satisfy the
* following conditions:
* 1. kQueueAsync -> kQueueAsync
* 2. kQueueAsync -> kQueueSync
*
* For example: matmul(gpu) -> out_var -> memcpy_d2h
* out_var should be associated with an event.
*/
std::vector<size_t> StreamAnalyzer::ParseEventVarIds(
const Instruction& cur_instr, const Instruction& next_instr) {
std::unordered_set<size_t> unique_var_ids;
for (auto& item : cur_instr.output_index_) {
unique_var_ids.insert(item.second.begin(), item.second.end());
}
std::vector<size_t> new_event_var_ids;
for (auto& item : next_instr.input_index_) {
for (auto var_id : item.second) {
if (unique_var_ids.count(var_id) > 0) {
new_event_var_ids.push_back(var_id);
}
}
}
return new_event_var_ids;
}
void StreamAnalyzer::AssociateInputWithEvents(
const std::vector<size_t>& new_event_var_id, Instruction* next_instr,
bool is_sync) {
for (auto var_id : new_event_var_id) {
if (var_id2event_.count(var_id) == 0) {
auto device_event = std::make_shared<platform::DeviceEvent>(
place_, platform::GenerateDeviceEventFlag());
var_id2event_.emplace(var_id, std::move(device_event));
}
// Add events for next_instr.inputs
next_instr->intput_events_.emplace_back(var_id, var_id2event_.at(var_id),
is_sync);
}
}
void StreamAnalyzer::Schedule(const std::vector<OpFuncNode>& op_func_nodes,
const std::vector<size_t>& downstream_ops,
size_t op_index,
std::vector<Instruction>* instructions) {
auto& op_func_type = op_func_nodes[op_index].type_;
auto& cur_instr = instructions->at(op_index);
auto& next_instruction = cur_instr.next_instruction_;
if (op_func_type == OpFuncType::kQueueSync) {
// all downstream ops of kQueueSync can directly run, such as CPU -> Any
next_instruction.direct_run_ = downstream_ops;
} else { // kQueueAsync
std::vector<size_t> event_var_ids;
for (auto next_op_id : downstream_ops) {
auto& next_instr = instructions->at(next_op_id);
// case 1: GPU -> GPU(same stream)
if (cur_instr.dev_ctx_ == next_instr.dev_ctx_) {
next_instruction.direct_run_.emplace_back(next_op_id);
continue;
}
// Always insert events between different stream
auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr);
event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(),
new_event_var_ids.end());
bool is_sync =
(op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync);
AssociateInputWithEvents(new_event_var_ids, &next_instr, is_sync);
if (is_sync) { // GPU -> CPU
next_instruction.synchronize_run_.emplace_back(next_op_id);
} else { // GPU -> GPU(different stream)
next_instruction.event_wait_run_.emplace_back(next_op_id);
}
}
// Create events for these cross-stream vars
VLOG(3) << cur_instr.kernel_func_.operator_base_->Type()
<< " event_var_ids.size: " << event_var_ids.size();
for (auto var_id : event_var_ids) {
cur_instr.output_events_.emplace_back(var_id, var_id2event_.at(var_id),
false /*not used*/);
}
}
}
platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const OpFuncNode& op_func_node, const OperatorBase& op_base) {
auto& op_type = op_base.Type();
auto* dev_ctx = op_func_node.dev_ctx_;
if (op_type == interpretercore::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctx_pool_.Get(place_);
} else if (op_type == interpretercore::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from h2d_context_pool_";
dev_ctx = h2d_ctx_pool_.Get(place_);
}
return dev_ctx;
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_event.h"
namespace paddle {
namespace framework {
class StreamAnalyzer {
public:
explicit StreamAnalyzer(const platform::Place& place)
: place_(place), d2h_ctx_pool_({place}), h2d_ctx_pool_({place}) {}
~StreamAnalyzer() {}
void Schedule(const std::vector<OpFuncNode>& op_func_nodes,
const std::vector<size_t>& downstream_ops, size_t op_index,
std::vector<Instruction>* instructions);
platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node,
const OperatorBase& op_base);
private:
std::vector<size_t> ParseEventVarIds(const Instruction& cur_instr,
const Instruction& next_instr);
void AssociateInputWithEvents(const std::vector<size_t>& new_event_var_id,
Instruction* next_instr, bool is_sync);
platform::Place place_;
platform::DeviceContextPool d2h_ctx_pool_;
platform::DeviceContextPool h2d_ctx_pool_;
std::map<size_t, std::shared_ptr<platform::DeviceEvent>> var_id2event_;
};
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册