未验证 提交 a6e99dc7 编写于 作者: A Aurelius84 提交者: GitHub

Refactor InterpretorCore and Modify into BlockDesc (#37056)

上级 993ec76a
...@@ -33,13 +33,11 @@ namespace framework { ...@@ -33,13 +33,11 @@ namespace framework {
// NOTE(Aurelius84): Need a better strategy to determine it. // NOTE(Aurelius84): Need a better strategy to determine it.
static constexpr size_t kHostNumThreads = 4; static constexpr size_t kHostNumThreads = 4;
InterpreterCore::InterpreterCore(const platform::Place& place, InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block,
const ProgramDesc& main_prog,
VariableScope* global_scope, VariableScope* global_scope,
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names)
const std::vector<std::string>& fetch_names)
: place_(place), : place_(place),
main_program_(main_prog), block_(block),
global_scope_(global_scope), global_scope_(global_scope),
stream_analyzer_(place), stream_analyzer_(place),
async_work_queue_(kHostNumThreads, &main_thread_blocker_) { async_work_queue_(kHostNumThreads, &main_thread_blocker_) {
...@@ -50,9 +48,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -50,9 +48,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
exception_notifier_ = main_thread_blocker_.RegisterEvent( exception_notifier_ = main_thread_blocker_.RegisterEvent(
kExceptionCaught, [this]() { return exception_holder_.IsCaught(); }); kExceptionCaught, [this]() { return exception_holder_.IsCaught(); });
// Step1: add feedop and fetchop to main_program
AddFetch(fetch_names);
// prune // prune
// optmize graph pass // optmize graph pass
...@@ -60,24 +55,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -60,24 +55,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
// convert to run graph // convert to run graph
} }
void InterpreterCore::AddFetch(const std::vector<std::string>& fetch_names) {
auto* fetch_holder = main_program_.MutableBlock(0)->Var("fetch_vars");
fetch_holder->SetType(proto::VarType::FETCH_LIST);
fetch_holder->SetPersistable(true);
int i = 0;
for (auto& fetch_name : fetch_names) {
// append fetch op
auto* op = main_program_.MutableBlock(0)->AppendOp();
op->SetType("fetch_v2");
op->SetInput("X", {fetch_name});
op->SetOutput("Out", {"fetch_vars"});
op->SetAttr("col", {static_cast<int>(i)});
op->CheckAttrs();
i++;
}
}
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
auto FeedInput = [&] { auto FeedInput = [&] {
...@@ -90,11 +67,11 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -90,11 +67,11 @@ paddle::framework::FetchList InterpreterCore::Run(
}; };
if (is_build_ == false) { if (is_build_ == false) {
paddle::framework::interpretercore::build_variable_scope(main_program_, paddle::framework::interpreter::build_variable_scope(*block_,
global_scope_); global_scope_);
FeedInput(); FeedInput();
paddle::framework::interpretercore::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, main_program_, &vec_func_list_, global_scope_); place_, *block_, &vec_func_list_, global_scope_);
is_build_ = true; is_build_ = true;
// convert vec func_list to graph // convert vec func_list to graph
Convert(); Convert();
...@@ -104,7 +81,7 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -104,7 +81,7 @@ paddle::framework::FetchList InterpreterCore::Run(
} }
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = global_scope_->Var("fetch_vars"); auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return *(fetch_var->GetMutable<framework::FetchList>()); return *(fetch_var->GetMutable<framework::FetchList>());
} }
...@@ -172,8 +149,7 @@ void InterpreterCore::Convert() { ...@@ -172,8 +149,7 @@ void InterpreterCore::Convert() {
std::vector<size_t> vec_temp; std::vector<size_t> vec_temp;
for (auto& item : vec_instruction_[i].Outputs()) { for (auto& item : vec_instruction_[i].Outputs()) {
for (auto id : item.second) { for (auto id : item.second) {
vec_temp = vec_temp = interpreter::merge_vector(vec_temp, input_var2op_info_[id]);
interpretercore::merge_vector(vec_temp, input_var2op_info_[id]);
} }
} }
...@@ -438,8 +414,8 @@ void InterpreterCore::RunNextInstructions( ...@@ -438,8 +414,8 @@ void InterpreterCore::RunNextInstructions(
[&, next_id] { RunInstructionAsync(next_id); }); [&, next_id] { RunInstructionAsync(next_id); });
} }
} }
auto direct_run_ops = interpretercore::merge_vector( auto direct_run_ops = interpreter::merge_vector(next_instr.SyncRunIds(),
next_instr.SyncRunIds(), next_instr.DirectRunIds()); next_instr.DirectRunIds());
size_t first_op = 0; size_t first_op = 0;
for (auto next_id : direct_run_ops) { for (auto next_id : direct_run_ops) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
...@@ -538,11 +514,11 @@ void InterpreterCore::DryRunPrepare( ...@@ -538,11 +514,11 @@ void InterpreterCore::DryRunPrepare(
}; };
if (is_build_ == false) { if (is_build_ == false) {
paddle::framework::interpretercore::build_variable_scope(main_program_, paddle::framework::interpreter::build_variable_scope(*block_,
global_scope_); global_scope_);
FeedInput(); FeedInput();
paddle::framework::interpretercore::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, main_program_, &vec_func_list_, global_scope_); place_, *block_, &vec_func_list_, global_scope_);
is_build_ = true; is_build_ = true;
// convert vec func_list to graph // convert vec func_list to graph
Convert(); Convert();
......
...@@ -40,10 +40,9 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>; ...@@ -40,10 +40,9 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
class InterpreterCore { class InterpreterCore {
public: public:
InterpreterCore(const platform::Place& place, const ProgramDesc& main_prog, InterpreterCore(const platform::Place& place, BlockDesc* block,
VariableScope* global_scope, VariableScope* global_scope,
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names);
const std::vector<std::string>& fetch_names);
paddle::framework::FetchList Run( paddle::framework::FetchList Run(
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
...@@ -72,15 +71,14 @@ class InterpreterCore { ...@@ -72,15 +71,14 @@ class InterpreterCore {
void RunInstructionAsync(size_t instr_id); void RunInstructionAsync(size_t instr_id);
void RunNextInstructions(const Instruction& instr_id, void RunNextInstructions(const Instruction& instr_id,
std::queue<size_t>* reserved_next_ops); std::queue<size_t>* reserved_next_ops);
void AddFetch(const std::vector<std::string>& fetch_names);
void BuildSkipShareLoDInfo(); void BuildSkipShareLoDInfo();
bool is_build_; bool is_build_;
const platform::Place& place_; const platform::Place& place_;
ProgramDesc main_program_; BlockDesc* block_; // not owned
VariableScope* global_scope_; VariableScope* global_scope_; // not owned
std::vector<paddle::framework::OpFuncNode> vec_func_list_; std::vector<paddle::framework::OpFuncNode> vec_func_list_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
...@@ -88,7 +86,6 @@ class InterpreterCore { ...@@ -88,7 +86,6 @@ class InterpreterCore {
InstructionInfo instruction_info_; InstructionInfo instruction_info_;
std::vector<size_t> dependecy_count_; std::vector<size_t> dependecy_count_;
std::vector<std::vector<size_t>> input_var2op_info_; std::vector<std::vector<size_t>> input_var2op_info_;
std::vector<VariableMetaInfo> ref_coun_info_;
std::vector<VariableMetaInfo> vec_meta_info_; std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<std::string> feed_names_; std::vector<std::string> feed_names_;
...@@ -97,7 +94,7 @@ class InterpreterCore { ...@@ -97,7 +94,7 @@ class InterpreterCore {
StreamAnalyzer stream_analyzer_; StreamAnalyzer stream_analyzer_;
EventManager event_manager_; EventManager event_manager_;
EventsWaiter main_thread_blocker_; EventsWaiter main_thread_blocker_;
interpretercore::AsyncWorkQueue async_work_queue_; interpreter::AsyncWorkQueue async_work_queue_;
details::ExceptionHolder exception_holder_; details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr}; std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace interpretercore { namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>; using VariableIdMap = std::map<std::string, std::vector<int>>;
AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicDeps( AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicDeps(
...@@ -129,11 +129,9 @@ std::string get_memcpy_type(const platform::Place& src_place, ...@@ -129,11 +129,9 @@ std::string get_memcpy_type(const platform::Place& src_place,
} }
} }
void build_variable_scope(const framework::ProgramDesc& pdesc, void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope) { VariableScope* var_scope) {
auto& global_block = pdesc.Block(0); for (auto& var_desc : block.AllVars()) {
for (auto& var_desc : global_block.AllVars()) {
auto var_name = var_desc->Name(); auto var_name = var_desc->Name();
if (var_name == framework::kEmptyVarName) { if (var_name == framework::kEmptyVarName) {
continue; continue;
...@@ -360,9 +358,9 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var( ...@@ -360,9 +358,9 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
std::vector<OpFuncNode> apply_data_transform( std::vector<OpFuncNode> apply_data_transform(
const OpKernelType& expected_kernel_key, const platform::Place& place, const OpKernelType& expected_kernel_key, const platform::Place& place,
VariableValueMap& ins_map_temp, VariableScope* var_scope, VariableValueMap* ins_map_temp, VariableScope* var_scope,
OpFuncNode& op_func_node) { OpFuncNode* op_func_node) {
auto& op_base = op_func_node.operator_base_; auto& op_base = op_func_node->operator_base_;
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base is null, please pass a valid " "op_base is null, please pass a valid "
"op_base in apply_data_transform.")); "op_base in apply_data_transform."));
...@@ -372,7 +370,7 @@ std::vector<OpFuncNode> apply_data_transform( ...@@ -372,7 +370,7 @@ std::vector<OpFuncNode> apply_data_transform(
no_data_transform_index; // record the no need transform variable index. no_data_transform_index; // record the no need transform variable index.
std::vector<OpFuncNode> copy_func_nodes; // return all the copy opfuncnode. std::vector<OpFuncNode> copy_func_nodes; // return all the copy opfuncnode.
for (auto& var_name_item : ins_map_temp) { for (auto& var_name_item : *ins_map_temp) {
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var = var_name_item.second[i]; auto var = var_name_item.second[i];
auto& var_name = inputs_names[var_name_item.first].at(i); auto& var_name = inputs_names[var_name_item.first].at(i);
...@@ -394,8 +392,8 @@ std::vector<OpFuncNode> apply_data_transform( ...@@ -394,8 +392,8 @@ std::vector<OpFuncNode> apply_data_transform(
std::tie(new_var_name, copy_op_func_node) = std::tie(new_var_name, copy_op_func_node) =
apply_place_transform_for_var( apply_place_transform_for_var(
kernel_type_for_var, expected_kernel_key, place, var_name, kernel_type_for_var, expected_kernel_key, place, var_name,
var_name_item.first, op_func_node, var, var_scope); var_name_item.first, *op_func_node, var, var_scope);
op_func_node.input_index[var_name_item.first][i] = op_func_node->input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name); var_scope->VarId(new_var_name);
copy_func_nodes.push_back(copy_op_func_node); copy_func_nodes.push_back(copy_op_func_node);
var_name_item.second[i] = var_scope->Var(new_var_name); var_name_item.second[i] = var_scope->Var(new_var_name);
...@@ -414,23 +412,22 @@ std::vector<OpFuncNode> apply_data_transform( ...@@ -414,23 +412,22 @@ std::vector<OpFuncNode> apply_data_transform(
} }
} }
} }
op_func_node.no_data_transform_index = std::move(no_data_transform_index); op_func_node->no_data_transform_index = std::move(no_data_transform_index);
return copy_func_nodes; return copy_func_nodes;
} }
void build_op_func_list(const platform::Place& place, void build_op_func_list(const platform::Place& place,
const framework::ProgramDesc& pdesc, const framework::BlockDesc& block,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) { VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
// Step 1: create all ops for global block. // Step 1: create all ops for current block.
auto ops = create_all_ops(global_block); auto ops = create_all_ops(block);
auto unused_var_map = get_unused_vars(global_block, ops); auto unused_var_map = get_unused_vars(block, ops);
size_t ops_index = 0; size_t ops_index = 0;
for (auto& op : global_block.AllOps()) { for (auto& op : block.AllOps()) {
VLOG(6) << "Build OpFuncNode from : " << op->Type(); VLOG(6) << "Build OpFuncNode from : " << op->Type();
auto op_base = ops[ops_index++]; auto op_base = ops[ops_index++];
...@@ -498,7 +495,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -498,7 +495,7 @@ void build_op_func_list(const platform::Place& place,
// apply_data_transform. // apply_data_transform.
op_func_node.operator_base_ = op_base; op_func_node.operator_base_ = op_base;
copy_op_to_insert = apply_data_transform( copy_op_to_insert = apply_data_transform(
expected_kernel_key, place, ins_map_temp, var_scope, op_func_node); expected_kernel_key, place, &ins_map_temp, var_scope, &op_func_node);
for (auto& item : copy_op_to_insert) { for (auto& item : copy_op_to_insert) {
vec_func_list->push_back(item); vec_func_list->push_back(item);
} }
...@@ -576,6 +573,25 @@ void build_op_func_list(const platform::Place& place, ...@@ -576,6 +573,25 @@ void build_op_func_list(const platform::Place& place,
} }
} }
void add_fetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block) {
auto* fetch_holder = block->Var(kFetchVarName);
fetch_holder->SetType(proto::VarType::FETCH_LIST);
fetch_holder->SetPersistable(true);
int i = 0;
for (auto& fetch_name : fetch_names) {
// append fetch op
auto* op = block->AppendOp();
op->SetType("fetch_v2");
op->SetInput("X", {fetch_name});
op->SetOutput("Out", {kFetchVarName});
op->SetAttr("col", {static_cast<int>(i)});
op->CheckAttrs();
i++;
}
}
std::vector<size_t> merge_vector(const std::vector<size_t>& first, std::vector<size_t> merge_vector(const std::vector<size_t>& first,
const std::vector<size_t>& second) { const std::vector<size_t>& second) {
std::vector<size_t> out(first.size() + second.size()); std::vector<size_t> out(first.size() + second.size());
...@@ -590,6 +606,6 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first, ...@@ -590,6 +606,6 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
return out; return out;
} }
} // namespace interpretercore } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -48,9 +48,10 @@ ...@@ -48,9 +48,10 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace interpretercore { namespace interpreter {
using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>; using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
static constexpr char kFetchVarName[] = "fetch_vars";
class AsyncWorkQueue { class AsyncWorkQueue {
public: public:
...@@ -96,17 +97,20 @@ class AsyncWorkQueue { ...@@ -96,17 +97,20 @@ class AsyncWorkQueue {
std::string get_memcpy_type(const platform::Place& src_place, std::string get_memcpy_type(const platform::Place& src_place,
const platform::Place& dst_place); const platform::Place& dst_place);
void build_variable_scope(const framework::ProgramDesc& pdesc, void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope); VariableScope* var_scope);
void build_op_func_list(const platform::Place& place, void build_op_func_list(const platform::Place& place,
const framework::ProgramDesc& pdesc, const framework::BlockDesc& block,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope); VariableScope* var_scope);
void add_fetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block);
std::vector<size_t> merge_vector(const std::vector<size_t>& first, std::vector<size_t> merge_vector(const std::vector<size_t>& first,
const std::vector<size_t>& second); const std::vector<size_t>& second);
} // namespace interpretercore } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -776,7 +776,7 @@ class Instruction { ...@@ -776,7 +776,7 @@ class Instruction {
std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_; std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_;
}; };
namespace interpretercore { namespace interpreter {
static constexpr char kMemcpyH2D[] = "memcpy_h2d"; static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h"; static constexpr char kMemcpyD2H[] = "memcpy_d2h";
...@@ -787,7 +787,7 @@ static bool IsMemcpyH2D(const Instruction& instr) { ...@@ -787,7 +787,7 @@ static bool IsMemcpyH2D(const Instruction& instr) {
static bool IsMemcpyD2H(const Instruction& instr) { static bool IsMemcpyD2H(const Instruction& instr) {
return instr.OpBase()->Type() == kMemcpyD2H; return instr.OpBase()->Type() == kMemcpyD2H;
} }
} // namespace interpretercore } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -41,8 +41,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -41,8 +41,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
// run startup program // run startup program
std::vector<paddle::framework::OpFuncNode> vec_func_list; std::vector<paddle::framework::OpFuncNode> vec_func_list;
paddle::framework::interpretercore::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, startup_prog, &vec_func_list, &global_scope_); place_, startup_prog.Block(0), &vec_func_list, &global_scope_);
} }
paddle::framework::FetchList StandaloneExecutor::Run( paddle::framework::FetchList StandaloneExecutor::Run(
...@@ -96,8 +96,15 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore( ...@@ -96,8 +96,15 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
if (iter == interpretercores_.end()) { if (iter == interpretercores_.end()) {
VLOG(3) << "create interpreter_core for " << oss.str(); VLOG(3) << "create interpreter_core for " << oss.str();
auto core = std::make_shared<InterpreterCore>( // NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy a
place_, main_prog_, &global_scope_, feed_names, fetch_names); // new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(main_prog_);
auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block);
auto core = std::make_shared<InterpreterCore>(place_, block, &global_scope_,
feed_names);
programs_.emplace(oss.str(), new_prog);
interpretercores_.emplace(oss.str(), core); interpretercores_.emplace(oss.str(), core);
return core; return core;
} else { } else {
......
...@@ -62,6 +62,7 @@ class StandaloneExecutor : public ExecutorBase { ...@@ -62,6 +62,7 @@ class StandaloneExecutor : public ExecutorBase {
Scope* outer_scope_; Scope* outer_scope_;
VariableScope global_scope_; VariableScope global_scope_;
std::unordered_map<std::string, std::shared_ptr<ProgramDesc>> programs_;
std::unordered_map<std::string, std::shared_ptr<InterpreterCore>> std::unordered_map<std::string, std::shared_ptr<InterpreterCore>>
interpretercores_; interpretercores_;
}; };
......
...@@ -101,10 +101,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -101,10 +101,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const OpFuncNode& op_func_node) { const OpFuncNode& op_func_node) {
auto& op_type = op_func_node.operator_base_->Type(); auto& op_type = op_func_node.operator_base_->Type();
auto* dev_ctx = op_func_node.dev_ctx_; auto* dev_ctx = op_func_node.dev_ctx_;
if (op_type == interpretercore::kMemcpyH2D) { if (op_type == interpreter::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_"; VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctx_pool_.Get(place_); dev_ctx = d2h_ctx_pool_.Get(place_);
} else if (op_type == interpretercore::kMemcpyD2H) { } else if (op_type == interpreter::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from h2d_context_pool_"; VLOG(3) << "Get dev_ctx from h2d_context_pool_";
dev_ctx = h2d_ctx_pool_.Get(place_); dev_ctx = h2d_ctx_pool_.Get(place_);
} }
...@@ -122,8 +122,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -122,8 +122,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) { const Instruction& next_instr) {
return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() || return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() ||
interpretercore::IsMemcpyD2H(cur_instr) || interpreter::IsMemcpyD2H(cur_instr) ||
interpretercore::IsMemcpyH2D(next_instr)); interpreter::IsMemcpyH2D(next_instr));
} }
platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) { platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册