未验证 提交 1f0512be 编写于 作者: L Leo Chen 提交者: GitHub

[new feature] add local scope for interpretercore (#37379)

上级 964e20e0
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true, PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true,
"Use inplace in new executor"); "Use inplace in new executor");
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope, true,
"Use local_scope in new executor(especially used "
"in UT), can turn off for better performance");
DECLARE_bool(check_nan_inf); DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
...@@ -48,6 +51,14 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -48,6 +51,14 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
exception_notifier_ = main_thread_blocker_.RegisterEvent( exception_notifier_ = main_thread_blocker_.RegisterEvent(
kExceptionCaught, [this]() { return exception_holder_.IsCaught(); }); kExceptionCaught, [this]() { return exception_holder_.IsCaught(); });
create_local_scope_ = FLAGS_new_executor_use_local_scope;
if (FLAGS_new_executor_use_local_scope) {
auto local_scope = &global_scope->GetMutableScope()->NewScope();
local_scope->AddListener(global_scope->Listener());
local_scope_ = local_scope;
}
VLOG(4) << "create_local_scope_ is " << create_local_scope_;
// prune // prune
// optmize graph pass // optmize graph pass
...@@ -62,10 +73,15 @@ InterpreterCore::~InterpreterCore() { ...@@ -62,10 +73,15 @@ InterpreterCore::~InterpreterCore() {
async_work_queue_.reset(nullptr); async_work_queue_.reset(nullptr);
} }
void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
copy_program_ = prog;
}
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
bool is_build = is_build_; bool is_build = is_build_;
global_scope_->SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, is_build); Prepare(feed_names, feed_tensors, is_build);
if (is_build) { if (is_build) {
...@@ -79,13 +95,27 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -79,13 +95,27 @@ paddle::framework::FetchList InterpreterCore::Run(
paddle::framework::FetchList InterpreterCore::Run() { paddle::framework::FetchList InterpreterCore::Run() {
if (!is_build_) { if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_); if (create_local_scope_ &&
global_scope_->GetMutableLocalScope() !=
global_scope_->GetMutableScope() &&
global_scope_->GetMutableLocalScope()) {
VLOG(4) << "Clear previous local scope before run";
VLOG(4) << global_scope_->GetMutableScope() << " "
<< global_scope_->GetMutableLocalScope();
platform::DeviceContextPool::Instance().Get(place_)->Wait();
// TODO(zhiqiu): clear the tensor holder of all vars in previous local
// scope?
}
global_scope_->SetLocalScope(local_scope_);
paddle::framework::interpreter::build_variable_scope(block_, global_scope_,
create_local_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_); place_, block_, &op_func_nodes, global_scope_, create_local_scope_);
is_build_ = true; is_build_ = true;
// convert vec func_list to graph // convert vec func_list to graph
Convert(&op_func_nodes); Convert(&op_func_nodes);
} else { } else {
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
} }
...@@ -300,7 +330,10 @@ void InterpreterCore::BuildSkipShareLoDInfo() { ...@@ -300,7 +330,10 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
void InterpreterCore::RunInstruction(const Instruction& instr_node) { void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto* op = instr_node.OpBase(); auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace(); auto place = instr_node.DeviceContext().GetPlace();
VLOG(4) << "Start run" << place << " " << op->DebugStringEx(global_scope_); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(global_scope_);
Scope* local_scope = create_local_scope_
? global_scope_->GetMutableLocalScope()
: global_scope_->GetMutableScope();
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op); auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
{ {
...@@ -325,13 +358,14 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -325,13 +358,14 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
} }
{ {
platform::RecordEvent compute_event("Compute"); platform::RecordEvent compute_event("Compute");
if (op_with_kernel == nullptr) if (op_with_kernel == nullptr) {
instr_node.OpBase()->Run(*global_scope_->GetScope(), place_); instr_node.OpBase()->Run(*local_scope, place_);
else } else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
}
} }
VLOG(4) << "End run" << place << " " << op->DebugStringEx(global_scope_); VLOG(4) << "End run " << place << " " << op->DebugStringEx(global_scope_);
/*For profiling/benchmark only*/ /*For profiling/benchmark only*/
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
...@@ -372,8 +406,8 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -372,8 +406,8 @@ void InterpreterCore::ExecuteInstructionList(
} }
} }
auto event_id = main_thread_blocker_.WaitEvent(); auto event_name = main_thread_blocker_.WaitEvent();
VLOG(3) << "event_id " << event_id; VLOG(3) << "event_name: " << event_name;
if (UNLIKELY(exception_holder_.IsCaught())) { if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(4) << "Exception caught " << exception_holder_.Type(); VLOG(4) << "Exception caught " << exception_holder_.Type();
...@@ -526,8 +560,9 @@ void InterpreterCore::Prepare( ...@@ -526,8 +560,9 @@ void InterpreterCore::Prepare(
VLOG(4) << "Feed inputs"; VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) { for (size_t i = 0; i < feed_names.size(); ++i) {
auto* feed_var = global_scope_->FindVar(feed_names[i]); auto* feed_var = global_scope_->FindVar(feed_names[i]);
PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound( PADDLE_ENFORCE_NOT_NULL(
"feed_var shall not be nullptr.")); feed_var, platform::errors::NotFound(
"Variable %s should not be nullptr.", feed_names[i]));
auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>(); auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>();
feed_tensor->ShareDataWith(feed_tensors[i]); feed_tensor->ShareDataWith(feed_tensors[i]);
...@@ -536,11 +571,12 @@ void InterpreterCore::Prepare( ...@@ -536,11 +571,12 @@ void InterpreterCore::Prepare(
}; };
if (!is_build_) { if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_); paddle::framework::interpreter::build_variable_scope(block_, global_scope_,
create_local_scope_);
FeedInput(); FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_); place_, block_, &op_func_nodes, global_scope_, create_local_scope_);
is_build_ = true; is_build_ = true;
// convert vec func_list to graph // convert vec func_list to graph
Convert(&op_func_nodes); Convert(&op_func_nodes);
...@@ -556,6 +592,7 @@ void InterpreterCore::Prepare( ...@@ -556,6 +592,7 @@ void InterpreterCore::Prepare(
interpreter::CostInfo InterpreterCore::DryRun( interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
global_scope_->SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, true); Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info; interpreter::CostInfo cost_info;
{ {
......
...@@ -55,6 +55,8 @@ class InterpreterCore { ...@@ -55,6 +55,8 @@ class InterpreterCore {
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog);
private: private:
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes); void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
...@@ -85,7 +87,13 @@ class InterpreterCore { ...@@ -85,7 +87,13 @@ class InterpreterCore {
bool is_build_; bool is_build_;
const platform::Place& place_; const platform::Place& place_;
const BlockDesc& block_; // not owned const BlockDesc& block_; // not owned
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will
// copy a new program and block, the copy_program_ here is used to
// hold the program, otherwise block_ maybe not valid after the
// new program is deleted.
std::shared_ptr<ProgramDesc> copy_program_{nullptr};
VariableScope* global_scope_; // not owned VariableScope* global_scope_; // not owned
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
...@@ -102,6 +110,8 @@ class InterpreterCore { ...@@ -102,6 +110,8 @@ class InterpreterCore {
std::unique_ptr<InterpreterCoreGarbageCollector> gc_; std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_; std::vector<paddle::platform::DeviceEvent> gc_event_;
bool create_local_scope_{true};
Scope* local_scope_{nullptr}; // not owned
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -132,23 +132,35 @@ std::string get_memcpy_type(const platform::Place& src_place, ...@@ -132,23 +132,35 @@ std::string get_memcpy_type(const platform::Place& src_place,
} }
void build_variable_scope(const framework::BlockDesc& block, void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope) { VariableScope* var_scope, bool use_local_scope) {
VLOG(3) << "Creating Variables";
auto inner_scope = var_scope->GetMutableScope();
// NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
// created in var_scope.scope_ , and other scope is created in local scope.
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
for (auto& var_desc : block.AllVars()) { for (auto& var_desc : block.AllVars()) {
auto var_name = var_desc->Name(); auto var_name = var_desc->Name();
if (var_name == framework::kEmptyVarName) { if (var_name == framework::kEmptyVarName) {
continue; continue;
} }
if (var_desc->Persistable()) {
auto* ptr = inner_scope->Var(var_name);
if (nullptr == var_scope->FindVar(var_name)) { VLOG(3) << "Initialize Variable " << var_name;
var_scope->AddVar(var_desc->Name(), var_desc); InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
<< ptr << " type is " << static_cast<int>(var_desc->GetType());
} else { } else {
auto* var_desc_tmp = var_scope->VarDesc(var_name); auto* ptr = local_scope->Var(var_name);
if (nullptr == var_desc_tmp) { InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "update var:" << var_name << " desc from nullptr into " VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
<< var_desc; << ptr << "Variable Type "
var_scope->SetVarDesc(var_name, var_desc); << static_cast<int>(var_desc->GetType());
}
} }
var_scope->SetVarDesc(var_name, var_desc);
} }
} }
...@@ -237,14 +249,14 @@ void apply_device_guard(const OperatorBase* op_base, ...@@ -237,14 +249,14 @@ void apply_device_guard(const OperatorBase* op_base,
void deal_operator_base(const platform::Place& place, void deal_operator_base(const platform::Place& place,
const VariableScope* var_scope, const VariableScope* var_scope,
std::shared_ptr<OperatorBase> op_base, std::shared_ptr<OperatorBase> op_base,
OpFuncNode* op_func_node) { OpFuncNode* op_func_node, Scope* local_scope) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
// input, output is prepared. set the other attributes. // input, output is prepared. set the other attributes.
op_func_node->operator_base_ = op_base; op_func_node->operator_base_ = op_base;
op_func_node->type_ = OpFuncType::kQueueSync; // alway Sync op_func_node->type_ = OpFuncType::kQueueSync; // alway Sync
op_func_node->kernel_func_ = nullptr; op_func_node->kernel_func_ = nullptr;
op_base->Run(*var_scope->GetScope(), place); // Run without data transformer. op_base->Run(*local_scope, place); // Run without data transformer.
std::unordered_set<int> no_data_transform_index; std::unordered_set<int> no_data_transform_index;
for (auto& it : op_func_node->input_index) { for (auto& it : op_func_node->input_index) {
...@@ -288,12 +300,21 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var( ...@@ -288,12 +300,21 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
const OpKernelType& kernel_type_for_var, const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key, const platform::Place& place, const OpKernelType& expected_kernel_key, const platform::Place& place,
const std::string& var_name, const std::string& outer_name, const std::string& var_name, const std::string& outer_name,
const OpFuncNode& op_func_node, Variable* var, VariableScope* var_scope) { const OpFuncNode& op_func_node, Variable* var, VariableScope* var_scope,
bool use_local_scope = true) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
std::string new_var_name = std::string new_var_name =
var_name + "_copy_" + std::to_string(var_scope->VarSize() + 1); var_name + "_copy_" + std::to_string(var_scope->VarSize() + 1);
var_scope->AddVar(new_var_name, nullptr);
auto* ptr = local_scope->Var(new_var_name);
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var->Type()));
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
<< ptr << "Variable Type " << var->Type();
var_scope->SetVarDesc(var_name, nullptr);
VariableNameMap copy_in_map; VariableNameMap copy_in_map;
copy_in_map["X"] = {var_name}; copy_in_map["X"] = {var_name};
...@@ -368,7 +389,8 @@ void apply_data_transform(const OpKernelType& expected_kernel_key, ...@@ -368,7 +389,8 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
const platform::Place& place, const platform::Place& place,
VariableValueMap* ins_map_temp, VariableValueMap* ins_map_temp,
VariableScope* var_scope, OpFuncNode* op_func_node, VariableScope* var_scope, OpFuncNode* op_func_node,
std::vector<OpFuncNode>* copy_func_nodes) { std::vector<OpFuncNode>* copy_func_nodes,
bool use_local_scope = true) {
auto op_base = op_func_node->operator_base_.get(); auto op_base = op_func_node->operator_base_.get();
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base is null, please pass a valid " "op_base is null, please pass a valid "
...@@ -402,9 +424,10 @@ void apply_data_transform(const OpKernelType& expected_kernel_key, ...@@ -402,9 +424,10 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
std::string new_var_name; std::string new_var_name;
OpFuncNode copy_op_func_node; OpFuncNode copy_op_func_node;
std::tie(new_var_name, copy_op_func_node) = std::tie(new_var_name, copy_op_func_node) =
apply_place_transform_for_var( apply_place_transform_for_var(kernel_type_for_var,
kernel_type_for_var, expected_kernel_key, place, var_name, expected_kernel_key, place, var_name,
var_name_item.first, *op_func_node, var, var_scope); var_name_item.first, *op_func_node,
var, var_scope, use_local_scope);
op_func_node->input_index[var_name_item.first][i] = op_func_node->input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name); var_scope->VarId(new_var_name);
copy_func_nodes->emplace_back(copy_op_func_node); copy_func_nodes->emplace_back(copy_op_func_node);
...@@ -438,7 +461,9 @@ void apply_data_transform(const OpKernelType& expected_kernel_key, ...@@ -438,7 +461,9 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
void build_op_func_list(const platform::Place& place, void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block, const framework::BlockDesc& block,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) { VariableScope* var_scope, bool use_local_scope) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
std::vector<std::shared_ptr<OperatorBase>> std::vector<std::shared_ptr<OperatorBase>>
ops; // its elements will be moved to vec_func_list ops; // its elements will be moved to vec_func_list
...@@ -478,7 +503,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -478,7 +503,7 @@ void build_op_func_list(const platform::Place& place,
if (dynamic_cast<const framework::OperatorWithKernel*>(op) == nullptr) { if (dynamic_cast<const framework::OperatorWithKernel*>(op) == nullptr) {
// op is not a operatorwithkernel, so direcly run OperatorBase::Run() // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, ops[i], &op_func_node); deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
} else { } else {
// construct RuntimeContext and analysis KernelType // construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {}); RuntimeContext runtime_context({}, {});
...@@ -520,7 +545,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -520,7 +545,7 @@ void build_op_func_list(const platform::Place& place,
// apply_data_transform. // apply_data_transform.
op_func_node.operator_base_ = ops[i]; op_func_node.operator_base_ = ops[i];
apply_data_transform(expected_kernel_key, place, &ins_map_temp, var_scope, apply_data_transform(expected_kernel_key, place, &ins_map_temp, var_scope,
&op_func_node, &copy_op_to_insert); &op_func_node, &copy_op_to_insert, use_local_scope);
for (auto& item : copy_op_to_insert) { for (auto& item : copy_op_to_insert) {
vec_func_list->push_back(item); vec_func_list->push_back(item);
} }
...@@ -631,16 +656,16 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first, ...@@ -631,16 +656,16 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
} }
void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences, void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
std::map<int, std::list<int>>& var2min_rw_op, std::map<int, std::list<int>>* var2min_rw_op,
int cur_op, int rw_var) { int cur_op, int rw_var) {
// rw_var is inputs or outputs of cur_op // rw_var is inputs or outputs of cur_op
// this function update the var2min_rw_op set . // this function update the var2min_rw_op set .
if (var2min_rw_op.find(rw_var) == var2min_rw_op.end()) if (var2min_rw_op->find(rw_var) == var2min_rw_op->end())
var2min_rw_op[rw_var] = std::list<int>(); (*var2min_rw_op)[rw_var] = std::list<int>();
for (auto dep_op : op2dependences.at(cur_op)) { for (auto dep_op : op2dependences.at(cur_op)) {
var2min_rw_op[rw_var].remove(dep_op); (*var2min_rw_op)[rw_var].remove(dep_op);
} }
var2min_rw_op[rw_var].push_back(cur_op); (*var2min_rw_op)[rw_var].push_back(cur_op);
} }
std::map<int, std::list<int>> get_downstream_map( std::map<int, std::list<int>> get_downstream_map(
...@@ -702,7 +727,7 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -702,7 +727,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (auto& item : for (auto& item :
vec_instruction[op_idx].Inputs()) { // for all inputs(read only) vec_instruction[op_idx].Inputs()) { // for all inputs(read only)
for (auto var : item.second) { for (auto var : item.second) {
update_var_min_rw_op(op2dependences, var2min_rw_op, op_idx, var); update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
remove_duplicate.insert(var); remove_duplicate.insert(var);
} }
} }
...@@ -713,7 +738,7 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -713,7 +738,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
var2recent_write_op[var] = op_idx; var2recent_write_op[var] = op_idx;
if (remove_duplicate.count(var) == if (remove_duplicate.count(var) ==
0) { // var in input list and in output list, so remove it. 0) { // var in input list and in output list, so remove it.
update_var_min_rw_op(op2dependences, var2min_rw_op, op_idx, var); update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
} }
} }
} }
......
...@@ -51,7 +51,7 @@ namespace framework { ...@@ -51,7 +51,7 @@ namespace framework {
namespace interpreter { namespace interpreter {
using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>; using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
static constexpr char kFetchVarName[] = "fetch_vars"; static constexpr char kFetchVarName[] = "fetch";
class AsyncWorkQueue { class AsyncWorkQueue {
public: public:
...@@ -98,12 +98,13 @@ std::string get_memcpy_type(const platform::Place& src_place, ...@@ -98,12 +98,13 @@ std::string get_memcpy_type(const platform::Place& src_place,
const platform::Place& dst_place); const platform::Place& dst_place);
void build_variable_scope(const framework::BlockDesc& block, void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope); VariableScope* var_scope,
bool use_local_scope = true);
void build_op_func_list(const platform::Place& place, void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block, const framework::BlockDesc& block,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope); VariableScope* var_scope, bool use_local_scope = true);
std::map<int, std::list<int>> build_op_downstream_map( std::map<int, std::list<int>> build_op_downstream_map(
const std::vector<Instruction>& vec_instruction); const std::vector<Instruction>& vec_instruction);
......
...@@ -497,7 +497,14 @@ VariableScope::~VariableScope() { ...@@ -497,7 +497,14 @@ VariableScope::~VariableScope() {
} }
} }
const Scope* VariableScope::GetScope() const { return scope_; } Scope* VariableScope::GetMutableScope() const { return scope_; }
Scope* VariableScope::GetMutableLocalScope() const { return local_scope_; }
void VariableScope::SetLocalScope(Scope* local_scope) {
VLOG(4) << "Set local scope: " << local_scope;
local_scope_ = local_scope;
}
Variable* VariableScope::FindVar(const std::string& name) const { Variable* VariableScope::FindVar(const std::string& name) const {
auto it = name2id_.find(name); auto it = name2id_.find(name);
...@@ -554,8 +561,9 @@ Variable* VariableScope::Var(const std::string& name) const { ...@@ -554,8 +561,9 @@ Variable* VariableScope::Var(const std::string& name) const {
size_t VariableScope::VarSize() const { return var_list_.size(); } size_t VariableScope::VarSize() const { return var_list_.size(); }
void VariableScope::AddVar(const std::string& name, void VariableScope::AddVar(const std::string& name,
framework::VarDesc* var_desc) { // NOLINT framework::VarDesc* var_desc,
auto v = scope_->Var(name); bool local_scope) { // NOLINT
auto v = local_scope ? local_scope_->Var(name) : scope_->Var(name);
if (nullptr == var_desc) { if (nullptr == var_desc) {
v->GetMutable<LoDTensor>(); v->GetMutable<LoDTensor>();
} else { } else {
...@@ -606,9 +614,9 @@ VariableScopeListener::VariableScopeListener(VariableScope* var_scope) { ...@@ -606,9 +614,9 @@ VariableScopeListener::VariableScopeListener(VariableScope* var_scope) {
var_scope_ = var_scope; var_scope_ = var_scope;
} }
void VariableScopeListener::onCreateVariable(const std::string& name) { void VariableScopeListener::onCreateVariable(const std::string& name,
auto v = var_scope_->scope_->GetVar(name); // must exsit in outer_scope_ Variable* v) {
if (!var_scope_->HasVar(name)) { // may exist in variable scope. if (!var_scope_->HasVar(name)) { // may exist in variable scope.
VLOG(4) << "Calling VariableScope::onCreateVariable with var_name: " VLOG(4) << "Calling VariableScope::onCreateVariable with var_name: "
<< name; << name;
var_scope_->name2id_[name] = var_scope_->VarSize(); var_scope_->name2id_[name] = var_scope_->VarSize();
......
...@@ -155,7 +155,7 @@ class VariableScope; ...@@ -155,7 +155,7 @@ class VariableScope;
class VariableScopeListener : public ScopeListener { class VariableScopeListener : public ScopeListener {
public: public:
explicit VariableScopeListener(VariableScope* var_scope_); explicit VariableScopeListener(VariableScope* var_scope_);
void onCreateVariable(const std::string& name) override; void onCreateVariable(const std::string& name, Variable* v) override;
void onDeleteVariable(const std::string& name) override; void onDeleteVariable(const std::string& name) override;
void onRenameVariable(const std::string& old_name, void onRenameVariable(const std::string& old_name,
const std::string& new_name) override; const std::string& new_name) override;
...@@ -177,7 +177,11 @@ class VariableScope : public ScopeBase { ...@@ -177,7 +177,11 @@ class VariableScope : public ScopeBase {
public: public:
explicit VariableScope(Scope* scope); explicit VariableScope(Scope* scope);
const Scope* GetScope() const; Scope* GetMutableScope() const;
Scope* GetMutableLocalScope() const;
void SetLocalScope(Scope* local_scope);
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
...@@ -199,7 +203,8 @@ class VariableScope : public ScopeBase { ...@@ -199,7 +203,8 @@ class VariableScope : public ScopeBase {
size_t VarSize() const; size_t VarSize() const;
void AddVar(const std::string& name, VarDesc* var_desc); void AddVar(const std::string& name, VarDesc* var_desc,
bool local_scope = false);
void AddVar(const std::string& name, const Variable& var); void AddVar(const std::string& name, const Variable& var);
...@@ -219,15 +224,21 @@ class VariableScope : public ScopeBase { ...@@ -219,15 +224,21 @@ class VariableScope : public ScopeBase {
return vec_meta_info_; return vec_meta_info_;
} }
const std::shared_ptr<VariableScopeListener>& Listener() const {
return listener_;
}
friend class VariableScopeListener; friend class VariableScopeListener;
private: private:
std::vector<Variable*> var_list_; std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_; std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_; std::vector<VariableMetaInfo> vec_meta_info_;
Scope* scope_ = nullptr; Scope* scope_{nullptr};
// TODO(zhiqiu): find a better way to support local scope.
Scope* local_scope_{nullptr};
// mutable RWLock vars_lock_; // mutable RWLock vars_lock_;
std::shared_ptr<VariableScopeListener> listener_; std::shared_ptr<VariableScopeListener> listener_{nullptr};
}; };
class NextInstruction { class NextInstruction {
......
...@@ -24,30 +24,36 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -24,30 +24,36 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
startup_prog_(startup_prog), startup_prog_(startup_prog),
main_prog_(main_prog), main_prog_(main_prog),
global_scope_(VariableScope(scope)) { global_scope_(VariableScope(scope)) {
// init scope // NOTE(zhiqiu): for startup_program, initialize scope and run once
BuildVariableScope(startup_prog, &global_scope_); // if startup_program is empty, the scope is initialize during first run
if (startup_prog.Block(0).AllOps().size() > 0) {
if (scope != nullptr) { VLOG(4) << "Run startup program";
auto name_list = scope->LocalVarNames(); // init scope
for (auto name : name_list) { BuildVariableScope(startup_prog, &global_scope_);
auto v = scope->Var(name);
if (!global_scope_.HasVar(name)) { if (scope != nullptr) {
global_scope_.AddVar(name, *v); auto name_list = scope->LocalVarNames();
for (auto name : name_list) {
auto v = scope->Var(name);
if (!global_scope_.HasVar(name)) {
global_scope_.AddVar(name, *v);
}
} }
} }
}
// run startup program std::vector<paddle::framework::OpFuncNode> vec_func_list;
std::vector<paddle::framework::OpFuncNode> vec_func_list; // No need to use_local_scope for startup_program, its variables are
paddle::framework::interpreter::build_op_func_list( // persistable
place_, startup_prog.Block(0), &vec_func_list, &global_scope_); paddle::framework::interpreter::build_op_func_list(
place_, startup_prog.Block(0), &vec_func_list, &global_scope_, false);
}
} }
paddle::framework::FetchList StandaloneExecutor::Run( paddle::framework::FetchList StandaloneExecutor::Run(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors, const std::vector<framework::LoDTensor>& feed_tensors,
const std::vector<std::string>& fetch_names) { const std::vector<std::string>& fetch_names) {
auto core = GetInterpreterCore(feed_names, fetch_names); auto core = GetInterpreterCore(feed_names, fetch_names, true);
return core->Run(feed_names, feed_tensors); return core->Run(feed_names, feed_tensors);
} }
...@@ -55,15 +61,15 @@ paddle::framework::FetchList StandaloneExecutor::Run( ...@@ -55,15 +61,15 @@ paddle::framework::FetchList StandaloneExecutor::Run(
paddle::framework::FetchList StandaloneExecutor::Run( paddle::framework::FetchList StandaloneExecutor::Run(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names) { const std::vector<std::string>& fetch_names) {
auto core = GetInterpreterCore(feed_names, fetch_names); auto core = GetInterpreterCore(feed_names, fetch_names, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
return core->Run(); return core->Run();
} }
framework::interpreter::CostInfo StandaloneExecutor::DryRun( framework::interpreter::CostInfo StandaloneExecutor::DryRun(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
auto core = GetInterpreterCore(feed_names, {}); auto core = GetInterpreterCore(feed_names, {}, true);
return core->DryRun(feed_names, feed_tensors); return core->DryRun(feed_names, feed_tensors);
} }
...@@ -85,7 +91,7 @@ void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc, ...@@ -85,7 +91,7 @@ void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc,
std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore( std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names) { const std::vector<std::string>& fetch_names, bool add_fetch_op) {
std::ostringstream oss; std::ostringstream oss;
oss << "feed:"; oss << "feed:";
for (auto& feedname : feed_names) { for (auto& feedname : feed_names) {
...@@ -100,15 +106,22 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore( ...@@ -100,15 +106,22 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
if (iter == interpretercores_.end()) { if (iter == interpretercores_.end()) {
VLOG(3) << "create interpreter_core for " << oss.str(); VLOG(3) << "create interpreter_core for " << oss.str();
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy a VLOG(3) << "add fetch op: " << add_fetch_op;
// new program. std::shared_ptr<InterpreterCore> core = nullptr;
auto new_prog = std::make_shared<framework::ProgramDesc>(main_prog_); if (add_fetch_op) {
auto* block = new_prog->MutableBlock(0); // NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy
interpreter::add_fetch(fetch_names, block); // a
// new program.
auto core = auto new_prog = std::make_shared<framework::ProgramDesc>(main_prog_);
std::make_shared<InterpreterCore>(place_, *block, &global_scope_); auto* block = new_prog->MutableBlock(0);
programs_.emplace(oss.str(), new_prog); interpreter::add_fetch(fetch_names, block);
core = std::make_shared<InterpreterCore>(place_, *block, &global_scope_);
core->SetCopyProgram(new_prog);
} else {
core = std::make_shared<InterpreterCore>(place_, main_prog_.Block(0),
&global_scope_);
}
interpretercores_.emplace(oss.str(), core); interpretercores_.emplace(oss.str(), core);
return core; return core;
} else { } else {
......
...@@ -61,14 +61,13 @@ class StandaloneExecutor : public ExecutorBase { ...@@ -61,14 +61,13 @@ class StandaloneExecutor : public ExecutorBase {
std::shared_ptr<InterpreterCore> GetInterpreterCore( std::shared_ptr<InterpreterCore> GetInterpreterCore(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names); const std::vector<std::string>& fetch_names, bool add_fetch_op);
const platform::Place& place_; const platform::Place& place_;
const ProgramDesc& startup_prog_; const ProgramDesc& startup_prog_;
const ProgramDesc& main_prog_; const ProgramDesc& main_prog_;
VariableScope global_scope_; VariableScope global_scope_;
std::unordered_map<std::string, std::shared_ptr<ProgramDesc>> programs_;
std::unordered_map<std::string, std::shared_ptr<InterpreterCore>> std::unordered_map<std::string, std::shared_ptr<InterpreterCore>>
interpretercores_; interpretercores_;
}; };
......
...@@ -67,7 +67,7 @@ Variable* Scope::Var(const std::string& name) { ...@@ -67,7 +67,7 @@ Variable* Scope::Var(const std::string& name) {
ret = VarInternal(name); ret = VarInternal(name);
} }
for (auto l : listeners_) { for (auto l : listeners_) {
l->onCreateVariable(name); l->onCreateVariable(name, ret);
} }
return ret; return ret;
} }
...@@ -85,7 +85,7 @@ Variable* Scope::Var(std::string* name) { ...@@ -85,7 +85,7 @@ Variable* Scope::Var(std::string* name) {
ret = VarInternal(new_name); ret = VarInternal(new_name);
} }
for (auto l : listeners_) { for (auto l : listeners_) {
l->onCreateVariable(new_name); l->onCreateVariable(new_name, ret);
} }
return ret; return ret;
} }
......
...@@ -58,7 +58,7 @@ class ScopeListener { ...@@ -58,7 +58,7 @@ class ScopeListener {
// in original Scope. // in original Scope.
public: public:
virtual ~ScopeListener() {} virtual ~ScopeListener() {}
virtual void onCreateVariable(const std::string& name) {} virtual void onCreateVariable(const std::string& name, Variable* v) {}
virtual void onDeleteVariable(const std::string& name) {} virtual void onDeleteVariable(const std::string& name) {}
virtual void onRenameVariable(const std::string& old_name, virtual void onRenameVariable(const std::string& old_name,
const std::string& new_name) {} const std::string& new_name) {}
......
...@@ -112,13 +112,6 @@ class FetchV2Op : public framework::OperatorWithKernel { ...@@ -112,13 +112,6 @@ class FetchV2Op : public framework::OperatorWithKernel {
} }
}; };
class FetchV2InferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
}
};
class FetchV2Kernel { class FetchV2Kernel {
public: public:
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -211,7 +204,6 @@ namespace ops = paddle::operators; ...@@ -211,7 +204,6 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OPERATOR( REGISTER_OPERATOR(
fetch_v2, ops::FetchV2Op, ops::FetchV2OpProtoMaker, fetch_v2, ops::FetchV2Op, ops::FetchV2OpProtoMaker,
ops::FetchV2InferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
...@@ -288,7 +288,10 @@ def has_feed_operators(block, feed_targets, feed_holder_name): ...@@ -288,7 +288,10 @@ def has_feed_operators(block, feed_targets, feed_holder_name):
return feed_count > 0 return feed_count > 0
def has_fetch_operators(block, fetch_targets, fetch_holder_name): def has_fetch_operators(block,
fetch_targets,
fetch_holder_name,
fetch_op='fetch'):
""" Check whether the block already has fetch operators. """ Check whether the block already has fetch operators.
Return false if the block does not have any fetch operators. Return false if the block does not have any fetch operators.
...@@ -303,6 +306,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name): ...@@ -303,6 +306,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
fetch_holder_name: the name of the variable that holds the data of fetch_holder_name: the name of the variable that holds the data of
all fetch targets. The type of this fetch_holder variable is all fetch targets. The type of this fetch_holder variable is
FETCH_LIST, which is essentially vector<LoDTensor>. FETCH_LIST, which is essentially vector<LoDTensor>.
fetch_op: the operator name of fetch
Return: Return:
A boolean value that indicates whether a block has fetch operators A boolean value that indicates whether a block has fetch operators
...@@ -311,7 +315,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name): ...@@ -311,7 +315,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
fetch_count = 0 fetch_count = 0
for op in block.ops: for op in block.ops:
if op.desc.type() == 'fetch': if op.desc.type() == fetch_op:
fetch_count += 1 fetch_count += 1
assert op.desc.output('Out')[0] == fetch_holder_name assert op.desc.output('Out')[0] == fetch_holder_name
fetch_target_name = op.desc.input('X')[0] fetch_target_name = op.desc.input('X')[0]
...@@ -740,7 +744,7 @@ class Executor(object): ...@@ -740,7 +744,7 @@ class Executor(object):
fetch_list, fetch_list,
feed_var_name, feed_var_name,
fetch_var_name, fetch_var_name,
skip_fetch=False): use_fetch_v2=False):
tmp_program = program.clone() tmp_program = program.clone()
global_block = tmp_program.global_block() global_block = tmp_program.global_block()
...@@ -775,17 +779,21 @@ class Executor(object): ...@@ -775,17 +779,21 @@ class Executor(object):
warnings.warn( warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned." "The variable %s is not found in program. It is not declared or is pruned."
% name) % name)
if skip_fetch:
return tmp_program if use_fetch_v2:
fetch_op = 'fetch_v2'
else:
fetch_op = 'fetch'
# append fetch_operators # append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): if not has_fetch_operators(global_block, fetch_list, fetch_var_name,
fetch_op):
for i, var in enumerate(fetch_list): for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance( assert isinstance(var, Variable) or isinstance(
var, six.string_types), ( var, six.string_types), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var))) "Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op( global_block.append_op(
type='fetch', type=fetch_op,
inputs={'X': [var]}, inputs={'X': [var]},
outputs={'Out': [fetch_var]}, outputs={'Out': [fetch_var]},
attrs={'col': i}) attrs={'col': i})
...@@ -1345,7 +1353,13 @@ class Executor(object): ...@@ -1345,7 +1353,13 @@ class Executor(object):
fetch_list=fetch_list, fetch_list=fetch_list,
feed_var_name=feed_var_name, feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name, fetch_var_name=fetch_var_name,
skip_fetch=True) use_fetch_v2=True)
# NPTE(zhiqiu): Construct standalone_executor first, so
# the scope is binded with the variable_scope of standalone_executor
new_exe = self._executor_cache._get_exe_from_cache(program,
scope)
self._feed_data(program, feed, feed_var_name, scope) self._feed_data(program, feed, feed_var_name, scope)
if hasattr(program, 'lr_sheduler'): if hasattr(program, 'lr_sheduler'):
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
...@@ -1360,9 +1374,7 @@ class Executor(object): ...@@ -1360,9 +1374,7 @@ class Executor(object):
lr_sheduler._var_name) lr_sheduler._var_name)
tensor.set(data, self.place) tensor.set(data, self.place)
return self._executor_cache.run(program, scope, return new_exe.run(list(feed.keys()), fetch_list, return_numpy)
list(feed.keys()), fetch_list,
return_numpy)
# use_prune can be overrided by putting optimize_ops in fetch_list # use_prune can be overrided by putting optimize_ops in fetch_list
_origin_fetch_list = fetch_list _origin_fetch_list = fetch_list
......
...@@ -309,7 +309,7 @@ class TestException(unittest.TestCase): ...@@ -309,7 +309,7 @@ class TestException(unittest.TestCase):
feed[1]['data'][0] = np.nan feed[1]['data'][0] = np.nan
self.assertRaises(RuntimeError, self.run_new_executor, feed) self.assertRaises(RuntimeError, self.run_new_executor, feed)
def test_scope(self): def test_scope_find_temp_var(self):
feed = [{ feed = [{
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64), 'id': np.array([1, 2, 3, 4, 5]).astype(np.int64),
'data': np.array([1, 2, 3]).astype(np.float32), 'data': np.array([1, 2, 3]).astype(np.float32),
...@@ -318,7 +318,7 @@ class TestException(unittest.TestCase): ...@@ -318,7 +318,7 @@ class TestException(unittest.TestCase):
'data': np.array([2, 2, 2]).astype(np.float32), 'data': np.array([2, 2, 2]).astype(np.float32),
}] }]
self.run_new_executor(feed) self.run_new_executor(feed)
self.assertIsNotNone(paddle.static.global_scope().find_var( self.assertIsNone(paddle.static.global_scope().find_var(
self.fetch_vars.name)) self.fetch_vars.name))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册