From 3b0d8a7bca61a8d38e369d4a3eb96ccab98566e9 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 24 Aug 2021 15:17:41 +0800 Subject: [PATCH] cache runtime ctx for executor, test=develop (#35108) --- .../framework/new_executor/interpretercore.cc | 44 +++++++++++-------- .../framework/new_executor/interpretercore.h | 8 ++-- .../new_executor/new_executor_defs.h | 3 ++ .../new_executor/standalone_executor_test.cc | 37 ++++++++++++++++ 4 files changed, 71 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 7f6091742f0..ffcb1b9f3dd 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -153,15 +153,19 @@ void InterpreterCore::Convert() { dependecy_count_[inst_id]++; } } + + for (size_t i = 0; i < vec_instruction_.size(); ++i) { + BuildInstructionCtx(&vec_instruction_[i], *global_scope_, place_); + } } -void InterpreterCore::RunInstruction(const Instruction& instr_node, - const VariableScope& var_scope, - const platform::Place& place) { - auto op_base = instr_node.kernel_func_.operator_base_; - // build runtime cost +void InterpreterCore::BuildInstructionCtx(Instruction* instr_node, + const VariableScope& var_scope, + const platform::Place& place) { + auto op_base = instr_node->kernel_func_.operator_base_; + VariableValueMap ins_map; - for (auto& var_name_item : instr_node.input_index_) { + for (auto& var_name_item : instr_node->input_index_) { std::vector input_vars; input_vars.reserve(var_name_item.second.size()); @@ -172,7 +176,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node, } VariableValueMap outs_map; - for (auto& var_name_item : instr_node.output_index_) { + for (auto& var_name_item : instr_node->output_index_) { std::vector out_vars; out_vars.reserve(var_name_item.second.size()); @@ -182,23 +186,27 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node, outs_map.emplace(var_name_item.first, std::move(out_vars)); } - RuntimeContext runtime_context({}, {}); - runtime_context.inputs.swap(ins_map); - runtime_context.outputs.swap(outs_map); + instr_node->runtime_ctx_.reset(new RuntimeContext({}, {})); + instr_node->runtime_ctx_->inputs.swap(ins_map); + instr_node->runtime_ctx_->outputs.swap(outs_map); - RuntimeInferShapeContext infer_shape_ctx(*op_base, runtime_context); - - static_cast(op_base)->InferShape( - &infer_shape_ctx); + instr_node->infershape_ctx_.reset( + new RuntimeInferShapeContext(*op_base, *instr_node->runtime_ctx_.get())); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); Scope scope; - auto exec_context = - ExecutionContext(*op_base, scope, *dev_ctx, runtime_context); + instr_node->execution_ctx_.reset(new ExecutionContext( + *op_base, scope, *dev_ctx, *instr_node->runtime_ctx_.get())); +} + +void InterpreterCore::RunInstruction(const Instruction& instr_node) { + static_cast( + instr_node.kernel_func_.operator_base_) + ->InferShape(instr_node.infershape_ctx_.get()); - instr_node.kernel_func_.compute_func_(exec_context); + instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get()); } void InterpreterCore::ExecuteInstructionList( @@ -219,7 +227,7 @@ void InterpreterCore::ExecuteInstructionList( auto instr_id = working_queue.front(); working_queue.pop(); auto& instr_node = vec_instr[instr_id]; - RunInstruction(instr_node, var_scope, place); + RunInstruction(instr_node); auto& next_instr = instr_node.next_instruction_.direct_run_; ++run_op_number; diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 4d3369c8947..c102916e92b 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -47,9 +47,11 @@ class InterpreterCore { private: void Convert(); - void RunInstruction(const Instruction& instr_node, - const VariableScope& var_scope, - const platform::Place& place); + void BuildInstructionCtx(Instruction* instr_node, + const VariableScope& var_scope, + const platform::Place& place); + + void RunInstruction(const Instruction& instr_node); void ExecuteInstructionList(const std::vector& vec_instr, const VariableScope& var_scope, diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index fb8a96aaca4..33ea943fdb2 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -59,6 +59,9 @@ struct EventRun { struct Instruction { OpKernelFunc kernel_func_; + std::shared_ptr runtime_ctx_; + std::shared_ptr infershape_ctx_; + std::shared_ptr execution_ctx_; std::map> input_index_; std::map> output_index_; diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index 9e831147903..eff505d164a 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -23,6 +23,43 @@ #include "paddle/fluid/framework/new_executor/standalone_executor.h" +USE_OP(fill_constant); +USE_OP(uniform_random); +USE_OP(lookup_table); +USE_OP(transpose2); +USE_OP(reshape2); +USE_OP(split); +USE_OP(slice); +USE_OP(concat); +USE_OP(matmul); +USE_OP(elementwise_add); +USE_OP(sigmoid); +USE_OP(tanh); +USE_OP(elementwise_mul); +USE_OP(softmax_with_cross_entropy); +USE_OP(reduce_mean); +USE_OP(reduce_sum); +USE_OP(reduce_sum_grad); +USE_OP(reduce_mean_grad); +USE_OP(reshape2_grad); +USE_OP(softmax_with_cross_entropy_grad); +USE_OP(elementwise_add_grad); +USE_OP(matmul_grad); +USE_OP(square); +USE_OP(transpose2_grad); +USE_OP(concat_grad); +USE_OP(elementwise_mul_grad); +USE_OP(sigmoid_grad); +USE_OP(tanh_grad); +USE_OP(sum); +USE_OP(slice_grad); +USE_OP(lookup_table_grad); +USE_OP(sqrt); +USE_OP(elementwise_max); +USE_OP(elementwise_div); +USE_OP(sgd); +USE_OP(squared_l2_norm); + paddle::framework::ProgramDesc load_from_file(const std::string& file_name) { std::ifstream fin(file_name, std::ios::in | std::ios::binary); fin.seekg(0, std::ios::end); -- GitLab