未验证 提交 1f0512be 编写于 作者: L Leo Chen 提交者: GitHub

[new feature] add local scope for interpretercore (#37379)

上级 964e20e0
......@@ -22,6 +22,9 @@
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true,
"Use inplace in new executor");
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope, true,
"Use local_scope in new executor(especially used "
"in UT), can turn off for better performance");
DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark);
......@@ -48,6 +51,14 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
exception_notifier_ = main_thread_blocker_.RegisterEvent(
kExceptionCaught, [this]() { return exception_holder_.IsCaught(); });
create_local_scope_ = FLAGS_new_executor_use_local_scope;
if (FLAGS_new_executor_use_local_scope) {
auto local_scope = &global_scope->GetMutableScope()->NewScope();
local_scope->AddListener(global_scope->Listener());
local_scope_ = local_scope;
}
VLOG(4) << "create_local_scope_ is " << create_local_scope_;
// prune
// optmize graph pass
......@@ -62,10 +73,15 @@ InterpreterCore::~InterpreterCore() {
async_work_queue_.reset(nullptr);
}
void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
copy_program_ = prog;
}
paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
bool is_build = is_build_;
global_scope_->SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, is_build);
if (is_build) {
......@@ -79,13 +95,27 @@ paddle::framework::FetchList InterpreterCore::Run(
paddle::framework::FetchList InterpreterCore::Run() {
if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_);
if (create_local_scope_ &&
global_scope_->GetMutableLocalScope() !=
global_scope_->GetMutableScope() &&
global_scope_->GetMutableLocalScope()) {
VLOG(4) << "Clear previous local scope before run";
VLOG(4) << global_scope_->GetMutableScope() << " "
<< global_scope_->GetMutableLocalScope();
platform::DeviceContextPool::Instance().Get(place_)->Wait();
// TODO(zhiqiu): clear the tensor holder of all vars in previous local
// scope?
}
global_scope_->SetLocalScope(local_scope_);
paddle::framework::interpreter::build_variable_scope(block_, global_scope_,
create_local_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_);
place_, block_, &op_func_nodes, global_scope_, create_local_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert(&op_func_nodes);
} else {
ExecuteInstructionList(vec_instruction_);
}
......@@ -300,7 +330,10 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace();
VLOG(4) << "Start run" << place << " " << op->DebugStringEx(global_scope_);
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(global_scope_);
Scope* local_scope = create_local_scope_
? global_scope_->GetMutableLocalScope()
: global_scope_->GetMutableScope();
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
{
......@@ -325,13 +358,14 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
}
{
platform::RecordEvent compute_event("Compute");
if (op_with_kernel == nullptr)
instr_node.OpBase()->Run(*global_scope_->GetScope(), place_);
else
if (op_with_kernel == nullptr) {
instr_node.OpBase()->Run(*local_scope, place_);
} else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
}
}
VLOG(4) << "End run" << place << " " << op->DebugStringEx(global_scope_);
VLOG(4) << "End run " << place << " " << op->DebugStringEx(global_scope_);
/*For profiling/benchmark only*/
if (FLAGS_benchmark) {
......@@ -372,8 +406,8 @@ void InterpreterCore::ExecuteInstructionList(
}
}
auto event_id = main_thread_blocker_.WaitEvent();
VLOG(3) << "event_id " << event_id;
auto event_name = main_thread_blocker_.WaitEvent();
VLOG(3) << "event_name: " << event_name;
if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(4) << "Exception caught " << exception_holder_.Type();
......@@ -526,8 +560,9 @@ void InterpreterCore::Prepare(
VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) {
auto* feed_var = global_scope_->FindVar(feed_names[i]);
PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound(
"feed_var shall not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
feed_var, platform::errors::NotFound(
"Variable %s should not be nullptr.", feed_names[i]));
auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>();
feed_tensor->ShareDataWith(feed_tensors[i]);
......@@ -536,11 +571,12 @@ void InterpreterCore::Prepare(
};
if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_);
paddle::framework::interpreter::build_variable_scope(block_, global_scope_,
create_local_scope_);
FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_);
place_, block_, &op_func_nodes, global_scope_, create_local_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert(&op_func_nodes);
......@@ -556,6 +592,7 @@ void InterpreterCore::Prepare(
interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
global_scope_->SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info;
{
......
......@@ -55,6 +55,8 @@ class InterpreterCore {
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog);
private:
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
......@@ -85,7 +87,13 @@ class InterpreterCore {
bool is_build_;
const platform::Place& place_;
const BlockDesc& block_; // not owned
const BlockDesc& block_; // not owned
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will
// copy a new program and block, the copy_program_ here is used to
// hold the program, otherwise block_ maybe not valid after the
// new program is deleted.
std::shared_ptr<ProgramDesc> copy_program_{nullptr};
VariableScope* global_scope_; // not owned
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
......@@ -102,6 +110,8 @@ class InterpreterCore {
std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_;
bool create_local_scope_{true};
Scope* local_scope_{nullptr}; // not owned
};
} // namespace framework
} // namespace paddle
......@@ -132,23 +132,35 @@ std::string get_memcpy_type(const platform::Place& src_place,
}
void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope) {
VariableScope* var_scope, bool use_local_scope) {
VLOG(3) << "Creating Variables";
auto inner_scope = var_scope->GetMutableScope();
// NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
// created in var_scope.scope_ , and other scope is created in local scope.
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
for (auto& var_desc : block.AllVars()) {
auto var_name = var_desc->Name();
if (var_name == framework::kEmptyVarName) {
continue;
}
if (var_desc->Persistable()) {
auto* ptr = inner_scope->Var(var_name);
if (nullptr == var_scope->FindVar(var_name)) {
var_scope->AddVar(var_desc->Name(), var_desc);
VLOG(3) << "Initialize Variable " << var_name;
InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
<< ptr << " type is " << static_cast<int>(var_desc->GetType());
} else {
auto* var_desc_tmp = var_scope->VarDesc(var_name);
if (nullptr == var_desc_tmp) {
VLOG(3) << "update var:" << var_name << " desc from nullptr into "
<< var_desc;
var_scope->SetVarDesc(var_name, var_desc);
}
auto* ptr = local_scope->Var(var_name);
InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
<< ptr << "Variable Type "
<< static_cast<int>(var_desc->GetType());
}
var_scope->SetVarDesc(var_name, var_desc);
}
}
......@@ -237,14 +249,14 @@ void apply_device_guard(const OperatorBase* op_base,
void deal_operator_base(const platform::Place& place,
const VariableScope* var_scope,
std::shared_ptr<OperatorBase> op_base,
OpFuncNode* op_func_node) {
OpFuncNode* op_func_node, Scope* local_scope) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// input, output is prepared. set the other attributes.
op_func_node->operator_base_ = op_base;
op_func_node->type_ = OpFuncType::kQueueSync; // alway Sync
op_func_node->kernel_func_ = nullptr;
op_base->Run(*var_scope->GetScope(), place); // Run without data transformer.
op_base->Run(*local_scope, place); // Run without data transformer.
std::unordered_set<int> no_data_transform_index;
for (auto& it : op_func_node->input_index) {
......@@ -288,12 +300,21 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key, const platform::Place& place,
const std::string& var_name, const std::string& outer_name,
const OpFuncNode& op_func_node, Variable* var, VariableScope* var_scope) {
const OpFuncNode& op_func_node, Variable* var, VariableScope* var_scope,
bool use_local_scope = true) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
std::string new_var_name =
var_name + "_copy_" + std::to_string(var_scope->VarSize() + 1);
var_scope->AddVar(new_var_name, nullptr);
auto* ptr = local_scope->Var(new_var_name);
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var->Type()));
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
<< ptr << "Variable Type " << var->Type();
var_scope->SetVarDesc(var_name, nullptr);
VariableNameMap copy_in_map;
copy_in_map["X"] = {var_name};
......@@ -368,7 +389,8 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
const platform::Place& place,
VariableValueMap* ins_map_temp,
VariableScope* var_scope, OpFuncNode* op_func_node,
std::vector<OpFuncNode>* copy_func_nodes) {
std::vector<OpFuncNode>* copy_func_nodes,
bool use_local_scope = true) {
auto op_base = op_func_node->operator_base_.get();
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base is null, please pass a valid "
......@@ -402,9 +424,10 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
std::string new_var_name;
OpFuncNode copy_op_func_node;
std::tie(new_var_name, copy_op_func_node) =
apply_place_transform_for_var(
kernel_type_for_var, expected_kernel_key, place, var_name,
var_name_item.first, *op_func_node, var, var_scope);
apply_place_transform_for_var(kernel_type_for_var,
expected_kernel_key, place, var_name,
var_name_item.first, *op_func_node,
var, var_scope, use_local_scope);
op_func_node->input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name);
copy_func_nodes->emplace_back(copy_op_func_node);
......@@ -438,7 +461,9 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) {
VariableScope* var_scope, bool use_local_scope) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
std::vector<std::shared_ptr<OperatorBase>>
ops; // its elements will be moved to vec_func_list
......@@ -478,7 +503,7 @@ void build_op_func_list(const platform::Place& place,
if (dynamic_cast<const framework::OperatorWithKernel*>(op) == nullptr) {
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, ops[i], &op_func_node);
deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
} else {
// construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {});
......@@ -520,7 +545,7 @@ void build_op_func_list(const platform::Place& place,
// apply_data_transform.
op_func_node.operator_base_ = ops[i];
apply_data_transform(expected_kernel_key, place, &ins_map_temp, var_scope,
&op_func_node, &copy_op_to_insert);
&op_func_node, &copy_op_to_insert, use_local_scope);
for (auto& item : copy_op_to_insert) {
vec_func_list->push_back(item);
}
......@@ -631,16 +656,16 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
}
void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
std::map<int, std::list<int>>& var2min_rw_op,
std::map<int, std::list<int>>* var2min_rw_op,
int cur_op, int rw_var) {
// rw_var is inputs or outputs of cur_op
// this function update the var2min_rw_op set .
if (var2min_rw_op.find(rw_var) == var2min_rw_op.end())
var2min_rw_op[rw_var] = std::list<int>();
if (var2min_rw_op->find(rw_var) == var2min_rw_op->end())
(*var2min_rw_op)[rw_var] = std::list<int>();
for (auto dep_op : op2dependences.at(cur_op)) {
var2min_rw_op[rw_var].remove(dep_op);
(*var2min_rw_op)[rw_var].remove(dep_op);
}
var2min_rw_op[rw_var].push_back(cur_op);
(*var2min_rw_op)[rw_var].push_back(cur_op);
}
std::map<int, std::list<int>> get_downstream_map(
......@@ -702,7 +727,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (auto& item :
vec_instruction[op_idx].Inputs()) { // for all inputs(read only)
for (auto var : item.second) {
update_var_min_rw_op(op2dependences, var2min_rw_op, op_idx, var);
update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
remove_duplicate.insert(var);
}
}
......@@ -713,7 +738,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
var2recent_write_op[var] = op_idx;
if (remove_duplicate.count(var) ==
0) { // var in input list and in output list, so remove it.
update_var_min_rw_op(op2dependences, var2min_rw_op, op_idx, var);
update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
}
}
}
......
......@@ -51,7 +51,7 @@ namespace framework {
namespace interpreter {
using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
static constexpr char kFetchVarName[] = "fetch_vars";
static constexpr char kFetchVarName[] = "fetch";
class AsyncWorkQueue {
public:
......@@ -98,12 +98,13 @@ std::string get_memcpy_type(const platform::Place& src_place,
const platform::Place& dst_place);
void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope);
VariableScope* var_scope,
bool use_local_scope = true);
void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope);
VariableScope* var_scope, bool use_local_scope = true);
std::map<int, std::list<int>> build_op_downstream_map(
const std::vector<Instruction>& vec_instruction);
......
......@@ -497,7 +497,14 @@ VariableScope::~VariableScope() {
}
}
const Scope* VariableScope::GetScope() const { return scope_; }
Scope* VariableScope::GetMutableScope() const { return scope_; }
Scope* VariableScope::GetMutableLocalScope() const { return local_scope_; }
void VariableScope::SetLocalScope(Scope* local_scope) {
VLOG(4) << "Set local scope: " << local_scope;
local_scope_ = local_scope;
}
Variable* VariableScope::FindVar(const std::string& name) const {
auto it = name2id_.find(name);
......@@ -554,8 +561,9 @@ Variable* VariableScope::Var(const std::string& name) const {
size_t VariableScope::VarSize() const { return var_list_.size(); }
void VariableScope::AddVar(const std::string& name,
framework::VarDesc* var_desc) { // NOLINT
auto v = scope_->Var(name);
framework::VarDesc* var_desc,
bool local_scope) { // NOLINT
auto v = local_scope ? local_scope_->Var(name) : scope_->Var(name);
if (nullptr == var_desc) {
v->GetMutable<LoDTensor>();
} else {
......@@ -606,9 +614,9 @@ VariableScopeListener::VariableScopeListener(VariableScope* var_scope) {
var_scope_ = var_scope;
}
void VariableScopeListener::onCreateVariable(const std::string& name) {
auto v = var_scope_->scope_->GetVar(name); // must exsit in outer_scope_
if (!var_scope_->HasVar(name)) { // may exist in variable scope.
void VariableScopeListener::onCreateVariable(const std::string& name,
Variable* v) {
if (!var_scope_->HasVar(name)) { // may exist in variable scope.
VLOG(4) << "Calling VariableScope::onCreateVariable with var_name: "
<< name;
var_scope_->name2id_[name] = var_scope_->VarSize();
......
......@@ -155,7 +155,7 @@ class VariableScope;
class VariableScopeListener : public ScopeListener {
public:
explicit VariableScopeListener(VariableScope* var_scope_);
void onCreateVariable(const std::string& name) override;
void onCreateVariable(const std::string& name, Variable* v) override;
void onDeleteVariable(const std::string& name) override;
void onRenameVariable(const std::string& old_name,
const std::string& new_name) override;
......@@ -177,7 +177,11 @@ class VariableScope : public ScopeBase {
public:
explicit VariableScope(Scope* scope);
const Scope* GetScope() const;
Scope* GetMutableScope() const;
Scope* GetMutableLocalScope() const;
void SetLocalScope(Scope* local_scope);
Variable* FindVar(const std::string& name) const;
......@@ -199,7 +203,8 @@ class VariableScope : public ScopeBase {
size_t VarSize() const;
void AddVar(const std::string& name, VarDesc* var_desc);
void AddVar(const std::string& name, VarDesc* var_desc,
bool local_scope = false);
void AddVar(const std::string& name, const Variable& var);
......@@ -219,15 +224,21 @@ class VariableScope : public ScopeBase {
return vec_meta_info_;
}
const std::shared_ptr<VariableScopeListener>& Listener() const {
return listener_;
}
friend class VariableScopeListener;
private:
std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_;
Scope* scope_ = nullptr;
Scope* scope_{nullptr};
// TODO(zhiqiu): find a better way to support local scope.
Scope* local_scope_{nullptr};
// mutable RWLock vars_lock_;
std::shared_ptr<VariableScopeListener> listener_;
std::shared_ptr<VariableScopeListener> listener_{nullptr};
};
class NextInstruction {
......
......@@ -24,30 +24,36 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
startup_prog_(startup_prog),
main_prog_(main_prog),
global_scope_(VariableScope(scope)) {
// init scope
BuildVariableScope(startup_prog, &global_scope_);
if (scope != nullptr) {
auto name_list = scope->LocalVarNames();
for (auto name : name_list) {
auto v = scope->Var(name);
if (!global_scope_.HasVar(name)) {
global_scope_.AddVar(name, *v);
// NOTE(zhiqiu): for startup_program, initialize scope and run once
// if startup_program is empty, the scope is initialize during first run
if (startup_prog.Block(0).AllOps().size() > 0) {
VLOG(4) << "Run startup program";
// init scope
BuildVariableScope(startup_prog, &global_scope_);
if (scope != nullptr) {
auto name_list = scope->LocalVarNames();
for (auto name : name_list) {
auto v = scope->Var(name);
if (!global_scope_.HasVar(name)) {
global_scope_.AddVar(name, *v);
}
}
}
}
// run startup program
std::vector<paddle::framework::OpFuncNode> vec_func_list;
paddle::framework::interpreter::build_op_func_list(
place_, startup_prog.Block(0), &vec_func_list, &global_scope_);
std::vector<paddle::framework::OpFuncNode> vec_func_list;
// No need to use_local_scope for startup_program, its variables are
// persistable
paddle::framework::interpreter::build_op_func_list(
place_, startup_prog.Block(0), &vec_func_list, &global_scope_, false);
}
}
paddle::framework::FetchList StandaloneExecutor::Run(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors,
const std::vector<std::string>& fetch_names) {
auto core = GetInterpreterCore(feed_names, fetch_names);
auto core = GetInterpreterCore(feed_names, fetch_names, true);
return core->Run(feed_names, feed_tensors);
}
......@@ -55,15 +61,15 @@ paddle::framework::FetchList StandaloneExecutor::Run(
paddle::framework::FetchList StandaloneExecutor::Run(
const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names) {
auto core = GetInterpreterCore(feed_names, fetch_names);
auto core = GetInterpreterCore(feed_names, fetch_names, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
return core->Run();
}
framework::interpreter::CostInfo StandaloneExecutor::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
auto core = GetInterpreterCore(feed_names, {});
auto core = GetInterpreterCore(feed_names, {}, true);
return core->DryRun(feed_names, feed_tensors);
}
......@@ -85,7 +91,7 @@ void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc,
std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names) {
const std::vector<std::string>& fetch_names, bool add_fetch_op) {
std::ostringstream oss;
oss << "feed:";
for (auto& feedname : feed_names) {
......@@ -100,15 +106,22 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
if (iter == interpretercores_.end()) {
VLOG(3) << "create interpreter_core for " << oss.str();
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy a
// new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(main_prog_);
auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block);
auto core =
std::make_shared<InterpreterCore>(place_, *block, &global_scope_);
programs_.emplace(oss.str(), new_prog);
VLOG(3) << "add fetch op: " << add_fetch_op;
std::shared_ptr<InterpreterCore> core = nullptr;
if (add_fetch_op) {
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy
// a
// new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(main_prog_);
auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block);
core = std::make_shared<InterpreterCore>(place_, *block, &global_scope_);
core->SetCopyProgram(new_prog);
} else {
core = std::make_shared<InterpreterCore>(place_, main_prog_.Block(0),
&global_scope_);
}
interpretercores_.emplace(oss.str(), core);
return core;
} else {
......
......@@ -61,14 +61,13 @@ class StandaloneExecutor : public ExecutorBase {
std::shared_ptr<InterpreterCore> GetInterpreterCore(
const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names);
const std::vector<std::string>& fetch_names, bool add_fetch_op);
const platform::Place& place_;
const ProgramDesc& startup_prog_;
const ProgramDesc& main_prog_;
VariableScope global_scope_;
std::unordered_map<std::string, std::shared_ptr<ProgramDesc>> programs_;
std::unordered_map<std::string, std::shared_ptr<InterpreterCore>>
interpretercores_;
};
......
......@@ -67,7 +67,7 @@ Variable* Scope::Var(const std::string& name) {
ret = VarInternal(name);
}
for (auto l : listeners_) {
l->onCreateVariable(name);
l->onCreateVariable(name, ret);
}
return ret;
}
......@@ -85,7 +85,7 @@ Variable* Scope::Var(std::string* name) {
ret = VarInternal(new_name);
}
for (auto l : listeners_) {
l->onCreateVariable(new_name);
l->onCreateVariable(new_name, ret);
}
return ret;
}
......
......@@ -58,7 +58,7 @@ class ScopeListener {
// in original Scope.
public:
virtual ~ScopeListener() {}
virtual void onCreateVariable(const std::string& name) {}
virtual void onCreateVariable(const std::string& name, Variable* v) {}
virtual void onDeleteVariable(const std::string& name) {}
virtual void onRenameVariable(const std::string& old_name,
const std::string& new_name) {}
......
......@@ -112,13 +112,6 @@ class FetchV2Op : public framework::OperatorWithKernel {
}
};
class FetchV2InferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
}
};
class FetchV2Kernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
......@@ -211,7 +204,6 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
fetch_v2, ops::FetchV2Op, ops::FetchV2OpProtoMaker,
ops::FetchV2InferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
......@@ -288,7 +288,10 @@ def has_feed_operators(block, feed_targets, feed_holder_name):
return feed_count > 0
def has_fetch_operators(block, fetch_targets, fetch_holder_name):
def has_fetch_operators(block,
fetch_targets,
fetch_holder_name,
fetch_op='fetch'):
""" Check whether the block already has fetch operators.
Return false if the block does not have any fetch operators.
......@@ -303,6 +306,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
fetch_holder_name: the name of the variable that holds the data of
all fetch targets. The type of this fetch_holder variable is
FETCH_LIST, which is essentially vector<LoDTensor>.
fetch_op: the operator name of fetch
Return:
A boolean value that indicates whether a block has fetch operators
......@@ -311,7 +315,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
fetch_count = 0
for op in block.ops:
if op.desc.type() == 'fetch':
if op.desc.type() == fetch_op:
fetch_count += 1
assert op.desc.output('Out')[0] == fetch_holder_name
fetch_target_name = op.desc.input('X')[0]
......@@ -740,7 +744,7 @@ class Executor(object):
fetch_list,
feed_var_name,
fetch_var_name,
skip_fetch=False):
use_fetch_v2=False):
tmp_program = program.clone()
global_block = tmp_program.global_block()
......@@ -775,17 +779,21 @@ class Executor(object):
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% name)
if skip_fetch:
return tmp_program
if use_fetch_v2:
fetch_op = 'fetch_v2'
else:
fetch_op = 'fetch'
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
if not has_fetch_operators(global_block, fetch_list, fetch_var_name,
fetch_op):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(
var, six.string_types), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
type=fetch_op,
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
......@@ -1345,7 +1353,13 @@ class Executor(object):
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name,
skip_fetch=True)
use_fetch_v2=True)
# NPTE(zhiqiu): Construct standalone_executor first, so
# the scope is binded with the variable_scope of standalone_executor
new_exe = self._executor_cache._get_exe_from_cache(program,
scope)
self._feed_data(program, feed, feed_var_name, scope)
if hasattr(program, 'lr_sheduler'):
from paddle.optimizer.lr import LRScheduler
......@@ -1360,9 +1374,7 @@ class Executor(object):
lr_sheduler._var_name)
tensor.set(data, self.place)
return self._executor_cache.run(program, scope,
list(feed.keys()), fetch_list,
return_numpy)
return new_exe.run(list(feed.keys()), fetch_list, return_numpy)
# use_prune can be overrided by putting optimize_ops in fetch_list
_origin_fetch_list = fetch_list
......
......@@ -309,7 +309,7 @@ class TestException(unittest.TestCase):
feed[1]['data'][0] = np.nan
self.assertRaises(RuntimeError, self.run_new_executor, feed)
def test_scope(self):
def test_scope_find_temp_var(self):
feed = [{
'id': np.array([1, 2, 3, 4, 5]).astype(np.int64),
'data': np.array([1, 2, 3]).astype(np.float32),
......@@ -318,7 +318,7 @@ class TestException(unittest.TestCase):
'data': np.array([2, 2, 2]).astype(np.float32),
}]
self.run_new_executor(feed)
self.assertIsNotNone(paddle.static.global_scope().find_var(
self.assertIsNone(paddle.static.global_scope().find_var(
self.fetch_vars.name))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册