未验证 提交 e82c3a5f 编写于 作者: A Aurelius84 提交者: GitHub

Support No DataTransform From GetKernelTypeForVar (#36571)

* Add kQueueSync.synchronize_run_ logic

* Support No DataTransform From GetKernelTypeForVar
上级 ded3e705
...@@ -118,6 +118,8 @@ void InterpreterCore::Convert() { ...@@ -118,6 +118,8 @@ void InterpreterCore::Convert() {
temp_inst.input_index_ = vec_func_list_[i].input_index; temp_inst.input_index_ = vec_func_list_[i].input_index;
temp_inst.output_index_ = vec_func_list_[i].output_index; temp_inst.output_index_ = vec_func_list_[i].output_index;
temp_inst.type_ = vec_func_list_[i].type_; temp_inst.type_ = vec_func_list_[i].type_;
temp_inst.no_data_transform_index_ =
vec_func_list_[i].no_data_transform_index;
OpInOutInfo info; OpInOutInfo info;
......
...@@ -278,6 +278,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -278,6 +278,7 @@ void build_op_func_list(const platform::Place& place,
// step 3. Insert memcpy_op if needed // step 3. Insert memcpy_op if needed
VariableValueMap& ins_map_temp = runtime_context.inputs; VariableValueMap& ins_map_temp = runtime_context.inputs;
std::unordered_set<int> no_data_transform_index;
for (auto& var_name_item : ins_map_temp) { for (auto& var_name_item : ins_map_temp) {
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var = var_name_item.second[i]; auto var = var_name_item.second[i];
...@@ -289,8 +290,14 @@ void build_op_func_list(const platform::Place& place, ...@@ -289,8 +290,14 @@ void build_op_func_list(const platform::Place& place,
static_cast<const framework::OperatorWithKernel*>(op_base) static_cast<const framework::OperatorWithKernel*>(op_base)
->GetKernelTypeForVar(var_name_item.first, *tensor_in, ->GetKernelTypeForVar(var_name_item.first, *tensor_in,
expected_kernel_key); expected_kernel_key);
if (!platform::is_same_place(kernel_type_for_var.place_, if (platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_key.place_)) { expected_kernel_key.place_)) {
// record no need data transformer input var_id
auto& var_name = inputs_names[var_name_item.first][i];
VLOG(3) << op->Type() << " found no data_transform var: " << var_name
<< " with id: " << var_scope->name2id[var_name];
no_data_transform_index.emplace(var_scope->name2id[var_name]);
} else {
if (op_base->Type() == "fetch_v2") { if (op_base->Type() == "fetch_v2") {
op_base->SetAttr("deepcopy", false); op_base->SetAttr("deepcopy", false);
} }
...@@ -385,6 +392,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -385,6 +392,7 @@ void build_op_func_list(const platform::Place& place,
} }
} }
} }
op_func_node.no_data_transform_index = std::move(no_data_transform_index);
// step 4. Run op kernel // step 4. Run op kernel
op_list->push_back(op_base); op_list->push_back(op_base);
VLOG(3) << op_base->Type() VLOG(3) << op_base->Type()
......
...@@ -511,6 +511,8 @@ struct Instruction { ...@@ -511,6 +511,8 @@ struct Instruction {
std::map<std::string, std::vector<int>> input_index_; std::map<std::string, std::vector<int>> input_index_;
std::map<std::string, std::vector<int>> output_index_; std::map<std::string, std::vector<int>> output_index_;
std::unordered_set<int> no_data_transform_index_;
std::vector<size_t> gc_check_var_list; std::vector<size_t> gc_check_var_list;
NextInstruction next_instruction_; NextInstruction next_instruction_;
...@@ -527,6 +529,7 @@ struct OpFuncNode { ...@@ -527,6 +529,7 @@ struct OpFuncNode {
// int unsed; // int unsed;
std::map<std::string, std::vector<int>> input_index; std::map<std::string, std::vector<int>> input_index;
std::map<std::string, std::vector<int>> output_index; std::map<std::string, std::vector<int>> output_index;
std::unordered_set<int> no_data_transform_index;
OpKernelComputeFunc kernel_func_; OpKernelComputeFunc kernel_func_;
platform::DeviceContext* dev_ctx_; // not owned platform::DeviceContext* dev_ctx_; // not owned
......
...@@ -38,7 +38,8 @@ std::vector<size_t> StreamAnalyzer::ParseEventVarIds( ...@@ -38,7 +38,8 @@ std::vector<size_t> StreamAnalyzer::ParseEventVarIds(
std::vector<size_t> new_event_var_ids; std::vector<size_t> new_event_var_ids;
for (auto& item : next_instr.input_index_) { for (auto& item : next_instr.input_index_) {
for (auto var_id : item.second) { for (auto var_id : item.second) {
if (unique_var_ids.count(var_id) > 0) { if (unique_var_ids.count(var_id) > 0 &&
next_instr.no_data_transform_index_.count(var_id) == 0) {
new_event_var_ids.push_back(var_id); new_event_var_ids.push_back(var_id);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册