From e82c3a5f6da3348845a65670d412d5607c7b9c14 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 21 Oct 2021 10:10:49 +0800 Subject: [PATCH] Support No DataTransform From GetKernelTypeForVar (#36571) * Add kQueueSync.synchronize_run_ logic * Support No DataTransform From GetKernelTypeForVar --- .../fluid/framework/new_executor/interpretercore.cc | 2 ++ .../framework/new_executor/interpretercore_util.cc | 12 ++++++++++-- .../fluid/framework/new_executor/new_executor_defs.h | 3 +++ .../fluid/framework/new_executor/stream_analyzer.cc | 3 ++- 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index f6157367cd4..b26d213ddf7 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -118,6 +118,8 @@ void InterpreterCore::Convert() { temp_inst.input_index_ = vec_func_list_[i].input_index; temp_inst.output_index_ = vec_func_list_[i].output_index; temp_inst.type_ = vec_func_list_[i].type_; + temp_inst.no_data_transform_index_ = + vec_func_list_[i].no_data_transform_index; OpInOutInfo info; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 3438fc3bd4d..7bb0429c622 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -278,6 +278,7 @@ void build_op_func_list(const platform::Place& place, // step 3. Insert memcpy_op if needed VariableValueMap& ins_map_temp = runtime_context.inputs; + std::unordered_set no_data_transform_index; for (auto& var_name_item : ins_map_temp) { for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto var = var_name_item.second[i]; @@ -289,8 +290,14 @@ void build_op_func_list(const platform::Place& place, static_cast(op_base) ->GetKernelTypeForVar(var_name_item.first, *tensor_in, expected_kernel_key); - if (!platform::is_same_place(kernel_type_for_var.place_, - expected_kernel_key.place_)) { + if (platform::is_same_place(kernel_type_for_var.place_, + expected_kernel_key.place_)) { + // record no need data transformer input var_id + auto& var_name = inputs_names[var_name_item.first][i]; + VLOG(3) << op->Type() << " found no data_transform var: " << var_name + << " with id: " << var_scope->name2id[var_name]; + no_data_transform_index.emplace(var_scope->name2id[var_name]); + } else { if (op_base->Type() == "fetch_v2") { op_base->SetAttr("deepcopy", false); } @@ -385,6 +392,7 @@ void build_op_func_list(const platform::Place& place, } } } + op_func_node.no_data_transform_index = std::move(no_data_transform_index); // step 4. Run op kernel op_list->push_back(op_base); VLOG(3) << op_base->Type() diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 19b7b6d5dc2..e6cff353a65 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -511,6 +511,8 @@ struct Instruction { std::map> input_index_; std::map> output_index_; + std::unordered_set no_data_transform_index_; + std::vector gc_check_var_list; NextInstruction next_instruction_; @@ -527,6 +529,7 @@ struct OpFuncNode { // int unsed; std::map> input_index; std::map> output_index; + std::unordered_set no_data_transform_index; OpKernelComputeFunc kernel_func_; platform::DeviceContext* dev_ctx_; // not owned diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc index a9322d8fc88..ffc2da499e1 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -38,7 +38,8 @@ std::vector StreamAnalyzer::ParseEventVarIds( std::vector new_event_var_ids; for (auto& item : next_instr.input_index_) { for (auto var_id : item.second) { - if (unique_var_ids.count(var_id) > 0) { + if (unique_var_ids.count(var_id) > 0 && + next_instr.no_data_transform_index_.count(var_id) == 0) { new_event_var_ids.push_back(var_id); } } -- GitLab