未验证 提交 a6e99dc7 编写于 作者: A Aurelius84 提交者: GitHub

Refactor InterpretorCore and Modify into BlockDesc (#37056)

上级 993ec76a
......@@ -33,13 +33,11 @@ namespace framework {
// NOTE(Aurelius84): Need a better strategy to determine it.
static constexpr size_t kHostNumThreads = 4;
InterpreterCore::InterpreterCore(const platform::Place& place,
const ProgramDesc& main_prog,
InterpreterCore::InterpreterCore(const platform::Place& place, BlockDesc* block,
VariableScope* global_scope,
const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names)
const std::vector<std::string>& feed_names)
: place_(place),
main_program_(main_prog),
block_(block),
global_scope_(global_scope),
stream_analyzer_(place),
async_work_queue_(kHostNumThreads, &main_thread_blocker_) {
......@@ -50,9 +48,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
exception_notifier_ = main_thread_blocker_.RegisterEvent(
kExceptionCaught, [this]() { return exception_holder_.IsCaught(); });
// Step1: add feedop and fetchop to main_program
AddFetch(fetch_names);
// prune
// optmize graph pass
......@@ -60,24 +55,6 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
// convert to run graph
}
void InterpreterCore::AddFetch(const std::vector<std::string>& fetch_names) {
auto* fetch_holder = main_program_.MutableBlock(0)->Var("fetch_vars");
fetch_holder->SetType(proto::VarType::FETCH_LIST);
fetch_holder->SetPersistable(true);
int i = 0;
for (auto& fetch_name : fetch_names) {
// append fetch op
auto* op = main_program_.MutableBlock(0)->AppendOp();
op->SetType("fetch_v2");
op->SetInput("X", {fetch_name});
op->SetOutput("Out", {"fetch_vars"});
op->SetAttr("col", {static_cast<int>(i)});
op->CheckAttrs();
i++;
}
}
paddle::framework::FetchList InterpreterCore::Run(
const std::vector<framework::LoDTensor>& feed_tensors) {
auto FeedInput = [&] {
......@@ -90,11 +67,11 @@ paddle::framework::FetchList InterpreterCore::Run(
};
if (is_build_ == false) {
paddle::framework::interpretercore::build_variable_scope(main_program_,
paddle::framework::interpreter::build_variable_scope(*block_,
global_scope_);
FeedInput();
paddle::framework::interpretercore::build_op_func_list(
place_, main_program_, &vec_func_list_, global_scope_);
paddle::framework::interpreter::build_op_func_list(
place_, *block_, &vec_func_list_, global_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert();
......@@ -104,7 +81,7 @@ paddle::framework::FetchList InterpreterCore::Run(
}
// return Fetch Tensors
auto* fetch_var = global_scope_->Var("fetch_vars");
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return *(fetch_var->GetMutable<framework::FetchList>());
}
......@@ -172,8 +149,7 @@ void InterpreterCore::Convert() {
std::vector<size_t> vec_temp;
for (auto& item : vec_instruction_[i].Outputs()) {
for (auto id : item.second) {
vec_temp =
interpretercore::merge_vector(vec_temp, input_var2op_info_[id]);
vec_temp = interpreter::merge_vector(vec_temp, input_var2op_info_[id]);
}
}
......@@ -438,8 +414,8 @@ void InterpreterCore::RunNextInstructions(
[&, next_id] { RunInstructionAsync(next_id); });
}
}
auto direct_run_ops = interpretercore::merge_vector(
next_instr.SyncRunIds(), next_instr.DirectRunIds());
auto direct_run_ops = interpreter::merge_vector(next_instr.SyncRunIds(),
next_instr.DirectRunIds());
size_t first_op = 0;
for (auto next_id : direct_run_ops) {
if (IsReady(next_id)) {
......@@ -538,11 +514,11 @@ void InterpreterCore::DryRunPrepare(
};
if (is_build_ == false) {
paddle::framework::interpretercore::build_variable_scope(main_program_,
paddle::framework::interpreter::build_variable_scope(*block_,
global_scope_);
FeedInput();
paddle::framework::interpretercore::build_op_func_list(
place_, main_program_, &vec_func_list_, global_scope_);
paddle::framework::interpreter::build_op_func_list(
place_, *block_, &vec_func_list_, global_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert();
......
......@@ -40,10 +40,9 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
class InterpreterCore {
public:
InterpreterCore(const platform::Place& place, const ProgramDesc& main_prog,
InterpreterCore(const platform::Place& place, BlockDesc* block,
VariableScope* global_scope,
const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names);
const std::vector<std::string>& feed_names);
paddle::framework::FetchList Run(
const std::vector<framework::LoDTensor>& feed_tensors);
......@@ -72,15 +71,14 @@ class InterpreterCore {
void RunInstructionAsync(size_t instr_id);
void RunNextInstructions(const Instruction& instr_id,
std::queue<size_t>* reserved_next_ops);
void AddFetch(const std::vector<std::string>& fetch_names);
void BuildSkipShareLoDInfo();
bool is_build_;
const platform::Place& place_;
ProgramDesc main_program_;
VariableScope* global_scope_;
BlockDesc* block_; // not owned
VariableScope* global_scope_; // not owned
std::vector<paddle::framework::OpFuncNode> vec_func_list_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
......@@ -88,7 +86,6 @@ class InterpreterCore {
InstructionInfo instruction_info_;
std::vector<size_t> dependecy_count_;
std::vector<std::vector<size_t>> input_var2op_info_;
std::vector<VariableMetaInfo> ref_coun_info_;
std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<std::string> feed_names_;
......@@ -97,7 +94,7 @@ class InterpreterCore {
StreamAnalyzer stream_analyzer_;
EventManager event_manager_;
EventsWaiter main_thread_blocker_;
interpretercore::AsyncWorkQueue async_work_queue_;
interpreter::AsyncWorkQueue async_work_queue_;
details::ExceptionHolder exception_holder_;
std::shared_ptr<EventsWaiter::EventNotifier> exception_notifier_{nullptr};
......
......@@ -18,7 +18,7 @@
namespace paddle {
namespace framework {
namespace interpretercore {
namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>;
AtomicVectorSizeT& AsyncWorkQueue::PrepareAtomicDeps(
......@@ -129,11 +129,9 @@ std::string get_memcpy_type(const platform::Place& src_place,
}
}
void build_variable_scope(const framework::ProgramDesc& pdesc,
void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
for (auto& var_desc : global_block.AllVars()) {
for (auto& var_desc : block.AllVars()) {
auto var_name = var_desc->Name();
if (var_name == framework::kEmptyVarName) {
continue;
......@@ -360,9 +358,9 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
std::vector<OpFuncNode> apply_data_transform(
const OpKernelType& expected_kernel_key, const platform::Place& place,
VariableValueMap& ins_map_temp, VariableScope* var_scope,
OpFuncNode& op_func_node) {
auto& op_base = op_func_node.operator_base_;
VariableValueMap* ins_map_temp, VariableScope* var_scope,
OpFuncNode* op_func_node) {
auto& op_base = op_func_node->operator_base_;
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base is null, please pass a valid "
"op_base in apply_data_transform."));
......@@ -372,7 +370,7 @@ std::vector<OpFuncNode> apply_data_transform(
no_data_transform_index; // record the no need transform variable index.
std::vector<OpFuncNode> copy_func_nodes; // return all the copy opfuncnode.
for (auto& var_name_item : ins_map_temp) {
for (auto& var_name_item : *ins_map_temp) {
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var = var_name_item.second[i];
auto& var_name = inputs_names[var_name_item.first].at(i);
......@@ -394,8 +392,8 @@ std::vector<OpFuncNode> apply_data_transform(
std::tie(new_var_name, copy_op_func_node) =
apply_place_transform_for_var(
kernel_type_for_var, expected_kernel_key, place, var_name,
var_name_item.first, op_func_node, var, var_scope);
op_func_node.input_index[var_name_item.first][i] =
var_name_item.first, *op_func_node, var, var_scope);
op_func_node->input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name);
copy_func_nodes.push_back(copy_op_func_node);
var_name_item.second[i] = var_scope->Var(new_var_name);
......@@ -414,23 +412,22 @@ std::vector<OpFuncNode> apply_data_transform(
}
}
}
op_func_node.no_data_transform_index = std::move(no_data_transform_index);
op_func_node->no_data_transform_index = std::move(no_data_transform_index);
return copy_func_nodes;
}
void build_op_func_list(const platform::Place& place,
const framework::ProgramDesc& pdesc,
const framework::BlockDesc& block,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
// Step 1: create all ops for global block.
auto ops = create_all_ops(global_block);
auto unused_var_map = get_unused_vars(global_block, ops);
// Step 1: create all ops for current block.
auto ops = create_all_ops(block);
auto unused_var_map = get_unused_vars(block, ops);
size_t ops_index = 0;
for (auto& op : global_block.AllOps()) {
for (auto& op : block.AllOps()) {
VLOG(6) << "Build OpFuncNode from : " << op->Type();
auto op_base = ops[ops_index++];
......@@ -498,7 +495,7 @@ void build_op_func_list(const platform::Place& place,
// apply_data_transform.
op_func_node.operator_base_ = op_base;
copy_op_to_insert = apply_data_transform(
expected_kernel_key, place, ins_map_temp, var_scope, op_func_node);
expected_kernel_key, place, &ins_map_temp, var_scope, &op_func_node);
for (auto& item : copy_op_to_insert) {
vec_func_list->push_back(item);
}
......@@ -576,6 +573,25 @@ void build_op_func_list(const platform::Place& place,
}
}
void add_fetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block) {
auto* fetch_holder = block->Var(kFetchVarName);
fetch_holder->SetType(proto::VarType::FETCH_LIST);
fetch_holder->SetPersistable(true);
int i = 0;
for (auto& fetch_name : fetch_names) {
// append fetch op
auto* op = block->AppendOp();
op->SetType("fetch_v2");
op->SetInput("X", {fetch_name});
op->SetOutput("Out", {kFetchVarName});
op->SetAttr("col", {static_cast<int>(i)});
op->CheckAttrs();
i++;
}
}
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
const std::vector<size_t>& second) {
std::vector<size_t> out(first.size() + second.size());
......@@ -590,6 +606,6 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
return out;
}
} // namespace interpretercore
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -48,9 +48,10 @@
namespace paddle {
namespace framework {
namespace interpretercore {
namespace interpreter {
using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
static constexpr char kFetchVarName[] = "fetch_vars";
class AsyncWorkQueue {
public:
......@@ -96,17 +97,20 @@ class AsyncWorkQueue {
std::string get_memcpy_type(const platform::Place& src_place,
const platform::Place& dst_place);
void build_variable_scope(const framework::ProgramDesc& pdesc,
void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope);
void build_op_func_list(const platform::Place& place,
const framework::ProgramDesc& pdesc,
const framework::BlockDesc& block,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope);
void add_fetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block);
std::vector<size_t> merge_vector(const std::vector<size_t>& first,
const std::vector<size_t>& second);
} // namespace interpretercore
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -776,7 +776,7 @@ class Instruction {
std::vector<std::pair<Variable*, Variable*>> vec_inplace_in_to_out_;
};
namespace interpretercore {
namespace interpreter {
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
......@@ -787,7 +787,7 @@ static bool IsMemcpyH2D(const Instruction& instr) {
static bool IsMemcpyD2H(const Instruction& instr) {
return instr.OpBase()->Type() == kMemcpyD2H;
}
} // namespace interpretercore
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -41,8 +41,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
// run startup program
std::vector<paddle::framework::OpFuncNode> vec_func_list;
paddle::framework::interpretercore::build_op_func_list(
place_, startup_prog, &vec_func_list, &global_scope_);
paddle::framework::interpreter::build_op_func_list(
place_, startup_prog.Block(0), &vec_func_list, &global_scope_);
}
paddle::framework::FetchList StandaloneExecutor::Run(
......@@ -96,8 +96,15 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
if (iter == interpretercores_.end()) {
VLOG(3) << "create interpreter_core for " << oss.str();
auto core = std::make_shared<InterpreterCore>(
place_, main_prog_, &global_scope_, feed_names, fetch_names);
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy a
// new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(main_prog_);
auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block);
auto core = std::make_shared<InterpreterCore>(place_, block, &global_scope_,
feed_names);
programs_.emplace(oss.str(), new_prog);
interpretercores_.emplace(oss.str(), core);
return core;
} else {
......
......@@ -62,6 +62,7 @@ class StandaloneExecutor : public ExecutorBase {
Scope* outer_scope_;
VariableScope global_scope_;
std::unordered_map<std::string, std::shared_ptr<ProgramDesc>> programs_;
std::unordered_map<std::string, std::shared_ptr<InterpreterCore>>
interpretercores_;
};
......
......@@ -101,10 +101,10 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const OpFuncNode& op_func_node) {
auto& op_type = op_func_node.operator_base_->Type();
auto* dev_ctx = op_func_node.dev_ctx_;
if (op_type == interpretercore::kMemcpyH2D) {
if (op_type == interpreter::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctx_pool_.Get(place_);
} else if (op_type == interpretercore::kMemcpyD2H) {
} else if (op_type == interpreter::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from h2d_context_pool_";
dev_ctx = h2d_ctx_pool_.Get(place_);
}
......@@ -122,8 +122,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) {
return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() ||
interpretercore::IsMemcpyD2H(cur_instr) ||
interpretercore::IsMemcpyH2D(next_instr));
interpreter::IsMemcpyD2H(cur_instr) ||
interpreter::IsMemcpyH2D(next_instr));
}
platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册