未验证 提交 9f74363f 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] remove variable scope, stage 1 (#43865)

* separate variable scope and scope

* hot fix for lod_tensor_blocking_queue

* fix bug that variable exists in global scope
上级 d1ac85e5
...@@ -30,9 +30,6 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, ...@@ -30,9 +30,6 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
bool is_transferred = false; bool is_transferred = false;
auto* src_var_name = &var_name; auto* src_var_name = &var_name;
Scope* local_scope = use_local_scope ? var_scope_->GetMutableLocalScope()
: var_scope_->GetMutableScope();
// 1. layout transform // 1. layout transform
if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) { if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) {
auto op = TransferLayout(*src_var_name, auto op = TransferLayout(*src_var_name,
...@@ -40,7 +37,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, ...@@ -40,7 +37,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
kernel_type_for_var.data_layout_, kernel_type_for_var.data_layout_,
expected_kernel_key.data_layout_, expected_kernel_key.data_layout_,
var_scope_, var_scope_,
local_scope, scope_,
is_fetch_v2); is_fetch_v2);
if (op) { if (op) {
RunAndConstructOpFuncNode( RunAndConstructOpFuncNode(
...@@ -57,7 +54,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, ...@@ -57,7 +54,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
kernel_type_for_var.data_type_, kernel_type_for_var.data_type_,
expected_kernel_key.data_type_, expected_kernel_key.data_type_,
var_scope_, var_scope_,
local_scope); scope_);
if (op) { if (op) {
RunAndConstructOpFuncNode( RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes); op, *src_var_name, *new_var_name, op_func_nodes);
...@@ -71,12 +68,8 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var, ...@@ -71,12 +68,8 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
auto src_place = kernel_type_for_var.place_; auto src_place = kernel_type_for_var.place_;
auto dst_place = expected_kernel_key.place_; auto dst_place = expected_kernel_key.place_;
auto op = TransferDevice(*src_var_name, auto op = TransferDevice(
new_var_name, *src_var_name, new_var_name, src_place, dst_place, var_scope_, scope_);
src_place,
dst_place,
var_scope_,
local_scope);
if (op) { if (op) {
RunAndConstructOpFuncNode( RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes); op, *src_var_name, *new_var_name, op_func_nodes);
...@@ -114,8 +107,8 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -114,8 +107,8 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
// 1. Construct RuntimeContext // 1. Construct RuntimeContext
RuntimeContext runtime_context({}, {}); RuntimeContext runtime_context({}, {});
runtime_context.inputs["X"] = {var_scope_->Var(var_name)}; runtime_context.inputs["X"] = {scope_->FindVar(var_name)};
runtime_context.outputs["Out"] = {var_scope_->Var(new_var_name)}; runtime_context.outputs["Out"] = {scope_->Var(new_var_name)};
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
// 2. Execute infer shape and choose kernel // 2. Execute infer shape and choose kernel
...@@ -188,19 +181,19 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name, ...@@ -188,19 +181,19 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
std::to_string(static_cast<int>(out_layout)); std::to_string(static_cast<int>(out_layout));
if (var_scope->HasVar(*new_var_name) && if (var_scope->HasVar(*new_var_name) &&
IsTensorOfVarInitialized(var_scope->Var(*new_var_name))) { IsTensorOfVarInitialized(local_scope->FindVar(*new_var_name))) {
// already has same var // already has same var
VLOG(4) << "Use cached variable: " << *new_var_name; VLOG(4) << "Use cached variable: " << *new_var_name;
return nullptr; return nullptr;
} }
auto* ptr = local_scope->Var(*new_var_name); auto* ptr = local_scope->Var(*new_var_name);
auto var_type = var_scope->Var(var_name)->Type(); auto var_type = local_scope->FindVar(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
VLOG(3) << "Create Variable " << *new_var_name VLOG(3) << "Create Variable " << *new_var_name
<< " locally, which pointer is " << ptr << "Variable Type " << " locally, which pointer is " << ptr << "Variable Type "
<< var_type; << var_type;
var_scope->SetVarDesc(*new_var_name, nullptr); var_scope->AddVar(*new_var_name, nullptr);
// 2. Construct VariableNameMap // 2. Construct VariableNameMap
VariableNameMap in_name_map = {{"X", {var_name}}}; VariableNameMap in_name_map = {{"X", {var_name}}};
...@@ -227,27 +220,27 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name, ...@@ -227,27 +220,27 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
std::string* new_var_name, std::string* new_var_name,
proto::VarType::Type in_dtype, proto::VarType::Type in_dtype,
proto::VarType::Type out_dtype, proto::VarType::Type out_dtype,
VariableScope* var_scope, framework::VariableScope* var_scope,
framework::Scope* local_scope) { framework::Scope* local_scope) {
// 1. Generate new_var_name and Initialize it // 1. Generate new_var_name and Initialize it
*new_var_name = var_name + "_dtype_" + *new_var_name = var_name + "_dtype_" +
std::to_string(static_cast<int>(in_dtype)) + "_" + std::to_string(static_cast<int>(in_dtype)) + "_" +
std::to_string(static_cast<int>(out_dtype)); std::to_string(static_cast<int>(out_dtype));
if (var_scope->HasVar(*new_var_name) && if (var_scope->HasVar(*new_var_name) &&
IsTensorOfVarInitialized(var_scope->Var(*new_var_name))) { IsTensorOfVarInitialized(local_scope->FindVar(*new_var_name))) {
// already has same var // already has same var
VLOG(4) << "Use cached variable: " << *new_var_name; VLOG(4) << "Use cached variable: " << *new_var_name;
return nullptr; return nullptr;
} }
auto* ptr = local_scope->Var(*new_var_name); auto* ptr = local_scope->Var(*new_var_name);
auto var_type = var_scope->Var(var_name)->Type(); auto var_type = local_scope->FindVar(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
VLOG(3) << "Create Variable " << *new_var_name VLOG(3) << "Create Variable " << *new_var_name
<< " locally, which pointer is " << ptr << "Variable Type " << " locally, which pointer is " << ptr << "Variable Type "
<< var_type; << var_type;
var_scope->SetVarDesc(*new_var_name, nullptr); var_scope->AddVar(*new_var_name, nullptr);
// 2. Construct VariableNameMap // 2. Construct VariableNameMap
VariableNameMap in_name_map = {{"X", {var_name}}}; VariableNameMap in_name_map = {{"X", {var_name}}};
...@@ -283,20 +276,20 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name, ...@@ -283,20 +276,20 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
*new_var_name = var_name + "_device_" + src_place.DebugString() + "_" + *new_var_name = var_name + "_device_" + src_place.DebugString() + "_" +
dst_place.DebugString(); dst_place.DebugString();
if (var_scope->HasVar(*new_var_name) && if (local_scope->FindVar(*new_var_name) &&
IsTensorOfVarInitialized(var_scope->Var(*new_var_name))) { IsTensorOfVarInitialized(local_scope->FindVar(*new_var_name))) {
// already has same var // already has same var
VLOG(4) << "Use cached variable: " << *new_var_name; VLOG(4) << "Use cached variable: " << *new_var_name;
return nullptr; return nullptr;
} }
auto* ptr = local_scope->Var(*new_var_name); auto* ptr = local_scope->Var(*new_var_name);
auto var_type = var_scope->Var(var_name)->Type(); auto var_type = local_scope->FindVar(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type)); InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
VLOG(3) << "Create Variable " << *new_var_name VLOG(3) << "Create Variable " << *new_var_name
<< " locally, which pointer is " << ptr << "Variable Type " << " locally, which pointer is " << ptr << "Variable Type "
<< var_type; << var_type;
var_scope->SetVarDesc(*new_var_name, nullptr); var_scope->AddVar(*new_var_name, nullptr);
// 2. Construct VariableNameMap // 2. Construct VariableNameMap
VariableNameMap in_name_map = {{"X", {var_name}}}; VariableNameMap in_name_map = {{"X", {var_name}}};
...@@ -350,6 +343,9 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -350,6 +343,9 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
OpFuncNode* op_func_node, OpFuncNode* op_func_node,
std::vector<OpFuncNode>* new_op_func_nodes, std::vector<OpFuncNode>* new_op_func_nodes,
bool use_local_scope) { bool use_local_scope) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
auto op_base = op_func_node->operator_base_.get(); auto op_base = op_func_node->operator_base_.get();
PADDLE_ENFORCE_NOT_NULL(op_base, PADDLE_ENFORCE_NOT_NULL(op_base,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -372,7 +368,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -372,7 +368,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
} }
bool transfered = false; bool transfered = false;
DataTranferHelper data_transfer_helper(place, var_scope); DataTranferHelper data_transfer_helper(place, var_scope, local_scope);
for (auto& var_name_item : *ins_map_temp) { for (auto& var_name_item : *ins_map_temp) {
bool should_skip_input = bool should_skip_input =
no_buffer_ins && no_buffer_ins->count(var_name_item.first) > 0; no_buffer_ins && no_buffer_ins->count(var_name_item.first) > 0;
...@@ -414,9 +410,6 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -414,9 +410,6 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
"but kNHWC layout" "but kNHWC layout"
<< var_name_item.first << " in Operator " << var_name_item.first << " in Operator "
<< op_base->Type(); << op_base->Type();
Scope* local_scope = use_local_scope
? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
auto op = TransferLayout(var_name, auto op = TransferLayout(var_name,
&new_var_name, &new_var_name,
tensor_in->layout(), tensor_in->layout(),
...@@ -458,7 +451,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -458,7 +451,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
// update RuntimeContext.inputs and original op_func_node inputs // update RuntimeContext.inputs and original op_func_node inputs
op_func_node->input_index[var_name_item.first][i] = op_func_node->input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name); var_scope->VarId(new_var_name);
var_name_item.second[i] = var_scope->Var(new_var_name); var_name_item.second[i] = local_scope->FindVar(new_var_name);
new_ins[var_name_item.first][i] = new_var_name; new_ins[var_name_item.first][i] = new_var_name;
for (auto& pair : new_outs) { for (auto& pair : new_outs) {
for (size_t j = 0; j < pair.second.size(); ++j) { for (size_t j = 0; j < pair.second.size(); ++j) {
...@@ -467,7 +460,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -467,7 +460,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
VLOG(4) << "Found inplace between input(" << var_name_item.first VLOG(4) << "Found inplace between input(" << var_name_item.first
<< ") and output(" << pair.first << ") and output(" << pair.first
<< "), the variable name is " << var_name; << "), the variable name is " << var_name;
(*outs_map_temp)[pair.first][j] = var_scope->Var(new_var_name); (*outs_map_temp)[pair.first][j] =
local_scope->FindVar(new_var_name);
new_outs[pair.first][j] = new_var_name; new_outs[pair.first][j] = new_var_name;
op_func_node op_func_node
->inplace_back_map[var_scope->GetIdByName(new_var_name)] = ->inplace_back_map[var_scope->GetIdByName(new_var_name)] =
...@@ -508,7 +502,7 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, ...@@ -508,7 +502,7 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
VariableScope* var_scope, VariableScope* var_scope,
std::vector<OpFuncNode>* op_func_nodes, std::vector<OpFuncNode>* op_func_nodes,
framework::Scope* local_scope) { framework::Scope* local_scope) {
DataTranferHelper data_transfer_helper(place, var_scope); DataTranferHelper data_transfer_helper(place, var_scope, local_scope);
for (auto& var_name_item : out_names) { for (auto& var_name_item : out_names) {
std::vector<Variable*>& vars = out_vars->at(var_name_item.first); std::vector<Variable*>& vars = out_vars->at(var_name_item.first);
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
...@@ -548,7 +542,7 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, ...@@ -548,7 +542,7 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
} }
// 2. find forward var & check whether need to cast // 2. find forward var & check whether need to cast
auto* var = var_scope->FindVar(orig_var_name); auto* var = local_scope->FindVar(orig_var_name);
// if forward var not exists, do nothing // if forward var not exists, do nothing
if (var == nullptr) { if (var == nullptr) {
VLOG(3) << "skip " << orig_var_name << " with not found in var_scope"; VLOG(3) << "skip " << orig_var_name << " with not found in var_scope";
......
...@@ -29,8 +29,10 @@ namespace interpreter { ...@@ -29,8 +29,10 @@ namespace interpreter {
*/ */
class DataTranferHelper { class DataTranferHelper {
public: public:
DataTranferHelper(const platform::Place& place, VariableScope* var_scope) DataTranferHelper(const platform::Place& place,
: place_(place), var_scope_(var_scope) {} VariableScope* var_scope,
Scope* local_scope)
: place_(place), var_scope_(var_scope), scope_(local_scope) {}
bool apply(const OpKernelType& kernel_type_for_var, bool apply(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key, const OpKernelType& expected_kernel_key,
...@@ -52,6 +54,7 @@ class DataTranferHelper { ...@@ -52,6 +54,7 @@ class DataTranferHelper {
private: private:
platform::Place place_; platform::Place place_;
VariableScope* var_scope_; VariableScope* var_scope_;
Scope* scope_;
}; };
void ApplyDataTransform(const OpKernelType& expected_kernel_key, void ApplyDataTransform(const OpKernelType& expected_kernel_key,
......
...@@ -60,11 +60,11 @@ bool IsInterpretercoreFastGCEnabled() { ...@@ -60,11 +60,11 @@ bool IsInterpretercoreFastGCEnabled() {
InterpreterCore::InterpreterCore(const platform::Place& place, InterpreterCore::InterpreterCore(const platform::Place& place,
const BlockDesc& block, const BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
VariableScope* global_scope) framework::Scope* scope)
: place_(place), : place_(place),
block_(block), block_(block),
skip_gc_vars_(skip_gc_vars), skip_gc_vars_(skip_gc_vars),
global_scope_(global_scope), var_scope_(scope),
stream_analyzer_(place) { stream_analyzer_(place) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place_; VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
...@@ -84,12 +84,12 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -84,12 +84,12 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion); completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
create_local_scope_ = FLAGS_new_executor_use_local_scope; create_local_scope_ = FLAGS_new_executor_use_local_scope;
if (FLAGS_new_executor_use_local_scope) { VLOG(4) << "create_local_scope_ is " << create_local_scope_;
auto local_scope = &global_scope->GetMutableScope()->NewScope();
local_scope->AddListener(global_scope->Listener()); if (create_local_scope_) {
auto local_scope = &var_scope_.GetMutableScope()->NewScope();
local_scope_ = local_scope; local_scope_ = local_scope;
} }
VLOG(4) << "create_local_scope_ is " << create_local_scope_;
// prune // prune
...@@ -115,7 +115,7 @@ InterpreterCore::~InterpreterCore() { ...@@ -115,7 +115,7 @@ InterpreterCore::~InterpreterCore() {
interpreter::CostInfo InterpreterCore::DryRun( interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
global_scope_->SetLocalScope(local_scope_); var_scope_.SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, true); Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info; interpreter::CostInfo cost_info;
{ {
...@@ -144,13 +144,10 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -144,13 +144,10 @@ paddle::framework::FetchList InterpreterCore::Run(
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif #endif
bool is_build = is_build_; bool is_build = is_build_;
global_scope_->SetLocalScope(local_scope_); var_scope_.SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, is_build); Prepare(feed_names, feed_tensors, is_build);
if (is_build) { if (is_build) {
// add listener before run and is_build=true
global_scope_->ResetListener();
// For the program that only run once, it is no need to // For the program that only run once, it is no need to
// create work_queue, so the async_work_queue_ is created // create work_queue, so the async_work_queue_ is created
// until the second step run. // until the second step run.
...@@ -162,12 +159,13 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -162,12 +159,13 @@ paddle::framework::FetchList InterpreterCore::Run(
ClearLoDTensorArrayInLocalScope(); ClearLoDTensorArrayInLocalScope();
} }
// clear the listener after run
global_scope_->ClearListener();
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>()); if (fetch_var) {
return std::move(*fetch_var->GetMutable<framework::FetchList>());
} else {
return {};
}
} }
paddle::framework::FetchList InterpreterCore::Run( paddle::framework::FetchList InterpreterCore::Run(
...@@ -176,26 +174,15 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -176,26 +174,15 @@ paddle::framework::FetchList InterpreterCore::Run(
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif #endif
if (!is_build_) { if (!is_build_) {
if (create_local_scope_ && var_scope_.SetLocalScope(local_scope_);
global_scope_->GetMutableLocalScope() != paddle::framework::interpreter::build_variable_scope(block_, &var_scope_);
global_scope_->GetMutableScope() &&
global_scope_->GetMutableLocalScope()) {
VLOG(4) << "Clear previous local scope before run";
VLOG(4) << global_scope_->GetMutableScope() << " "
<< global_scope_->GetMutableLocalScope();
platform::DeviceContextPool::Instance().Get(place_)->Wait();
// TODO(zhiqiu): clear the tensor holder of all vars in previous local
// scope?
}
global_scope_->SetLocalScope(local_scope_);
paddle::framework::interpreter::build_variable_scope(
block_, global_scope_, create_local_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(place_, paddle::framework::interpreter::build_op_func_list(place_,
block_, block_,
skip_gc_vars_, skip_gc_vars_,
&op_func_nodes, &op_func_nodes,
global_scope_, &var_scope_,
create_local_scope_); create_local_scope_);
is_build_ = true; is_build_ = true;
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
...@@ -203,9 +190,6 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -203,9 +190,6 @@ paddle::framework::FetchList InterpreterCore::Run(
Convert(&op_func_nodes); Convert(&op_func_nodes);
} else { } else {
// add listener before run and is_build=true
global_scope_->ResetListener();
// For the program that only run once, it is no need to // For the program that only run once, it is no need to
// create work_queue, so the async_work_queue_ is created // create work_queue, so the async_work_queue_ is created
// until the second step run. // until the second step run.
...@@ -218,12 +202,13 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -218,12 +202,13 @@ paddle::framework::FetchList InterpreterCore::Run(
ClearLoDTensorArrayInLocalScope(); ClearLoDTensorArrayInLocalScope();
} }
// clear the listener after run
global_scope_->ClearListener();
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>()); if (fetch_var) {
return std::move(*fetch_var->GetMutable<framework::FetchList>());
} else {
return {};
}
} }
void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) { void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
...@@ -237,14 +222,14 @@ void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) { ...@@ -237,14 +222,14 @@ void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) {
} }
bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) { bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) {
if (!global_scope_->VarDesc(var_index)) { if (!var_scope_.VarDesc(var_index)) {
return input_var2op_info_.at(var_index).size() == 1; return input_var2op_info_.at(var_index).size() == 1;
} else { } else {
int is_input_cnt = 0; int is_input_cnt = 0;
for (auto inst_id : input_var2op_info_.at(var_index)) { for (auto inst_id : input_var2op_info_.at(var_index)) {
OpInOutInfo info; OpInOutInfo info;
info.Build(vec_instruction_.at(inst_id).OpBase()); info.Build(vec_instruction_.at(inst_id).OpBase());
if (info.IsInArgBufferNeeded(global_scope_->VarDesc(var_index)->Name())) { if (info.IsInArgBufferNeeded(var_scope_.VarDesc(var_index)->Name())) {
is_input_cnt++; is_input_cnt++;
} }
} }
...@@ -267,7 +252,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { ...@@ -267,7 +252,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
input_vars.reserve(var_name_item.second.size()); input_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) { for (auto& id : var_name_item.second) {
input_vars.emplace_back(global_scope_->Var(id)); input_vars.emplace_back(
local_scope_->FindVar(var_scope_.GetNameById(id)));
} }
ins_map.emplace(var_name_item.first, std::move(input_vars)); ins_map.emplace(var_name_item.first, std::move(input_vars));
} }
...@@ -278,7 +264,7 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { ...@@ -278,7 +264,7 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
out_vars.reserve(var_name_item.second.size()); out_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) { for (auto& id : var_name_item.second) {
out_vars.emplace_back(global_scope_->Var(id)); out_vars.emplace_back(local_scope_->FindVar(var_scope_.GetNameById(id)));
} }
outs_map.emplace(var_name_item.first, std::move(out_vars)); outs_map.emplace(var_name_item.first, std::move(out_vars));
} }
...@@ -286,9 +272,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) { ...@@ -286,9 +272,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
// set runtime_ctx and infershape_ctx_ // set runtime_ctx and infershape_ctx_
if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in
// kernel // kernel
Scope* local_scope = create_local_scope_ Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope()
? global_scope_->GetMutableLocalScope() : var_scope_.GetMutableScope();
: global_scope_->GetMutableScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope); instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
} else { } else {
instr_node->ResetContext(ins_map, outs_map); instr_node->ResetContext(ins_map, outs_map);
...@@ -311,25 +296,26 @@ void InterpreterCore::BuildInplace() { ...@@ -311,25 +296,26 @@ void InterpreterCore::BuildInplace() {
for (auto& pair : in_to_outs) { for (auto& pair : in_to_outs) {
auto iter = inputs.find(pair.first); auto iter = inputs.find(pair.first);
if (iter != inputs.end() && !iter->second.empty()) { if (iter != inputs.end() && !iter->second.empty()) {
auto in_var_desc = global_scope_->VarDesc(iter->second[0]); auto in_var_desc = var_scope_.VarDesc(iter->second[0]);
if (in_var_desc && in_var_desc->Persistable()) { if (in_var_desc && in_var_desc->Persistable()) {
continue; continue;
} }
if (global_scope_->GetVarSikpInplace(iter->second[0])) { if (var_scope_.GetVarSikpInplace(iter->second[0])) {
continue; continue;
} }
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) { if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
auto iterout = outputs.find(pair.second); auto iterout = outputs.find(pair.second);
if (iterout != outputs.end() && !iterout->second.empty()) { if (iterout != outputs.end() && !iterout->second.empty()) {
auto invar = global_scope_->Var(iter->second[0]); auto invar =
auto outvar = global_scope_->Var(iterout->second[0]); local_scope_->FindVar(var_scope_.GetNameById(iter->second[0]));
auto outvar = local_scope_->FindVar(
var_scope_.GetNameById(iterout->second[0]));
if (invar && outvar && invar->IsType<LoDTensor>() && if (invar && outvar && invar->IsType<LoDTensor>() &&
outvar->IsType<LoDTensor>()) { outvar->IsType<LoDTensor>()) {
instr.AddInplace(invar, outvar); instr.AddInplace(invar, outvar);
VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type() VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type()
<< " " << global_scope_->GetNameById(iter->second[0]) << " " << var_scope_.GetNameById(iter->second[0])
<< " -> " << " -> " << var_scope_.GetNameById(iterout->second[0])
<< global_scope_->GetNameById(iterout->second[0])
<< std::endl; << std::endl;
} }
} }
...@@ -372,8 +358,8 @@ void InterpreterCore::ClearLoDTensorArrayInLocalScope() { ...@@ -372,8 +358,8 @@ void InterpreterCore::ClearLoDTensorArrayInLocalScope() {
void InterpreterCore::Convert( void InterpreterCore::Convert(
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) { std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
auto& vec_meta_info = global_scope_->MutableVecMetaInfo(); auto& vec_meta_info = var_scope_.MutableVecMetaInfo();
auto var_nums = global_scope_->VarSize(); auto var_nums = var_scope_.VarSize();
input_var2op_info_.resize(var_nums); input_var2op_info_.resize(var_nums);
auto nodes = *op_func_nodes; auto nodes = *op_func_nodes;
...@@ -403,7 +389,7 @@ void InterpreterCore::Convert( ...@@ -403,7 +389,7 @@ void InterpreterCore::Convert(
if (!info.IsBuilt()) { if (!info.IsBuilt()) {
info.Build(instr.OpBase()); info.Build(instr.OpBase());
} }
auto* var_desc = global_scope_->VarDesc(id); auto* var_desc = var_scope_.VarDesc(id);
if (var_desc) { if (var_desc) {
if (info.IsInArgBufferNeeded(var_desc->Name())) { if (info.IsInArgBufferNeeded(var_desc->Name())) {
gc_check_inputs.insert(id); gc_check_inputs.insert(id);
...@@ -415,14 +401,14 @@ void InterpreterCore::Convert( ...@@ -415,14 +401,14 @@ void InterpreterCore::Convert(
} }
for (auto var_id : gc_check_inputs) { for (auto var_id : gc_check_inputs) {
paddle::framework::Variable* var = global_scope_->Var(var_id); paddle::framework::Variable* var =
local_scope_->FindVar(var_scope_.GetNameById(var_id));
if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>() || if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>() ||
var->IsType<LoDTensorArray>()) { var->IsType<LoDTensorArray>()) {
last_live_ops_[var_id].insert(op_idx); last_live_ops_[var_id].insert(op_idx);
} else { } else {
VLOG(4) << "not clear " << global_scope_->GetNameById(var_id) VLOG(4) << "not clear " << var_scope_.GetNameById(var_id) << " after "
<< " after " << instr.OpBase()->Type() << instr.OpBase()->Type() << " because its type is "
<< " because its type is "
<< framework::ToTypeName(var->Type()); << framework::ToTypeName(var->Type());
} }
} }
...@@ -441,7 +427,7 @@ void InterpreterCore::Convert( ...@@ -441,7 +427,7 @@ void InterpreterCore::Convert(
// clear the last_live_ops list for all vars in skip_gc_vars // clear the last_live_ops list for all vars in skip_gc_vars
for (const std::string& skip_gc_var : skip_gc_vars_) { for (const std::string& skip_gc_var : skip_gc_vars_) {
int var_id = global_scope_->GetIdByName(skip_gc_var); int var_id = var_scope_.GetIdByName(skip_gc_var);
if (var_id != -1) { if (var_id != -1) {
last_live_ops_[var_id].clear(); last_live_ops_[var_id].clear();
VLOG(8) << "Skip gc for var: " << skip_gc_var; VLOG(8) << "Skip gc for var: " << skip_gc_var;
...@@ -470,7 +456,7 @@ void InterpreterCore::Convert( ...@@ -470,7 +456,7 @@ void InterpreterCore::Convert(
} }
if (not_before_any) { if (not_before_any) {
VLOG(8) << "last live op of var " << i << " " VLOG(8) << "last live op of var " << i << " "
<< global_scope_->GetNameById(i) << " : " << item << " " << var_scope_.GetNameById(i) << " : " << item << " "
<< vec_instruction_[item].OpBase()->Type(); << vec_instruction_[item].OpBase()->Type();
minumum_last_live_ops.insert(item); minumum_last_live_ops.insert(item);
vec_instruction_[item].AddGCCheckVar(i); vec_instruction_[item].AddGCCheckVar(i);
...@@ -513,7 +499,7 @@ void InterpreterCore::Convert( ...@@ -513,7 +499,7 @@ void InterpreterCore::Convert(
std::promise<std::unique_ptr<AtomicVectorSizeT>>(); std::promise<std::unique_ptr<AtomicVectorSizeT>>();
atomic_var_ref_ = var_ref_promise.get_future(); atomic_var_ref_ = var_ref_promise.get_future();
var_ref_promise.set_value( var_ref_promise.set_value(
interpreter::PrepareAtomicVarRef(global_scope_->VecMetaInfo())); interpreter::PrepareAtomicVarRef(var_scope_.VecMetaInfo()));
} }
void InterpreterCore::BuildSkipShareLoDInfo() { void InterpreterCore::BuildSkipShareLoDInfo() {
...@@ -539,10 +525,9 @@ void InterpreterCore::BuildSkipShareLoDInfo() { ...@@ -539,10 +525,9 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
void InterpreterCore::RunInstruction(const Instruction& instr_node) { void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto* op = instr_node.OpBase(); auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace(); auto place = instr_node.DeviceContext().GetPlace();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(global_scope_); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_);
Scope* local_scope = create_local_scope_ Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope()
? global_scope_->GetMutableLocalScope() : var_scope_.GetMutableScope();
: global_scope_->GetMutableScope();
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op); auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
{ {
// If it is OperatorBase, InferShape do nothing. // If it is OperatorBase, InferShape do nothing.
...@@ -607,7 +592,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -607,7 +592,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
} }
} }
VLOG(4) << "End run " << place << " " << op->DebugStringEx(global_scope_); VLOG(4) << "End run " << place << " " << op->DebugStringEx(local_scope_);
if (!instr_node.InplaceBackMap().empty()) { if (!instr_node.InplaceBackMap().empty()) {
platform::RecordEvent inplaceback_event( platform::RecordEvent inplaceback_event(
...@@ -616,13 +601,13 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -616,13 +601,13 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
// NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc
for (auto& p : m) { for (auto& p : m) {
auto* transformed_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar( auto* transformed_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
global_scope_->Var(p.first)); var_scope_.VarRef(p.first));
auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar( auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
global_scope_->Var(p.second)); var_scope_.VarRef(p.second));
original_tensor->ShareDataWith(*transformed_tensor); original_tensor->ShareDataWith(*transformed_tensor);
VLOG(4) << "Transfer inplace variable back form " VLOG(4) << "Transfer inplace variable back form "
<< global_scope_->GetNameById(p.first) << " to " << var_scope_.GetNameById(p.first) << " to "
<< global_scope_->GetNameById(p.second); << var_scope_.GetNameById(p.second);
} }
} }
...@@ -641,7 +626,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -641,7 +626,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
VLOG(4) << "Check nan/inf"; VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf( framework::details::CheckOpHasNanOrInf(
*op, *op,
*global_scope_, *local_scope_,
place); // TODO(xiongkun03) change it to inner scope. place); // TODO(xiongkun03) change it to inner scope.
} }
} }
...@@ -663,7 +648,7 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -663,7 +648,7 @@ void InterpreterCore::ExecuteInstructionList(
atomic_deps_ = async_work_queue_->PrepareAtomicDeps(dependecy_count_); atomic_deps_ = async_work_queue_->PrepareAtomicDeps(dependecy_count_);
atomic_var_ref_ = atomic_var_ref_ =
async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo()); async_work_queue_->PrepareAtomicVarRef(var_scope_.VecMetaInfo());
record_prepare.End(); record_prepare.End();
exception_holder_.Clear(); exception_holder_.Clear();
...@@ -898,16 +883,16 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) { ...@@ -898,16 +883,16 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
* supported later. * supported later.
*/ */
for (int var_id : instr.GCCheckVars()) { for (int var_id : instr.GCCheckVars()) {
VLOG(4) << "GC sync " << global_scope_->GetNameById(var_id) << " " VLOG(4) << "GC sync " << var_scope_.GetNameById(var_id) << " "
<< global_scope_->VarDesc(var_id); << var_scope_.VarDesc(var_id);
// persistable var will be ignore while GC // persistable var will be ignore while GC
if (global_scope_->VarDesc(var_id) && if (var_scope_.VarDesc(var_id) &&
global_scope_->VarDesc(var_id)->Persistable()) { var_scope_.VarDesc(var_id)->Persistable()) {
continue; continue;
} }
paddle::framework::Variable* var = global_scope_->Var(var_id); paddle::framework::Variable* var = var_scope_.VarRef(var_id);
if (var == nullptr) { if (var == nullptr) {
continue; continue;
} }
...@@ -943,10 +928,10 @@ void InterpreterCore::CheckGC( ...@@ -943,10 +928,10 @@ void InterpreterCore::CheckGC(
platform::RecordEvent record( platform::RecordEvent record(
"CheckGC", platform::TracerEventType::UserDefined, 10); "CheckGC", platform::TracerEventType::UserDefined, 10);
size_t instr_id = instr.Id(); size_t instr_id = instr.Id();
auto& var_scope = *global_scope_; auto& var_scope = var_scope_;
for (auto var_id : instr.GCCheckVars()) { for (auto var_id : instr.GCCheckVars()) {
VLOG(4) << "GC " << global_scope_->GetNameById(var_id) << " " VLOG(4) << "GC " << var_scope_.GetNameById(var_id) << " "
<< var_scope.VarDesc(var_id); << var_scope.VarDesc(var_id);
VLOG(4) << "atomic:" << atomic_var_ref << " " << &(*atomic_var_ref)[var_id] VLOG(4) << "atomic:" << atomic_var_ref << " " << &(*atomic_var_ref)[var_id]
<< " " << var_id; << " " << var_id;
...@@ -962,17 +947,17 @@ void InterpreterCore::CheckGC( ...@@ -962,17 +947,17 @@ void InterpreterCore::CheckGC(
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (IsInterpretercoreFastGCEnabled()) { if (IsInterpretercoreFastGCEnabled()) {
static_cast<InterpreterCoreFastGarbageCollector*>(gc_.get())->Add( static_cast<InterpreterCoreFastGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id)); var_scope_.VarRef(var_id));
} else { } else {
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add( static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id), var_scope_.VarRef(var_id),
&gc_event_.at(instr_id), &gc_event_.at(instr_id),
&instr.DeviceContext()); &instr.DeviceContext());
} }
#else #else
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add( static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id), var_scope_.VarRef(var_id),
&gc_event_.at(instr_id), &gc_event_.at(instr_id),
&instr.DeviceContext()); &instr.DeviceContext());
#endif #endif
...@@ -995,7 +980,7 @@ void InterpreterCore::Prepare( ...@@ -995,7 +980,7 @@ void InterpreterCore::Prepare(
auto FeedInput = [&] { auto FeedInput = [&] {
VLOG(4) << "Feed inputs"; VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) { for (size_t i = 0; i < feed_names.size(); ++i) {
auto* feed_var = global_scope_->FindVar(feed_names[i]); auto* feed_var = local_scope_->FindVar(feed_names[i]);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
feed_var, feed_var,
platform::errors::NotFound("Variable %s should not be nullptr.", platform::errors::NotFound("Variable %s should not be nullptr.",
...@@ -1009,14 +994,14 @@ void InterpreterCore::Prepare( ...@@ -1009,14 +994,14 @@ void InterpreterCore::Prepare(
if (!is_build_) { if (!is_build_) {
paddle::framework::interpreter::build_variable_scope( paddle::framework::interpreter::build_variable_scope(
block_, global_scope_, create_local_scope_); block_, &var_scope_, create_local_scope_);
FeedInput(); FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(place_, paddle::framework::interpreter::build_op_func_list(place_,
block_, block_,
skip_gc_vars_, skip_gc_vars_,
&op_func_nodes, &op_func_nodes,
global_scope_, &var_scope_,
create_local_scope_); create_local_scope_);
is_build_ = true; is_build_ = true;
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
...@@ -1034,14 +1019,14 @@ void InterpreterCore::Prepare( ...@@ -1034,14 +1019,14 @@ void InterpreterCore::Prepare(
void InterpreterCore::SetFeedVarsInplaceSkip( void InterpreterCore::SetFeedVarsInplaceSkip(
const std::vector<std::string>& feed_names) { const std::vector<std::string>& feed_names) {
for (auto& feed_name : feed_names) { for (auto& feed_name : feed_names) {
global_scope_->SetVarSikpInplace(feed_name, true); var_scope_.SetVarSikpInplace(feed_name, true);
} }
} }
std::shared_ptr<InterpreterCore> CreateInterpreterCore( std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place, const platform::Place& place,
const ProgramDesc& prog, const ProgramDesc& prog,
VariableScope* global_scope, Scope* scope,
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_names,
const std::set<std::string>& skip_gc_vars) { const std::set<std::string>& skip_gc_vars) {
std::shared_ptr<InterpreterCore> core = nullptr; std::shared_ptr<InterpreterCore> core = nullptr;
...@@ -1051,8 +1036,7 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCore( ...@@ -1051,8 +1036,7 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCore(
auto* block = new_prog->MutableBlock(0); auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block); interpreter::add_fetch(fetch_names, block);
core = std::make_shared<InterpreterCore>( core = std::make_shared<InterpreterCore>(place, *block, skip_gc_vars, scope);
place, *block, skip_gc_vars, global_scope);
core->SetCopyProgram(new_prog); core->SetCopyProgram(new_prog);
return core; return core;
} }
......
...@@ -40,7 +40,7 @@ class InterpreterCore { ...@@ -40,7 +40,7 @@ class InterpreterCore {
InterpreterCore(const platform::Place& place, InterpreterCore(const platform::Place& place,
const BlockDesc& block, const BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
VariableScope* global_scope); Scope* scope);
~InterpreterCore(); ~InterpreterCore();
...@@ -112,7 +112,10 @@ class InterpreterCore { ...@@ -112,7 +112,10 @@ class InterpreterCore {
// new program is deleted. // new program is deleted.
std::shared_ptr<ProgramDesc> copy_program_{nullptr}; std::shared_ptr<ProgramDesc> copy_program_{nullptr};
VariableScope* global_scope_; // not owned // from variable scope
std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
...@@ -125,6 +128,10 @@ class InterpreterCore { ...@@ -125,6 +128,10 @@ class InterpreterCore {
std::atomic<size_t> unfinished_op_numer_{0}; std::atomic<size_t> unfinished_op_numer_{0};
std::vector<std::vector<size_t>> input_var2op_info_; std::vector<std::vector<size_t>> input_var2op_info_;
VariableScope var_scope_;
bool create_local_scope_{true};
Scope* local_scope_{nullptr}; // not owned
StreamAnalyzer stream_analyzer_; StreamAnalyzer stream_analyzer_;
EventsWaiter main_thread_blocker_; EventsWaiter main_thread_blocker_;
std::shared_ptr<interpreter::AsyncWorkQueue> async_work_queue_; std::shared_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
...@@ -134,8 +141,6 @@ class InterpreterCore { ...@@ -134,8 +141,6 @@ class InterpreterCore {
std::unique_ptr<InterpreterCoreGarbageCollector> gc_; std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_; std::vector<paddle::platform::DeviceEvent> gc_event_;
bool create_local_scope_{true};
Scope* local_scope_{nullptr}; // not owned
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_deps_; std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_deps_;
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_var_ref_; std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_var_ref_;
...@@ -144,7 +149,7 @@ class InterpreterCore { ...@@ -144,7 +149,7 @@ class InterpreterCore {
std::shared_ptr<InterpreterCore> CreateInterpreterCore( std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place, const platform::Place& place,
const ProgramDesc& prog, const ProgramDesc& prog,
VariableScope* global_scope, Scope* global_scope,
const std::vector<std::string>& fetch_names = {}, const std::vector<std::string>& fetch_names = {},
const std::set<std::string>& skip_gc_vars = {}); const std::set<std::string>& skip_gc_vars = {});
......
...@@ -50,6 +50,7 @@ namespace interpreter { ...@@ -50,6 +50,7 @@ namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>; using VariableIdMap = std::map<std::string, std::vector<int>>;
constexpr size_t kPrepareWorkQueueIdx = 2; constexpr size_t kPrepareWorkQueueIdx = 2;
const char blocking_queue_prefix[] = "lod_tensor_blocking_queue";
const std::vector<WorkQueueOptions> ConstructWorkQueueOptions( const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) { size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) {
...@@ -225,6 +226,7 @@ void build_variable_scope(const framework::BlockDesc& block, ...@@ -225,6 +226,7 @@ void build_variable_scope(const framework::BlockDesc& block,
if (var_name == framework::kEmptyVarName) { if (var_name == framework::kEmptyVarName) {
continue; continue;
} }
if (var_desc->Persistable()) { if (var_desc->Persistable()) {
auto* ptr = inner_scope->Var(var_name); auto* ptr = inner_scope->Var(var_name);
...@@ -241,7 +243,7 @@ void build_variable_scope(const framework::BlockDesc& block, ...@@ -241,7 +243,7 @@ void build_variable_scope(const framework::BlockDesc& block,
<< ptr << "Variable Type " << ptr << "Variable Type "
<< static_cast<int>(var_desc->GetType()); << static_cast<int>(var_desc->GetType());
} }
var_scope->SetVarDesc(var_name, var_desc); var_scope->AddVar(var_name, var_desc);
} }
} }
...@@ -279,6 +281,7 @@ void create_all_ops(const framework::BlockDesc& block, ...@@ -279,6 +281,7 @@ void create_all_ops(const framework::BlockDesc& block,
std::tuple<VariableValueMap, VariableIdMap> build_variable_map( std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
const VariableNameMap& var_name_map, const VariableNameMap& var_name_map,
VariableScope* var_scope, VariableScope* var_scope,
Scope* local_scope,
bool enforce_exist = true) { bool enforce_exist = true) {
VariableValueMap name2var; VariableValueMap name2var;
VariableIdMap name2id; VariableIdMap name2id;
...@@ -288,14 +291,22 @@ std::tuple<VariableValueMap, VariableIdMap> build_variable_map( ...@@ -288,14 +291,22 @@ std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
vars.reserve(item.second.size()); vars.reserve(item.second.size());
for (auto& var_name : item.second) { for (auto& var_name : item.second) {
if (!enforce_exist && !var_scope->HasVar(var_name)) { if (!var_scope->HasVar(var_name)) {
// skip the non-exist variable: such as recurrent_grad // Hot fix for variables used in dataloader, like
VLOG(4) << var_name << " don't exist in variable scope, skip it!"; // 'lod_tensor_blocking_queue_0' These variables may be created in
continue; // scope, and it is not existed as variable in program.
if (var_name.find(blocking_queue_prefix) != std::string::npos &&
local_scope->FindVar(var_name)) {
var_scope->AddVar(var_name, nullptr);
} else if (!enforce_exist) {
// skip the non-exist variable: such as recurrent_grad
VLOG(4) << var_name << " don't exist in variable scope, skip it!";
continue;
}
} }
auto* var = local_scope->FindVar(var_name);
auto var_id = var_scope->VarId(var_name); auto var_id = var_scope->VarId(var_name);
auto* in_var = var_scope->Var(var_id); vars.push_back(var);
vars.push_back(in_var);
ids.push_back(var_id); ids.push_back(var_id);
} }
name2var[item.first] = std::move(vars); name2var[item.first] = std::move(vars);
...@@ -421,12 +432,12 @@ void build_op_func_list(const platform::Place& place, ...@@ -421,12 +432,12 @@ void build_op_func_list(const platform::Place& place,
enforce_exist = false; enforce_exist = false;
} }
std::tie(ins_map, ins_name2id) = std::tie(ins_map, ins_name2id) =
build_variable_map(inputs_names, var_scope, enforce_exist); build_variable_map(inputs_names, var_scope, local_scope, enforce_exist);
VariableValueMap outs_map; VariableValueMap outs_map;
VariableIdMap outs_name2id; VariableIdMap outs_name2id;
std::tie(outs_map, outs_name2id) = std::tie(outs_map, outs_name2id) = build_variable_map(
build_variable_map(outputs_names, var_scope, enforce_exist); outputs_names, var_scope, local_scope, enforce_exist);
// step 1: build OpFuncNode // step 1: build OpFuncNode
OpFuncNode op_func_node; OpFuncNode op_func_node;
...@@ -573,9 +584,9 @@ void build_op_func_list(const platform::Place& place, ...@@ -573,9 +584,9 @@ void build_op_func_list(const platform::Place& place,
for (auto& p : m) { for (auto& p : m) {
auto* transformed_tensor = auto* transformed_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar( GetMutableLoDTensorOrSelectedRowsValueFromVar(
var_scope->Var(p.first)); local_scope->FindVar(var_scope->GetNameById(p.first)));
auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar( auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
var_scope->Var(p.second)); local_scope->FindVar(var_scope->GetNameById(p.second)));
original_tensor->ShareDataWith(*transformed_tensor); original_tensor->ShareDataWith(*transformed_tensor);
VLOG(4) << "Transfer inplace variable back form " VLOG(4) << "Transfer inplace variable back form "
<< var_scope->GetNameById(p.first) << " to " << var_scope->GetNameById(p.first) << " to "
...@@ -600,7 +611,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -600,7 +611,7 @@ void build_op_func_list(const platform::Place& place,
new std::deque<std::shared_ptr<memory::Allocation>>(); new std::deque<std::shared_ptr<memory::Allocation>>();
for (auto& var_name : delete_vars) { for (auto& var_name : delete_vars) {
auto* var = var_scope->FindVar(var_name); auto* var = local_scope->FindVar(var_name);
if (var == nullptr || skip_gc_vars.find(var_name) != skip_gc_vars.end()) { if (var == nullptr || skip_gc_vars.find(var_name) != skip_gc_vars.end()) {
continue; continue;
} }
......
...@@ -86,7 +86,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -86,7 +86,7 @@ void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block, const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars, const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope, VariableScope* scope,
bool use_local_scope = true); bool use_local_scope = true);
std::map<int, std::list<int>> build_op_downstream_map( std::map<int, std::list<int>> build_op_downstream_map(
......
...@@ -560,8 +560,8 @@ const std::vector<Variable*>& InterpretercoreInferShapeContext::OutputVars( ...@@ -560,8 +560,8 @@ const std::vector<Variable*>& InterpretercoreInferShapeContext::OutputVars(
VariableScope::VariableScope(Scope* scope) { VariableScope::VariableScope(Scope* scope) {
// for @EMPTY@ variable // for @EMPTY@ variable
var_list_.push_back(nullptr);
name2id_[kEmptyVarName] = kEmptyVarIndex; name2id_[kEmptyVarName] = kEmptyVarIndex;
var_list_.push_back(nullptr);
vec_meta_info_.emplace_back(0, nullptr); vec_meta_info_.emplace_back(0, nullptr);
scope_ = scope; scope_ = scope;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -569,15 +569,9 @@ VariableScope::VariableScope(Scope* scope) { ...@@ -569,15 +569,9 @@ VariableScope::VariableScope(Scope* scope) {
nullptr, nullptr,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"You have passed a nullptr to construct VariableScope.")); "You have passed a nullptr to construct VariableScope."));
listener_ = std::make_shared<VariableScopeListener>(this);
scope->AddListener(listener_);
} }
VariableScope::~VariableScope() { VariableScope::~VariableScope() {}
if (scope_ && listener_) {
scope_->DelListener(listener_);
}
}
Scope* VariableScope::GetMutableScope() const { return scope_; } Scope* VariableScope::GetMutableScope() const { return scope_; }
...@@ -588,22 +582,6 @@ void VariableScope::SetLocalScope(Scope* local_scope) { ...@@ -588,22 +582,6 @@ void VariableScope::SetLocalScope(Scope* local_scope) {
local_scope_ = local_scope; local_scope_ = local_scope;
} }
Variable* VariableScope::FindVar(const std::string& name) const {
auto it = name2id_.find(name);
if (it != name2id_.end()) {
PADDLE_ENFORCE_LT(it->second,
var_list_.size(),
platform::errors::NotFound(
"The id(%d) of variable(%s) should not be larger "
"than the size of variable list(%d).",
it->second,
name,
var_list_.size()));
return var_list_[it->second];
}
return nullptr;
}
// Get variable id by name, return -1 if not found // Get variable id by name, return -1 if not found
int VariableScope::GetIdByName(const std::string& name) const { int VariableScope::GetIdByName(const std::string& name) const {
auto it = name2id_.find(name); auto it = name2id_.find(name);
...@@ -638,34 +616,23 @@ int VariableScope::VarId(const std::string& name) const { ...@@ -638,34 +616,23 @@ int VariableScope::VarId(const std::string& name) const {
return name2id_.at(name); return name2id_.at(name);
} }
Variable* VariableScope::Var(int id) const { return var_list_.at(id); } Variable* VariableScope::VarRef(int id) const { return var_list_[id]; }
Variable* VariableScope::Var(const std::string& name) const {
return var_list_.at(VarId(name));
}
size_t VariableScope::VarSize() const { return var_list_.size(); } size_t VariableScope::VarSize() const { return name2id_.size(); }
void VariableScope::AddVar(const std::string& name, void VariableScope::AddVar(const std::string& name,
framework::VarDesc* var_desc, framework::VarDesc* var_desc) {
bool local_scope) { // NOLINT if (!HasVar(name)) {
auto v = local_scope ? local_scope_->Var(name) : scope_->Var(name); auto id = VarSize();
if (nullptr == var_desc) { name2id_[name] = id;
v->GetMutable<LoDTensor>(); vec_meta_info_.emplace_back(0, var_desc);
} else { var_list_.push_back(local_scope_->FindVar(name));
InitializeVariable( PADDLE_ENFORCE_EQ(
v, var_list_.size(),
var_desc name2id_.size(),
->GetType()); // Scope don't initialize variable recently created platform::errors::InvalidArgument(
"The size of var_list and name2id map should be equal"));
} }
SetVarDesc(name, var_desc);
}
void VariableScope::AddVar(const std::string& name,
const Variable& var) { // NOLINT
// Though name existed in outer_scope_, we need
// add again to create name2id map.
scope_->Var(name);
} }
void VariableScope::SetVarDesc(const std::string& name, void VariableScope::SetVarDesc(const std::string& name,
...@@ -696,10 +663,10 @@ bool VariableScope::GetVarSikpInplace(int id) const { ...@@ -696,10 +663,10 @@ bool VariableScope::GetVarSikpInplace(int id) const {
void VariableScope::CheckExist(int id) const { void VariableScope::CheckExist(int id) const {
PADDLE_ENFORCE_LT(id, PADDLE_ENFORCE_LT(id,
var_list_.size(), name2id_.size(),
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Required var_id < %d, but received var_id = %d.", "Required var_id < %d, but received var_id = %d.",
var_list_.size(), name2id_.size(),
id)); id));
} }
...@@ -710,55 +677,6 @@ void VariableScope::CheckExist(const std::string& name) const { ...@@ -710,55 +677,6 @@ void VariableScope::CheckExist(const std::string& name) const {
platform::errors::NotFound("%s not in VariableScope.", name)); platform::errors::NotFound("%s not in VariableScope.", name));
} }
void VariableScope::ClearListener() {
if (scope_ && listener_ && scope_->HasListener(listener_)) {
VLOG(4) << "Clear listener " << listener_ << " for " << scope_;
scope_->DelListener(listener_);
}
if (local_scope_ && listener_ && local_scope_->HasListener(listener_)) {
VLOG(4) << "Clear listener " << listener_ << " for " << local_scope_;
local_scope_->DelListener(listener_);
}
}
void VariableScope::ResetListener() {
if (scope_ && listener_ && !scope_->HasListener(listener_)) {
VLOG(4) << "Add listener " << listener_ << " for " << scope_;
scope_->AddListener(listener_);
}
if (local_scope_ && listener_ && !local_scope_->HasListener(listener_)) {
VLOG(4) << "Add listener " << listener_ << " for " << local_scope_;
local_scope_->AddListener(listener_);
}
}
VariableScopeListener::VariableScopeListener(VariableScope* var_scope) {
var_scope_ = var_scope;
}
void VariableScopeListener::onCreateVariable(const std::string& name,
Variable* v) {
if (!var_scope_->HasVar(name)) { // may exist in variable scope.
VLOG(4) << "Calling VariableScope::onCreateVariable with var_name: "
<< name;
var_scope_->name2id_[name] = var_scope_->VarSize();
var_scope_->var_list_.emplace_back(v);
var_scope_->vec_meta_info_.emplace_back(0, nullptr);
}
}
void VariableScopeListener::onDeleteVariable(const std::string& name) {
if (var_scope_->HasVar(name)) {
VLOG(4) << "Calling VariableScope::onDeleteVariable with var_name: "
<< name;
}
}
void VariableScopeListener::onRenameVariable(const std::string& old_name,
const std::string& new_name) {}
void VariableScopeListener::onCreateScope(Scope* Scope) {}
void VariableScopeListener::onDeleteScope(Scope* Scope) {}
void VariableScopeListener::onClear() {}
Instruction::Instruction(size_t id, Instruction::Instruction(size_t id,
OpFuncNode&& op_func_node, OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx) const platform::DeviceContext& dev_ctx)
......
...@@ -168,29 +168,7 @@ struct VariableMetaInfo { ...@@ -168,29 +168,7 @@ struct VariableMetaInfo {
: var_ref_count_(var_ref_count), var_desc_(var_desc) {} : var_ref_count_(var_ref_count), var_desc_(var_desc) {}
}; };
class VariableScope; class VariableScope {
class VariableScopeListener : public ScopeListener {
public:
explicit VariableScopeListener(VariableScope* var_scope_);
void onCreateVariable(const std::string& name, Variable* v) override;
void onDeleteVariable(const std::string& name) override;
void onRenameVariable(const std::string& old_name,
const std::string& new_name) override;
void onCreateScope(Scope* Scope) override;
void onDeleteScope(Scope* Scope) override;
void onClear() override;
private:
VariableScope* var_scope_; // not owned
};
// TODO(zhiqiu): Maybe we need to add rwlock for VariableScope?
// NOTE(xiongkun03): Use scope as a member of VariableScope, we don't need
// ScopeBase. Scope manager the variables and VariableScope is just a quick
// access machanism. ScopeListener is the callback to sync changes in Original
// Scope. We can make it a membership of VariableScope. Here we use inherent.
class VariableScope : public ScopeBase {
public: public:
explicit VariableScope(Scope* scope); explicit VariableScope(Scope* scope);
...@@ -200,8 +178,6 @@ class VariableScope : public ScopeBase { ...@@ -200,8 +178,6 @@ class VariableScope : public ScopeBase {
void SetLocalScope(Scope* local_scope); void SetLocalScope(Scope* local_scope);
Variable* FindVar(const std::string& name) const;
~VariableScope(); ~VariableScope();
// Get variable id by name, return -1 if not found // Get variable id by name, return -1 if not found
...@@ -214,17 +190,11 @@ class VariableScope : public ScopeBase { ...@@ -214,17 +190,11 @@ class VariableScope : public ScopeBase {
int VarId(const std::string& name) const; int VarId(const std::string& name) const;
Variable* Var(int id) const;
Variable* Var(const std::string& name) const;
size_t VarSize() const; size_t VarSize() const;
void AddVar(const std::string& name, void AddVar(const std::string& name, VarDesc* var_desc);
VarDesc* var_desc,
bool local_scope = false);
void AddVar(const std::string& name, const Variable& var); Variable* VarRef(int id) const;
void SetVarDesc(const std::string& name, framework::VarDesc* var_desc); void SetVarDesc(const std::string& name, framework::VarDesc* var_desc);
...@@ -242,29 +212,22 @@ class VariableScope : public ScopeBase { ...@@ -242,29 +212,22 @@ class VariableScope : public ScopeBase {
return vec_meta_info_; return vec_meta_info_;
} }
const std::shared_ptr<VariableScopeListener>& Listener() const {
return listener_;
}
void SetVarSikpInplace(const std::string& name, bool skip); void SetVarSikpInplace(const std::string& name, bool skip);
bool GetVarSikpInplace(int id) const; bool GetVarSikpInplace(int id) const;
void ClearListener();
void ResetListener();
friend class VariableScopeListener;
private: private:
// not owned, better remove it since all vars should be
// accessed by Scope instead of VariableScope
std::vector<Variable*> var_list_; std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_; std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_; std::vector<VariableMetaInfo> vec_meta_info_;
Scope* scope_{nullptr}; Scope* scope_{nullptr};
// TODO(zhiqiu): find a better way to support local scope. // TODO(zhiqiu): find a better way to support local scope.
Scope* local_scope_{nullptr}; Scope* local_scope_{nullptr};
// mutable RWLock vars_lock_; // mutable RWLock vars_lock_;
std::shared_ptr<VariableScopeListener> listener_{nullptr};
}; };
class NextInstruction { class NextInstruction {
......
...@@ -25,41 +25,12 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -25,41 +25,12 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
: place_(place), : place_(place),
startup_prog_(startup_prog), startup_prog_(startup_prog),
main_prog_(main_prog), main_prog_(main_prog),
global_scope_(VariableScope(scope)) { scope_(scope) {
// NOTE(zhiqiu): it is needed to sync the variables in scope to // NOTE(zhiqiu): for startup_program, run once ?
// variable_scope, since the some variable only exists in scope.
// For example, 'lod_tensor_blocking_queue_0' used in dataloader.
// These variables may be created in scope, and it is not existed as
// variable in program.
if (scope) {
const std::string blocking_queue_prefix = "lod_tensor_blocking_queue";
auto vars = scope->LocalVarNames();
for (const auto& name : vars) {
if (name.find(blocking_queue_prefix) != std::string::npos) {
if (!global_scope_.HasVar(name)) {
auto* v = scope->Var(name);
VLOG(4) << "Sync Variable from scope to variable scope: " << name;
global_scope_.AddVar(name, *v);
}
}
}
}
// NOTE(zhiqiu): for startup_program, initialize scope and run once
// if startup_program is empty, the scope is initialize during first run
if (startup_prog.Block(0).AllOps().size() > 0) { if (startup_prog.Block(0).AllOps().size() > 0) {
VLOG(4) << "Run startup program"; auto core = GetInterpreterCore(startup_prog, {}, {}, false);
// init scope VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
BuildVariableScope(startup_prog, &global_scope_); core->Run({});
std::vector<paddle::framework::OpFuncNode> vec_func_list;
// No need to use_local_scope for startup_program, its variables are
// persistable
paddle::framework::interpreter::build_op_func_list(place_,
startup_prog.Block(0),
{},
&vec_func_list,
&global_scope_,
false);
} }
} }
...@@ -70,7 +41,7 @@ paddle::framework::FetchList StandaloneExecutor::Run( ...@@ -70,7 +41,7 @@ paddle::framework::FetchList StandaloneExecutor::Run(
platform::RecordEvent record_event( platform::RecordEvent record_event(
"StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1); "StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1);
auto core = GetInterpreterCore(feed_names, fetch_names, true); auto core = GetInterpreterCore(main_prog_, feed_names, fetch_names, true);
return core->Run(feed_names, feed_tensors); return core->Run(feed_names, feed_tensors);
} }
...@@ -81,7 +52,7 @@ paddle::framework::FetchList StandaloneExecutor::Run( ...@@ -81,7 +52,7 @@ paddle::framework::FetchList StandaloneExecutor::Run(
platform::RecordEvent record_event( platform::RecordEvent record_event(
"StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1); "StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1);
auto core = GetInterpreterCore(feed_names, fetch_names, false); auto core = GetInterpreterCore(main_prog_, feed_names, fetch_names, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core; VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
return core->Run(feed_names); return core->Run(feed_names);
} }
...@@ -89,28 +60,13 @@ paddle::framework::FetchList StandaloneExecutor::Run( ...@@ -89,28 +60,13 @@ paddle::framework::FetchList StandaloneExecutor::Run(
framework::interpreter::CostInfo StandaloneExecutor::DryRun( framework::interpreter::CostInfo StandaloneExecutor::DryRun(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
auto core = GetInterpreterCore(feed_names, {}, true); auto core = GetInterpreterCore(main_prog_, feed_names, {}, true);
return core->DryRun(feed_names, feed_tensors); return core->DryRun(feed_names, feed_tensors);
} }
void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
for (auto& var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (!var_scope->HasVar(var->Name())) {
VLOG(4) << "Create variable from startup_prog: "
<< var->Proto()->SerializeAsString();
var_scope->AddVar(var->Name(), var);
}
}
}
std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore( std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
const ProgramDesc& prog,
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_names,
bool add_fetch_op) { bool add_fetch_op) {
...@@ -133,14 +89,13 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore( ...@@ -133,14 +89,13 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
std::shared_ptr<InterpreterCore> core = nullptr; std::shared_ptr<InterpreterCore> core = nullptr;
if (add_fetch_op) { if (add_fetch_op) {
core = CreateInterpreterCore( core = CreateInterpreterCore(place_, prog, scope_, fetch_names);
place_, main_prog_, &global_scope_, fetch_names);
} else { } else {
core = std::make_shared<InterpreterCore>( core = std::make_shared<InterpreterCore>(
place_, place_,
main_prog_.Block(0), prog.Block(0),
/*skip_gc_vars=*/std::set<std::string>(), /*skip_gc_vars=*/std::set<std::string>(),
&global_scope_); scope_);
} }
interpretercores_.emplace(oss.str(), core); interpretercores_.emplace(oss.str(), core);
return core; return core;
......
...@@ -54,10 +54,8 @@ class StandaloneExecutor { ...@@ -54,10 +54,8 @@ class StandaloneExecutor {
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
private: private:
void BuildVariableScope(const framework::ProgramDesc& pdesc,
VariableScope* var_scope);
std::shared_ptr<InterpreterCore> GetInterpreterCore( std::shared_ptr<InterpreterCore> GetInterpreterCore(
const ProgramDesc& prog,
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_names,
bool add_fetch_op); bool add_fetch_op);
...@@ -65,7 +63,7 @@ class StandaloneExecutor { ...@@ -65,7 +63,7 @@ class StandaloneExecutor {
platform::Place place_; platform::Place place_;
const ProgramDesc& startup_prog_; const ProgramDesc& startup_prog_;
const ProgramDesc& main_prog_; const ProgramDesc& main_prog_;
VariableScope global_scope_; Scope* scope_; // not owned
std::unordered_map<std::string, std::shared_ptr<InterpreterCore>> std::unordered_map<std::string, std::shared_ptr<InterpreterCore>>
interpretercores_; interpretercores_;
......
...@@ -137,29 +137,29 @@ ProgramDesc GetLmMainProgram() { ...@@ -137,29 +137,29 @@ ProgramDesc GetLmMainProgram() {
return main_prog; return main_prog;
} }
TEST(StandaloneExecutor, run) { // TEST(StandaloneExecutor, run) {
auto place = platform::CUDAPlace(0); // auto place = platform::CUDAPlace(0);
ProgramDesc test_prog = load_from_file("lm_startup_program"); // ProgramDesc test_prog = load_from_file("lm_startup_program");
ProgramDesc main_prog = GetLmMainProgram(); // ProgramDesc main_prog = GetLmMainProgram();
Scope scope; // Scope scope;
StandaloneExecutor exec(place, test_prog, main_prog, &scope); // StandaloneExecutor exec(place, test_prog, main_prog, &scope);
exec.Run({}, {}, {}); // exec.Run({}, {}, {});
auto start = std::chrono::steady_clock::now(); // auto start = std::chrono::steady_clock::now();
for (size_t i = 0; i < 10; ++i) { // for (size_t i = 0; i < 10; ++i) {
if (i % 200 == 0) { // if (i % 200 == 0) {
std::cout << i << std::endl; // std::cout << i << std::endl;
} // }
exec.Run({}, {}, {}); // exec.Run({}, {}, {});
} // }
auto end = std::chrono::steady_clock::now(); // auto end = std::chrono::steady_clock::now();
std::chrono::duration<double> diff = end - start; // std::chrono::duration<double> diff = end - start;
std::cout << "time cost " << diff.count() << std::endl; // std::cout << "time cost " << diff.count() << std::endl;
} // }
TEST(InterpreterCore, skip_gc_vars) { TEST(InterpreterCore, skip_gc_vars) {
auto place = platform::CUDAPlace(0); auto place = platform::CUDAPlace(0);
...@@ -168,9 +168,8 @@ TEST(InterpreterCore, skip_gc_vars) { ...@@ -168,9 +168,8 @@ TEST(InterpreterCore, skip_gc_vars) {
Scope scope; Scope scope;
VariableScope startup_scope(&scope);
std::shared_ptr<InterpreterCore> startup_core = std::shared_ptr<InterpreterCore> startup_core =
CreateInterpreterCore(place, startup_prog, &startup_scope); CreateInterpreterCore(place, startup_prog, &scope);
startup_core->Run({}, {}); startup_core->Run({}, {});
std::set<std::string> skip_gc_vars = {"uniform_0.tmp_0", std::set<std::string> skip_gc_vars = {"uniform_0.tmp_0",
...@@ -183,26 +182,31 @@ TEST(InterpreterCore, skip_gc_vars) { ...@@ -183,26 +182,31 @@ TEST(InterpreterCore, skip_gc_vars) {
"split_0.tmp_0", "split_0.tmp_0",
"elementwise_add_0.tmp_0", "elementwise_add_0.tmp_0",
"tmp_0"}; "tmp_0"};
std::shared_ptr<InterpreterCore> main_core =
CreateInterpreterCore(place, main_prog, &scope, {}, skip_gc_vars);
auto check_gc_result = auto check_gc_result =
[](VariableScope& scope, std::set<std::string>& vars, bool is_skip_gc) { [](Scope& scope, std::set<std::string>& vars, bool is_skip_gc) {
// the first local scope is created in startup_core
// the second local scope is created in main_core
ASSERT_EQ(scope.kids().size(), 2UL);
auto* local_scope = scope.kids().back();
for (const std::string& var_name : vars) { for (const std::string& var_name : vars) {
ASSERT_EQ( ASSERT_EQ(local_scope->FindVar(var_name)
scope.FindVar(var_name)->GetMutable<LoDTensor>()->IsInitialized(), ->GetMutable<LoDTensor>()
is_skip_gc); ->IsInitialized(),
is_skip_gc);
} }
}; };
VariableScope main_scope(&scope);
std::shared_ptr<InterpreterCore> main_core =
CreateInterpreterCore(place, main_prog, &main_scope, {}, skip_gc_vars);
main_core->Run({}, {}); main_core->Run({}, {});
check_gc_result(main_scope, skip_gc_vars, true); check_gc_result(scope, skip_gc_vars, true);
check_gc_result(main_scope, gc_vars, false); check_gc_result(scope, gc_vars, false);
main_core->Run({}, {}); main_core->Run({}, {});
check_gc_result(main_scope, skip_gc_vars, true); check_gc_result(scope, skip_gc_vars, true);
check_gc_result(main_scope, gc_vars, false); check_gc_result(scope, gc_vars, false);
} }
void TestShareWorkQueue(const ProgramDesc& prog, void TestShareWorkQueue(const ProgramDesc& prog,
...@@ -213,11 +217,10 @@ void TestShareWorkQueue(const ProgramDesc& prog, ...@@ -213,11 +217,10 @@ void TestShareWorkQueue(const ProgramDesc& prog,
const platform::CPUPlace place = platform::CPUPlace(); const platform::CPUPlace place = platform::CPUPlace();
Scope scope; Scope scope;
VariableScope variable_scope(&scope);
std::shared_ptr<InterpreterCore> core1 = std::shared_ptr<InterpreterCore> core1 =
CreateInterpreterCore(place, prog, &variable_scope, fetch_names); CreateInterpreterCore(place, prog, &scope, fetch_names);
std::shared_ptr<InterpreterCore> core2 = std::shared_ptr<InterpreterCore> core2 =
CreateInterpreterCore(place, prog, &variable_scope, fetch_names); CreateInterpreterCore(place, prog, &scope, fetch_names);
core2->ShareWorkQueueFrom(core1); core2->ShareWorkQueueFrom(core1);
auto run_and_check = [&feed_names, &feed_tensors, &fetch_results]( auto run_and_check = [&feed_names, &feed_tensors, &fetch_results](
......
...@@ -67,9 +67,6 @@ Variable* Scope::Var(const std::string& name) { ...@@ -67,9 +67,6 @@ Variable* Scope::Var(const std::string& name) {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
ret = VarInternal(name); ret = VarInternal(name);
} }
for (auto l : listeners_) {
l->onCreateVariable(name, ret);
}
return ret; return ret;
} }
...@@ -85,9 +82,6 @@ Variable* Scope::Var(std::string* name) { ...@@ -85,9 +82,6 @@ Variable* Scope::Var(std::string* name) {
} }
ret = VarInternal(new_name); ret = VarInternal(new_name);
} }
for (auto l : listeners_) {
l->onCreateVariable(new_name, ret);
}
return ret; return ret;
} }
...@@ -124,9 +118,6 @@ void Scope::DropKids() { ...@@ -124,9 +118,6 @@ void Scope::DropKids() {
for (Scope* s : kids_) delete s; for (Scope* s : kids_) delete s;
kids_.clear(); kids_.clear();
} }
for (auto l : listeners_) {
l->onClear();
}
} }
bool Scope::HasKid(const Scope* scope) const { bool Scope::HasKid(const Scope* scope) const {
...@@ -175,9 +166,6 @@ void Scope::DeleteScope(Scope* scope) const { ...@@ -175,9 +166,6 @@ void Scope::DeleteScope(Scope* scope) const {
Async([scope] { delete scope; }); Async([scope] { delete scope; });
} }
} }
for (auto l : listeners_) {
l->onDeleteScope(scope);
}
} }
void Scope::EraseVars(const std::vector<std::string>& var_names) { void Scope::EraseVars(const std::vector<std::string>& var_names) {
...@@ -192,11 +180,6 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) { ...@@ -192,11 +180,6 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
} }
} }
} }
for (auto l : listeners_) {
for (auto& var_name : var_names) {
l->onDeleteVariable(var_name);
}
}
} }
void Scope::Rename(const std::string& origin_name, void Scope::Rename(const std::string& origin_name,
...@@ -205,9 +188,6 @@ void Scope::Rename(const std::string& origin_name, ...@@ -205,9 +188,6 @@ void Scope::Rename(const std::string& origin_name,
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
RenameInternal(origin_name, new_name); RenameInternal(origin_name, new_name);
} }
for (auto l : listeners_) {
l->onRenameVariable(origin_name, new_name);
}
} }
std::string Scope::Rename(const std::string& origin_name) const { std::string Scope::Rename(const std::string& origin_name) const {
...@@ -216,9 +196,6 @@ std::string Scope::Rename(const std::string& origin_name) const { ...@@ -216,9 +196,6 @@ std::string Scope::Rename(const std::string& origin_name) const {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
RenameInternal(origin_name, new_name); RenameInternal(origin_name, new_name);
} }
for (auto l : listeners_) {
l->onRenameVariable(origin_name, new_name);
}
return new_name; return new_name;
} }
...@@ -282,22 +259,6 @@ Variable* Scope::FindVarLocally(const std::string& name) const { ...@@ -282,22 +259,6 @@ Variable* Scope::FindVarLocally(const std::string& name) const {
return nullptr; return nullptr;
} }
void Scope::AddListener(const std::shared_ptr<ScopeListener>& listener) {
auto it = std::find(listeners_.begin(), listeners_.end(), listener);
if (it == listeners_.end()) {
listeners_.push_back(listener);
}
}
void Scope::DelListener(const std::shared_ptr<ScopeListener>& listener) {
listeners_.remove(listener);
}
bool Scope::HasListener(const std::shared_ptr<ScopeListener>& listener) {
auto it = std::find(listeners_.begin(), listeners_.end(), listener);
return it != listeners_.end();
}
void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) { void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
for (auto iter = vars_.begin(); iter != vars_.end();) { for (auto iter = vars_.begin(); iter != vars_.end();) {
......
...@@ -51,22 +51,6 @@ class ScopeBase { ...@@ -51,22 +51,6 @@ class ScopeBase {
class Scope; class Scope;
class ScopeListener {
// NOTE(xiongkun03) Abstract Class, doesn't have any attributes.
// Used by VariableScope. If we modify the original scope, we
// need synchronize changes to VariableScope. So we add listerer
// in original Scope.
public:
virtual ~ScopeListener() {}
virtual void onCreateVariable(const std::string& name, Variable* v) {}
virtual void onDeleteVariable(const std::string& name) {}
virtual void onRenameVariable(const std::string& old_name,
const std::string& new_name) {}
virtual void onCreateScope(Scope* Scope) {}
virtual void onDeleteScope(Scope* Scope) {}
virtual void onClear() {}
};
/** /**
* @brief Scope that manage all variables. * @brief Scope that manage all variables.
* *
...@@ -150,12 +134,6 @@ class Scope : public ScopeBase { ...@@ -150,12 +134,6 @@ class Scope : public ScopeBase {
// Rename variable to a new name and return the new name // Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const; std::string Rename(const std::string& origin_name) const;
void AddListener(const std::shared_ptr<ScopeListener>& listener);
void DelListener(const std::shared_ptr<ScopeListener>& listener);
bool HasListener(const std::shared_ptr<ScopeListener>& listener);
protected: protected:
struct KeyHasher { struct KeyHasher {
std::size_t operator()(const std::string& key) const { std::size_t operator()(const std::string& key) const {
...@@ -192,7 +170,6 @@ class Scope : public ScopeBase { ...@@ -192,7 +170,6 @@ class Scope : public ScopeBase {
// Scope in `kids_` are owned by this class. // Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
const Scope* parent_{nullptr}; const Scope* parent_{nullptr};
std::list<std::shared_ptr<ScopeListener>> listeners_;
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册