未验证 提交 3b0d8a7b 编写于 作者: W wanghuancoder 提交者: GitHub

cache runtime ctx for executor, test=develop (#35108)

上级 cb28753c
......@@ -153,15 +153,19 @@ void InterpreterCore::Convert() {
dependecy_count_[inst_id]++;
}
}
for (size_t i = 0; i < vec_instruction_.size(); ++i) {
BuildInstructionCtx(&vec_instruction_[i], *global_scope_, place_);
}
}
void InterpreterCore::RunInstruction(const Instruction& instr_node,
const VariableScope& var_scope,
const platform::Place& place) {
auto op_base = instr_node.kernel_func_.operator_base_;
// build runtime cost
void InterpreterCore::BuildInstructionCtx(Instruction* instr_node,
const VariableScope& var_scope,
const platform::Place& place) {
auto op_base = instr_node->kernel_func_.operator_base_;
VariableValueMap ins_map;
for (auto& var_name_item : instr_node.input_index_) {
for (auto& var_name_item : instr_node->input_index_) {
std::vector<Variable*> input_vars;
input_vars.reserve(var_name_item.second.size());
......@@ -172,7 +176,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node,
}
VariableValueMap outs_map;
for (auto& var_name_item : instr_node.output_index_) {
for (auto& var_name_item : instr_node->output_index_) {
std::vector<Variable*> out_vars;
out_vars.reserve(var_name_item.second.size());
......@@ -182,23 +186,27 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node,
outs_map.emplace(var_name_item.first, std::move(out_vars));
}
RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map);
instr_node->runtime_ctx_.reset(new RuntimeContext({}, {}));
instr_node->runtime_ctx_->inputs.swap(ins_map);
instr_node->runtime_ctx_->outputs.swap(outs_map);
RuntimeInferShapeContext infer_shape_ctx(*op_base, runtime_context);
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape(
&infer_shape_ctx);
instr_node->infershape_ctx_.reset(
new RuntimeInferShapeContext(*op_base, *instr_node->runtime_ctx_.get()));
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
Scope scope;
auto exec_context =
ExecutionContext(*op_base, scope, *dev_ctx, runtime_context);
instr_node->execution_ctx_.reset(new ExecutionContext(
*op_base, scope, *dev_ctx, *instr_node->runtime_ctx_.get()));
}
void InterpreterCore::RunInstruction(const Instruction& instr_node) {
static_cast<const framework::OperatorWithKernel*>(
instr_node.kernel_func_.operator_base_)
->InferShape(instr_node.infershape_ctx_.get());
instr_node.kernel_func_.compute_func_(exec_context);
instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get());
}
void InterpreterCore::ExecuteInstructionList(
......@@ -219,7 +227,7 @@ void InterpreterCore::ExecuteInstructionList(
auto instr_id = working_queue.front();
working_queue.pop();
auto& instr_node = vec_instr[instr_id];
RunInstruction(instr_node, var_scope, place);
RunInstruction(instr_node);
auto& next_instr = instr_node.next_instruction_.direct_run_;
++run_op_number;
......
......@@ -47,9 +47,11 @@ class InterpreterCore {
private:
void Convert();
void RunInstruction(const Instruction& instr_node,
const VariableScope& var_scope,
const platform::Place& place);
void BuildInstructionCtx(Instruction* instr_node,
const VariableScope& var_scope,
const platform::Place& place);
void RunInstruction(const Instruction& instr_node);
void ExecuteInstructionList(const std::vector<Instruction>& vec_instr,
const VariableScope& var_scope,
......
......@@ -59,6 +59,9 @@ struct EventRun {
struct Instruction {
OpKernelFunc kernel_func_;
std::shared_ptr<RuntimeContext> runtime_ctx_;
std::shared_ptr<RuntimeInferShapeContext> infershape_ctx_;
std::shared_ptr<ExecutionContext> execution_ctx_;
std::map<std::string, std::vector<int>> input_index_;
std::map<std::string, std::vector<int>> output_index_;
......
......@@ -23,6 +23,43 @@
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
USE_OP(fill_constant);
USE_OP(uniform_random);
USE_OP(lookup_table);
USE_OP(transpose2);
USE_OP(reshape2);
USE_OP(split);
USE_OP(slice);
USE_OP(concat);
USE_OP(matmul);
USE_OP(elementwise_add);
USE_OP(sigmoid);
USE_OP(tanh);
USE_OP(elementwise_mul);
USE_OP(softmax_with_cross_entropy);
USE_OP(reduce_mean);
USE_OP(reduce_sum);
USE_OP(reduce_sum_grad);
USE_OP(reduce_mean_grad);
USE_OP(reshape2_grad);
USE_OP(softmax_with_cross_entropy_grad);
USE_OP(elementwise_add_grad);
USE_OP(matmul_grad);
USE_OP(square);
USE_OP(transpose2_grad);
USE_OP(concat_grad);
USE_OP(elementwise_mul_grad);
USE_OP(sigmoid_grad);
USE_OP(tanh_grad);
USE_OP(sum);
USE_OP(slice_grad);
USE_OP(lookup_table_grad);
USE_OP(sqrt);
USE_OP(elementwise_max);
USE_OP(elementwise_div);
USE_OP(sgd);
USE_OP(squared_l2_norm);
paddle::framework::ProgramDesc load_from_file(const std::string& file_name) {
std::ifstream fin(file_name, std::ios::in | std::ios::binary);
fin.seekg(0, std::ios::end);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册