未验证 提交 3b0d8a7b 编写于 作者: W wanghuancoder 提交者: GitHub

cache runtime ctx for executor, test=develop (#35108)

上级 cb28753c
...@@ -153,15 +153,19 @@ void InterpreterCore::Convert() { ...@@ -153,15 +153,19 @@ void InterpreterCore::Convert() {
dependecy_count_[inst_id]++; dependecy_count_[inst_id]++;
} }
} }
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
BuildInstructionCtx(&vec_instruction_[i], *global_scope_, place_);
}
} }
void InterpreterCore::RunInstruction(const Instruction& instr_node, void InterpreterCore::BuildInstructionCtx(Instruction* instr_node,
const VariableScope& var_scope, const VariableScope& var_scope,
const platform::Place& place) { const platform::Place& place) {
auto op_base = instr_node.kernel_func_.operator_base_; auto op_base = instr_node->kernel_func_.operator_base_;
// build runtime cost
VariableValueMap ins_map; VariableValueMap ins_map;
for (auto& var_name_item : instr_node.input_index_) { for (auto& var_name_item : instr_node->input_index_) {
std::vector<Variable*> input_vars; std::vector<Variable*> input_vars;
input_vars.reserve(var_name_item.second.size()); input_vars.reserve(var_name_item.second.size());
...@@ -172,7 +176,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node, ...@@ -172,7 +176,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node,
} }
VariableValueMap outs_map; VariableValueMap outs_map;
for (auto& var_name_item : instr_node.output_index_) { for (auto& var_name_item : instr_node->output_index_) {
std::vector<Variable*> out_vars; std::vector<Variable*> out_vars;
out_vars.reserve(var_name_item.second.size()); out_vars.reserve(var_name_item.second.size());
...@@ -182,23 +186,27 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node, ...@@ -182,23 +186,27 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node,
outs_map.emplace(var_name_item.first, std::move(out_vars)); outs_map.emplace(var_name_item.first, std::move(out_vars));
} }
RuntimeContext runtime_context({}, {}); instr_node->runtime_ctx_.reset(new RuntimeContext({}, {}));
runtime_context.inputs.swap(ins_map); instr_node->runtime_ctx_->inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map); instr_node->runtime_ctx_->outputs.swap(outs_map);
RuntimeInferShapeContext infer_shape_ctx(*op_base, runtime_context); instr_node->infershape_ctx_.reset(
new RuntimeInferShapeContext(*op_base, *instr_node->runtime_ctx_.get()));
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape(
&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
Scope scope; Scope scope;
auto exec_context = instr_node->execution_ctx_.reset(new ExecutionContext(
ExecutionContext(*op_base, scope, *dev_ctx, runtime_context); *op_base, scope, *dev_ctx, *instr_node->runtime_ctx_.get()));
}
void InterpreterCore::RunInstruction(const Instruction& instr_node) {
static_cast<const framework::OperatorWithKernel*>(
instr_node.kernel_func_.operator_base_)
->InferShape(instr_node.infershape_ctx_.get());
instr_node.kernel_func_.compute_func_(exec_context); instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get());
} }
void InterpreterCore::ExecuteInstructionList( void InterpreterCore::ExecuteInstructionList(
...@@ -219,7 +227,7 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -219,7 +227,7 @@ void InterpreterCore::ExecuteInstructionList(
auto instr_id = working_queue.front(); auto instr_id = working_queue.front();
working_queue.pop(); working_queue.pop();
auto& instr_node = vec_instr[instr_id]; auto& instr_node = vec_instr[instr_id];
RunInstruction(instr_node, var_scope, place); RunInstruction(instr_node);
auto& next_instr = instr_node.next_instruction_.direct_run_; auto& next_instr = instr_node.next_instruction_.direct_run_;
++run_op_number; ++run_op_number;
......
...@@ -47,9 +47,11 @@ class InterpreterCore { ...@@ -47,9 +47,11 @@ class InterpreterCore {
private: private:
void Convert(); void Convert();
void RunInstruction(const Instruction& instr_node, void BuildInstructionCtx(Instruction* instr_node,
const VariableScope& var_scope, const VariableScope& var_scope,
const platform::Place& place); const platform::Place& place);
void RunInstruction(const Instruction& instr_node);
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr, void ExecuteInstructionList(const std::vector<Instruction>& vec_instr,
const VariableScope& var_scope, const VariableScope& var_scope,
......
...@@ -59,6 +59,9 @@ struct EventRun { ...@@ -59,6 +59,9 @@ struct EventRun {
struct Instruction { struct Instruction {
OpKernelFunc kernel_func_; OpKernelFunc kernel_func_;
std::shared_ptr<RuntimeContext> runtime_ctx_;
std::shared_ptr<RuntimeInferShapeContext> infershape_ctx_;
std::shared_ptr<ExecutionContext> execution_ctx_;
std::map<std::string, std::vector<int>> input_index_; std::map<std::string, std::vector<int>> input_index_;
std::map<std::string, std::vector<int>> output_index_; std::map<std::string, std::vector<int>> output_index_;
......
...@@ -23,6 +23,43 @@ ...@@ -23,6 +23,43 @@
#include "paddle/fluid/framework/new_executor/standalone_executor.h" #include "paddle/fluid/framework/new_executor/standalone_executor.h"
USE_OP(fill_constant);
USE_OP(uniform_random);
USE_OP(lookup_table);
USE_OP(transpose2);
USE_OP(reshape2);
USE_OP(split);
USE_OP(slice);
USE_OP(concat);
USE_OP(matmul);
USE_OP(elementwise_add);
USE_OP(sigmoid);
USE_OP(tanh);
USE_OP(elementwise_mul);
USE_OP(softmax_with_cross_entropy);
USE_OP(reduce_mean);
USE_OP(reduce_sum);
USE_OP(reduce_sum_grad);
USE_OP(reduce_mean_grad);
USE_OP(reshape2_grad);
USE_OP(softmax_with_cross_entropy_grad);
USE_OP(elementwise_add_grad);
USE_OP(matmul_grad);
USE_OP(square);
USE_OP(transpose2_grad);
USE_OP(concat_grad);
USE_OP(elementwise_mul_grad);
USE_OP(sigmoid_grad);
USE_OP(tanh_grad);
USE_OP(sum);
USE_OP(slice_grad);
USE_OP(lookup_table_grad);
USE_OP(sqrt);
USE_OP(elementwise_max);
USE_OP(elementwise_div);
USE_OP(sgd);
USE_OP(squared_l2_norm);
paddle::framework::ProgramDesc load_from_file(const std::string& file_name) { paddle::framework::ProgramDesc load_from_file(const std::string& file_name) {
std::ifstream fin(file_name, std::ios::in | std::ios::binary); std::ifstream fin(file_name, std::ios::in | std::ios::binary);
fin.seekg(0, std::ios::end); fin.seekg(0, std::ios::end);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册