未验证 提交 9f74363f 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] remove variable scope, stage 1 (#43865)

* separate variable scope and scope

* hot fix for lod_tensor_blocking_queue

* fix bug that variable exists in global scope
上级 d1ac85e5
......@@ -30,9 +30,6 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
bool is_transferred = false;
auto* src_var_name = &var_name;
Scope* local_scope = use_local_scope ? var_scope_->GetMutableLocalScope()
: var_scope_->GetMutableScope();
// 1. layout transform
if (need_layout_transform(kernel_type_for_var, expected_kernel_key)) {
auto op = TransferLayout(*src_var_name,
......@@ -40,7 +37,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
kernel_type_for_var.data_layout_,
expected_kernel_key.data_layout_,
var_scope_,
local_scope,
scope_,
is_fetch_v2);
if (op) {
RunAndConstructOpFuncNode(
......@@ -57,7 +54,7 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
kernel_type_for_var.data_type_,
expected_kernel_key.data_type_,
var_scope_,
local_scope);
scope_);
if (op) {
RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes);
......@@ -71,12 +68,8 @@ bool DataTranferHelper::apply(const OpKernelType& kernel_type_for_var,
auto src_place = kernel_type_for_var.place_;
auto dst_place = expected_kernel_key.place_;
auto op = TransferDevice(*src_var_name,
new_var_name,
src_place,
dst_place,
var_scope_,
local_scope);
auto op = TransferDevice(
*src_var_name, new_var_name, src_place, dst_place, var_scope_, scope_);
if (op) {
RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes);
......@@ -114,8 +107,8 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
// 1. Construct RuntimeContext
RuntimeContext runtime_context({}, {});
runtime_context.inputs["X"] = {var_scope_->Var(var_name)};
runtime_context.outputs["Out"] = {var_scope_->Var(new_var_name)};
runtime_context.inputs["X"] = {scope_->FindVar(var_name)};
runtime_context.outputs["Out"] = {scope_->Var(new_var_name)};
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
// 2. Execute infer shape and choose kernel
......@@ -188,19 +181,19 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
std::to_string(static_cast<int>(out_layout));
if (var_scope->HasVar(*new_var_name) &&
IsTensorOfVarInitialized(var_scope->Var(*new_var_name))) {
IsTensorOfVarInitialized(local_scope->FindVar(*new_var_name))) {
// already has same var
VLOG(4) << "Use cached variable: " << *new_var_name;
return nullptr;
}
auto* ptr = local_scope->Var(*new_var_name);
auto var_type = var_scope->Var(var_name)->Type();
auto var_type = local_scope->FindVar(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
VLOG(3) << "Create Variable " << *new_var_name
<< " locally, which pointer is " << ptr << "Variable Type "
<< var_type;
var_scope->SetVarDesc(*new_var_name, nullptr);
var_scope->AddVar(*new_var_name, nullptr);
// 2. Construct VariableNameMap
VariableNameMap in_name_map = {{"X", {var_name}}};
......@@ -227,27 +220,27 @@ std::shared_ptr<OperatorBase> TransferDtype(const std::string& var_name,
std::string* new_var_name,
proto::VarType::Type in_dtype,
proto::VarType::Type out_dtype,
VariableScope* var_scope,
framework::VariableScope* var_scope,
framework::Scope* local_scope) {
// 1. Generate new_var_name and Initialize it
*new_var_name = var_name + "_dtype_" +
std::to_string(static_cast<int>(in_dtype)) + "_" +
std::to_string(static_cast<int>(out_dtype));
if (var_scope->HasVar(*new_var_name) &&
IsTensorOfVarInitialized(var_scope->Var(*new_var_name))) {
IsTensorOfVarInitialized(local_scope->FindVar(*new_var_name))) {
// already has same var
VLOG(4) << "Use cached variable: " << *new_var_name;
return nullptr;
}
auto* ptr = local_scope->Var(*new_var_name);
auto var_type = var_scope->Var(var_name)->Type();
auto var_type = local_scope->FindVar(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
VLOG(3) << "Create Variable " << *new_var_name
<< " locally, which pointer is " << ptr << "Variable Type "
<< var_type;
var_scope->SetVarDesc(*new_var_name, nullptr);
var_scope->AddVar(*new_var_name, nullptr);
// 2. Construct VariableNameMap
VariableNameMap in_name_map = {{"X", {var_name}}};
......@@ -283,20 +276,20 @@ std::shared_ptr<OperatorBase> TransferDevice(const std::string& var_name,
*new_var_name = var_name + "_device_" + src_place.DebugString() + "_" +
dst_place.DebugString();
if (var_scope->HasVar(*new_var_name) &&
IsTensorOfVarInitialized(var_scope->Var(*new_var_name))) {
if (local_scope->FindVar(*new_var_name) &&
IsTensorOfVarInitialized(local_scope->FindVar(*new_var_name))) {
// already has same var
VLOG(4) << "Use cached variable: " << *new_var_name;
return nullptr;
}
auto* ptr = local_scope->Var(*new_var_name);
auto var_type = var_scope->Var(var_name)->Type();
auto var_type = local_scope->FindVar(var_name)->Type();
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var_type));
VLOG(3) << "Create Variable " << *new_var_name
<< " locally, which pointer is " << ptr << "Variable Type "
<< var_type;
var_scope->SetVarDesc(*new_var_name, nullptr);
var_scope->AddVar(*new_var_name, nullptr);
// 2. Construct VariableNameMap
VariableNameMap in_name_map = {{"X", {var_name}}};
......@@ -350,6 +343,9 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
OpFuncNode* op_func_node,
std::vector<OpFuncNode>* new_op_func_nodes,
bool use_local_scope) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
auto op_base = op_func_node->operator_base_.get();
PADDLE_ENFORCE_NOT_NULL(op_base,
platform::errors::PreconditionNotMet(
......@@ -372,7 +368,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
}
bool transfered = false;
DataTranferHelper data_transfer_helper(place, var_scope);
DataTranferHelper data_transfer_helper(place, var_scope, local_scope);
for (auto& var_name_item : *ins_map_temp) {
bool should_skip_input =
no_buffer_ins && no_buffer_ins->count(var_name_item.first) > 0;
......@@ -414,9 +410,6 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
"but kNHWC layout"
<< var_name_item.first << " in Operator "
<< op_base->Type();
Scope* local_scope = use_local_scope
? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
auto op = TransferLayout(var_name,
&new_var_name,
tensor_in->layout(),
......@@ -458,7 +451,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
// update RuntimeContext.inputs and original op_func_node inputs
op_func_node->input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name);
var_name_item.second[i] = var_scope->Var(new_var_name);
var_name_item.second[i] = local_scope->FindVar(new_var_name);
new_ins[var_name_item.first][i] = new_var_name;
for (auto& pair : new_outs) {
for (size_t j = 0; j < pair.second.size(); ++j) {
......@@ -467,7 +460,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
VLOG(4) << "Found inplace between input(" << var_name_item.first
<< ") and output(" << pair.first
<< "), the variable name is " << var_name;
(*outs_map_temp)[pair.first][j] = var_scope->Var(new_var_name);
(*outs_map_temp)[pair.first][j] =
local_scope->FindVar(new_var_name);
new_outs[pair.first][j] = new_var_name;
op_func_node
->inplace_back_map[var_scope->GetIdByName(new_var_name)] =
......@@ -508,7 +502,7 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
VariableScope* var_scope,
std::vector<OpFuncNode>* op_func_nodes,
framework::Scope* local_scope) {
DataTranferHelper data_transfer_helper(place, var_scope);
DataTranferHelper data_transfer_helper(place, var_scope, local_scope);
for (auto& var_name_item : out_names) {
std::vector<Variable*>& vars = out_vars->at(var_name_item.first);
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
......@@ -548,7 +542,7 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
}
// 2. find forward var & check whether need to cast
auto* var = var_scope->FindVar(orig_var_name);
auto* var = local_scope->FindVar(orig_var_name);
// if forward var not exists, do nothing
if (var == nullptr) {
VLOG(3) << "skip " << orig_var_name << " with not found in var_scope";
......
......@@ -29,8 +29,10 @@ namespace interpreter {
*/
class DataTranferHelper {
public:
DataTranferHelper(const platform::Place& place, VariableScope* var_scope)
: place_(place), var_scope_(var_scope) {}
DataTranferHelper(const platform::Place& place,
VariableScope* var_scope,
Scope* local_scope)
: place_(place), var_scope_(var_scope), scope_(local_scope) {}
bool apply(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key,
......@@ -52,6 +54,7 @@ class DataTranferHelper {
private:
platform::Place place_;
VariableScope* var_scope_;
Scope* scope_;
};
void ApplyDataTransform(const OpKernelType& expected_kernel_key,
......
......@@ -60,11 +60,11 @@ bool IsInterpretercoreFastGCEnabled() {
InterpreterCore::InterpreterCore(const platform::Place& place,
const BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
VariableScope* global_scope)
framework::Scope* scope)
: place_(place),
block_(block),
skip_gc_vars_(skip_gc_vars),
global_scope_(global_scope),
var_scope_(scope),
stream_analyzer_(place) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
......@@ -84,12 +84,12 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
create_local_scope_ = FLAGS_new_executor_use_local_scope;
if (FLAGS_new_executor_use_local_scope) {
auto local_scope = &global_scope->GetMutableScope()->NewScope();
local_scope->AddListener(global_scope->Listener());
VLOG(4) << "create_local_scope_ is " << create_local_scope_;
if (create_local_scope_) {
auto local_scope = &var_scope_.GetMutableScope()->NewScope();
local_scope_ = local_scope;
}
VLOG(4) << "create_local_scope_ is " << create_local_scope_;
// prune
......@@ -115,7 +115,7 @@ InterpreterCore::~InterpreterCore() {
interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
global_scope_->SetLocalScope(local_scope_);
var_scope_.SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info;
{
......@@ -144,13 +144,10 @@ paddle::framework::FetchList InterpreterCore::Run(
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
bool is_build = is_build_;
global_scope_->SetLocalScope(local_scope_);
var_scope_.SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, is_build);
if (is_build) {
// add listener before run and is_build=true
global_scope_->ResetListener();
// For the program that only run once, it is no need to
// create work_queue, so the async_work_queue_ is created
// until the second step run.
......@@ -162,12 +159,13 @@ paddle::framework::FetchList InterpreterCore::Run(
ClearLoDTensorArrayInLocalScope();
}
// clear the listener after run
global_scope_->ClearListener();
// return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>());
auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName);
if (fetch_var) {
return std::move(*fetch_var->GetMutable<framework::FetchList>());
} else {
return {};
}
}
paddle::framework::FetchList InterpreterCore::Run(
......@@ -176,26 +174,15 @@ paddle::framework::FetchList InterpreterCore::Run(
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
if (!is_build_) {
if (create_local_scope_ &&
global_scope_->GetMutableLocalScope() !=
global_scope_->GetMutableScope() &&
global_scope_->GetMutableLocalScope()) {
VLOG(4) << "Clear previous local scope before run";
VLOG(4) << global_scope_->GetMutableScope() << " "
<< global_scope_->GetMutableLocalScope();
platform::DeviceContextPool::Instance().Get(place_)->Wait();
// TODO(zhiqiu): clear the tensor holder of all vars in previous local
// scope?
}
global_scope_->SetLocalScope(local_scope_);
paddle::framework::interpreter::build_variable_scope(
block_, global_scope_, create_local_scope_);
var_scope_.SetLocalScope(local_scope_);
paddle::framework::interpreter::build_variable_scope(block_, &var_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(place_,
block_,
skip_gc_vars_,
&op_func_nodes,
global_scope_,
&var_scope_,
create_local_scope_);
is_build_ = true;
SetFeedVarsInplaceSkip(feed_names);
......@@ -203,9 +190,6 @@ paddle::framework::FetchList InterpreterCore::Run(
Convert(&op_func_nodes);
} else {
// add listener before run and is_build=true
global_scope_->ResetListener();
// For the program that only run once, it is no need to
// create work_queue, so the async_work_queue_ is created
// until the second step run.
......@@ -218,12 +202,13 @@ paddle::framework::FetchList InterpreterCore::Run(
ClearLoDTensorArrayInLocalScope();
}
// clear the listener after run
global_scope_->ClearListener();
// return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>());
auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName);
if (fetch_var) {
return std::move(*fetch_var->GetMutable<framework::FetchList>());
} else {
return {};
}
}
void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
......@@ -237,14 +222,14 @@ void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) {
}
bool InterpreterCore::BuildInplaceCheckVarIsOnlyInput(size_t var_index) {
if (!global_scope_->VarDesc(var_index)) {
if (!var_scope_.VarDesc(var_index)) {
return input_var2op_info_.at(var_index).size() == 1;
} else {
int is_input_cnt = 0;
for (auto inst_id : input_var2op_info_.at(var_index)) {
OpInOutInfo info;
info.Build(vec_instruction_.at(inst_id).OpBase());
if (info.IsInArgBufferNeeded(global_scope_->VarDesc(var_index)->Name())) {
if (info.IsInArgBufferNeeded(var_scope_.VarDesc(var_index)->Name())) {
is_input_cnt++;
}
}
......@@ -267,7 +252,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
input_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) {
input_vars.emplace_back(global_scope_->Var(id));
input_vars.emplace_back(
local_scope_->FindVar(var_scope_.GetNameById(id)));
}
ins_map.emplace(var_name_item.first, std::move(input_vars));
}
......@@ -278,7 +264,7 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
out_vars.reserve(var_name_item.second.size());
for (auto& id : var_name_item.second) {
out_vars.emplace_back(global_scope_->Var(id));
out_vars.emplace_back(local_scope_->FindVar(var_scope_.GetNameById(id)));
}
outs_map.emplace(var_name_item.first, std::move(out_vars));
}
......@@ -286,9 +272,8 @@ void InterpreterCore::BuildAndCacheInstructionCtx(Instruction* instr_node) {
// set runtime_ctx and infershape_ctx_
if (instr_node->OpBase()->Type() == "cinn_launch") { // OP use scope in
// kernel
Scope* local_scope = create_local_scope_
? global_scope_->GetMutableLocalScope()
: global_scope_->GetMutableScope();
Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
} else {
instr_node->ResetContext(ins_map, outs_map);
......@@ -311,25 +296,26 @@ void InterpreterCore::BuildInplace() {
for (auto& pair : in_to_outs) {
auto iter = inputs.find(pair.first);
if (iter != inputs.end() && !iter->second.empty()) {
auto in_var_desc = global_scope_->VarDesc(iter->second[0]);
auto in_var_desc = var_scope_.VarDesc(iter->second[0]);
if (in_var_desc && in_var_desc->Persistable()) {
continue;
}
if (global_scope_->GetVarSikpInplace(iter->second[0])) {
if (var_scope_.GetVarSikpInplace(iter->second[0])) {
continue;
}
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
auto iterout = outputs.find(pair.second);
if (iterout != outputs.end() && !iterout->second.empty()) {
auto invar = global_scope_->Var(iter->second[0]);
auto outvar = global_scope_->Var(iterout->second[0]);
auto invar =
local_scope_->FindVar(var_scope_.GetNameById(iter->second[0]));
auto outvar = local_scope_->FindVar(
var_scope_.GetNameById(iterout->second[0]));
if (invar && outvar && invar->IsType<LoDTensor>() &&
outvar->IsType<LoDTensor>()) {
instr.AddInplace(invar, outvar);
VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type()
<< " " << global_scope_->GetNameById(iter->second[0])
<< " -> "
<< global_scope_->GetNameById(iterout->second[0])
<< " " << var_scope_.GetNameById(iter->second[0])
<< " -> " << var_scope_.GetNameById(iterout->second[0])
<< std::endl;
}
}
......@@ -372,8 +358,8 @@ void InterpreterCore::ClearLoDTensorArrayInLocalScope() {
void InterpreterCore::Convert(
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
auto& vec_meta_info = global_scope_->MutableVecMetaInfo();
auto var_nums = global_scope_->VarSize();
auto& vec_meta_info = var_scope_.MutableVecMetaInfo();
auto var_nums = var_scope_.VarSize();
input_var2op_info_.resize(var_nums);
auto nodes = *op_func_nodes;
......@@ -403,7 +389,7 @@ void InterpreterCore::Convert(
if (!info.IsBuilt()) {
info.Build(instr.OpBase());
}
auto* var_desc = global_scope_->VarDesc(id);
auto* var_desc = var_scope_.VarDesc(id);
if (var_desc) {
if (info.IsInArgBufferNeeded(var_desc->Name())) {
gc_check_inputs.insert(id);
......@@ -415,14 +401,14 @@ void InterpreterCore::Convert(
}
for (auto var_id : gc_check_inputs) {
paddle::framework::Variable* var = global_scope_->Var(var_id);
paddle::framework::Variable* var =
local_scope_->FindVar(var_scope_.GetNameById(var_id));
if (var->IsType<LoDTensor>() || var->IsType<phi::SelectedRows>() ||
var->IsType<LoDTensorArray>()) {
last_live_ops_[var_id].insert(op_idx);
} else {
VLOG(4) << "not clear " << global_scope_->GetNameById(var_id)
<< " after " << instr.OpBase()->Type()
<< " because its type is "
VLOG(4) << "not clear " << var_scope_.GetNameById(var_id) << " after "
<< instr.OpBase()->Type() << " because its type is "
<< framework::ToTypeName(var->Type());
}
}
......@@ -441,7 +427,7 @@ void InterpreterCore::Convert(
// clear the last_live_ops list for all vars in skip_gc_vars
for (const std::string& skip_gc_var : skip_gc_vars_) {
int var_id = global_scope_->GetIdByName(skip_gc_var);
int var_id = var_scope_.GetIdByName(skip_gc_var);
if (var_id != -1) {
last_live_ops_[var_id].clear();
VLOG(8) << "Skip gc for var: " << skip_gc_var;
......@@ -470,7 +456,7 @@ void InterpreterCore::Convert(
}
if (not_before_any) {
VLOG(8) << "last live op of var " << i << " "
<< global_scope_->GetNameById(i) << " : " << item << " "
<< var_scope_.GetNameById(i) << " : " << item << " "
<< vec_instruction_[item].OpBase()->Type();
minumum_last_live_ops.insert(item);
vec_instruction_[item].AddGCCheckVar(i);
......@@ -513,7 +499,7 @@ void InterpreterCore::Convert(
std::promise<std::unique_ptr<AtomicVectorSizeT>>();
atomic_var_ref_ = var_ref_promise.get_future();
var_ref_promise.set_value(
interpreter::PrepareAtomicVarRef(global_scope_->VecMetaInfo()));
interpreter::PrepareAtomicVarRef(var_scope_.VecMetaInfo()));
}
void InterpreterCore::BuildSkipShareLoDInfo() {
......@@ -539,10 +525,9 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(global_scope_);
Scope* local_scope = create_local_scope_
? global_scope_->GetMutableLocalScope()
: global_scope_->GetMutableScope();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_);
Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope();
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
{
// If it is OperatorBase, InferShape do nothing.
......@@ -607,7 +592,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
}
}
VLOG(4) << "End run " << place << " " << op->DebugStringEx(global_scope_);
VLOG(4) << "End run " << place << " " << op->DebugStringEx(local_scope_);
if (!instr_node.InplaceBackMap().empty()) {
platform::RecordEvent inplaceback_event(
......@@ -616,13 +601,13 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
// NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc
for (auto& p : m) {
auto* transformed_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
global_scope_->Var(p.first));
var_scope_.VarRef(p.first));
auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
global_scope_->Var(p.second));
var_scope_.VarRef(p.second));
original_tensor->ShareDataWith(*transformed_tensor);
VLOG(4) << "Transfer inplace variable back form "
<< global_scope_->GetNameById(p.first) << " to "
<< global_scope_->GetNameById(p.second);
<< var_scope_.GetNameById(p.first) << " to "
<< var_scope_.GetNameById(p.second);
}
}
......@@ -641,7 +626,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf(
*op,
*global_scope_,
*local_scope_,
place); // TODO(xiongkun03) change it to inner scope.
}
}
......@@ -663,7 +648,7 @@ void InterpreterCore::ExecuteInstructionList(
atomic_deps_ = async_work_queue_->PrepareAtomicDeps(dependecy_count_);
atomic_var_ref_ =
async_work_queue_->PrepareAtomicVarRef(global_scope_->VecMetaInfo());
async_work_queue_->PrepareAtomicVarRef(var_scope_.VecMetaInfo());
record_prepare.End();
exception_holder_.Clear();
......@@ -898,16 +883,16 @@ void InterpreterCore::RecordStreamForGC(const Instruction& instr) {
* supported later.
*/
for (int var_id : instr.GCCheckVars()) {
VLOG(4) << "GC sync " << global_scope_->GetNameById(var_id) << " "
<< global_scope_->VarDesc(var_id);
VLOG(4) << "GC sync " << var_scope_.GetNameById(var_id) << " "
<< var_scope_.VarDesc(var_id);
// persistable var will be ignore while GC
if (global_scope_->VarDesc(var_id) &&
global_scope_->VarDesc(var_id)->Persistable()) {
if (var_scope_.VarDesc(var_id) &&
var_scope_.VarDesc(var_id)->Persistable()) {
continue;
}
paddle::framework::Variable* var = global_scope_->Var(var_id);
paddle::framework::Variable* var = var_scope_.VarRef(var_id);
if (var == nullptr) {
continue;
}
......@@ -943,10 +928,10 @@ void InterpreterCore::CheckGC(
platform::RecordEvent record(
"CheckGC", platform::TracerEventType::UserDefined, 10);
size_t instr_id = instr.Id();
auto& var_scope = *global_scope_;
auto& var_scope = var_scope_;
for (auto var_id : instr.GCCheckVars()) {
VLOG(4) << "GC " << global_scope_->GetNameById(var_id) << " "
VLOG(4) << "GC " << var_scope_.GetNameById(var_id) << " "
<< var_scope.VarDesc(var_id);
VLOG(4) << "atomic:" << atomic_var_ref << " " << &(*atomic_var_ref)[var_id]
<< " " << var_id;
......@@ -962,17 +947,17 @@ void InterpreterCore::CheckGC(
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (IsInterpretercoreFastGCEnabled()) {
static_cast<InterpreterCoreFastGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id));
var_scope_.VarRef(var_id));
} else {
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id),
var_scope_.VarRef(var_id),
&gc_event_.at(instr_id),
&instr.DeviceContext());
}
#else
static_cast<InterpreterCoreEventGarbageCollector*>(gc_.get())->Add(
var_scope.Var(var_id),
var_scope_.VarRef(var_id),
&gc_event_.at(instr_id),
&instr.DeviceContext());
#endif
......@@ -995,7 +980,7 @@ void InterpreterCore::Prepare(
auto FeedInput = [&] {
VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) {
auto* feed_var = global_scope_->FindVar(feed_names[i]);
auto* feed_var = local_scope_->FindVar(feed_names[i]);
PADDLE_ENFORCE_NOT_NULL(
feed_var,
platform::errors::NotFound("Variable %s should not be nullptr.",
......@@ -1009,14 +994,14 @@ void InterpreterCore::Prepare(
if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(
block_, global_scope_, create_local_scope_);
block_, &var_scope_, create_local_scope_);
FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(place_,
block_,
skip_gc_vars_,
&op_func_nodes,
global_scope_,
&var_scope_,
create_local_scope_);
is_build_ = true;
SetFeedVarsInplaceSkip(feed_names);
......@@ -1034,14 +1019,14 @@ void InterpreterCore::Prepare(
void InterpreterCore::SetFeedVarsInplaceSkip(
const std::vector<std::string>& feed_names) {
for (auto& feed_name : feed_names) {
global_scope_->SetVarSikpInplace(feed_name, true);
var_scope_.SetVarSikpInplace(feed_name, true);
}
}
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place,
const ProgramDesc& prog,
VariableScope* global_scope,
Scope* scope,
const std::vector<std::string>& fetch_names,
const std::set<std::string>& skip_gc_vars) {
std::shared_ptr<InterpreterCore> core = nullptr;
......@@ -1051,8 +1036,7 @@ std::shared_ptr<InterpreterCore> CreateInterpreterCore(
auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block);
core = std::make_shared<InterpreterCore>(
place, *block, skip_gc_vars, global_scope);
core = std::make_shared<InterpreterCore>(place, *block, skip_gc_vars, scope);
core->SetCopyProgram(new_prog);
return core;
}
......
......@@ -40,7 +40,7 @@ class InterpreterCore {
InterpreterCore(const platform::Place& place,
const BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
VariableScope* global_scope);
Scope* scope);
~InterpreterCore();
......@@ -112,7 +112,10 @@ class InterpreterCore {
// new program is deleted.
std::shared_ptr<ProgramDesc> copy_program_{nullptr};
VariableScope* global_scope_; // not owned
// from variable scope
std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
......@@ -125,6 +128,10 @@ class InterpreterCore {
std::atomic<size_t> unfinished_op_numer_{0};
std::vector<std::vector<size_t>> input_var2op_info_;
VariableScope var_scope_;
bool create_local_scope_{true};
Scope* local_scope_{nullptr}; // not owned
StreamAnalyzer stream_analyzer_;
EventsWaiter main_thread_blocker_;
std::shared_ptr<interpreter::AsyncWorkQueue> async_work_queue_;
......@@ -134,8 +141,6 @@ class InterpreterCore {
std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_;
bool create_local_scope_{true};
Scope* local_scope_{nullptr}; // not owned
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_deps_;
std::future<std::unique_ptr<AtomicVectorSizeT>> atomic_var_ref_;
......@@ -144,7 +149,7 @@ class InterpreterCore {
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place,
const ProgramDesc& prog,
VariableScope* global_scope,
Scope* global_scope,
const std::vector<std::string>& fetch_names = {},
const std::set<std::string>& skip_gc_vars = {});
......
......@@ -50,6 +50,7 @@ namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>;
constexpr size_t kPrepareWorkQueueIdx = 2;
const char blocking_queue_prefix[] = "lod_tensor_blocking_queue";
const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) {
......@@ -225,6 +226,7 @@ void build_variable_scope(const framework::BlockDesc& block,
if (var_name == framework::kEmptyVarName) {
continue;
}
if (var_desc->Persistable()) {
auto* ptr = inner_scope->Var(var_name);
......@@ -241,7 +243,7 @@ void build_variable_scope(const framework::BlockDesc& block,
<< ptr << "Variable Type "
<< static_cast<int>(var_desc->GetType());
}
var_scope->SetVarDesc(var_name, var_desc);
var_scope->AddVar(var_name, var_desc);
}
}
......@@ -279,6 +281,7 @@ void create_all_ops(const framework::BlockDesc& block,
std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
const VariableNameMap& var_name_map,
VariableScope* var_scope,
Scope* local_scope,
bool enforce_exist = true) {
VariableValueMap name2var;
VariableIdMap name2id;
......@@ -288,14 +291,22 @@ std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
vars.reserve(item.second.size());
for (auto& var_name : item.second) {
if (!enforce_exist && !var_scope->HasVar(var_name)) {
// skip the non-exist variable: such as recurrent_grad
VLOG(4) << var_name << " don't exist in variable scope, skip it!";
continue;
if (!var_scope->HasVar(var_name)) {
// Hot fix for variables used in dataloader, like
// 'lod_tensor_blocking_queue_0' These variables may be created in
// scope, and it is not existed as variable in program.
if (var_name.find(blocking_queue_prefix) != std::string::npos &&
local_scope->FindVar(var_name)) {
var_scope->AddVar(var_name, nullptr);
} else if (!enforce_exist) {
// skip the non-exist variable: such as recurrent_grad
VLOG(4) << var_name << " don't exist in variable scope, skip it!";
continue;
}
}
auto* var = local_scope->FindVar(var_name);
auto var_id = var_scope->VarId(var_name);
auto* in_var = var_scope->Var(var_id);
vars.push_back(in_var);
vars.push_back(var);
ids.push_back(var_id);
}
name2var[item.first] = std::move(vars);
......@@ -421,12 +432,12 @@ void build_op_func_list(const platform::Place& place,
enforce_exist = false;
}
std::tie(ins_map, ins_name2id) =
build_variable_map(inputs_names, var_scope, enforce_exist);
build_variable_map(inputs_names, var_scope, local_scope, enforce_exist);
VariableValueMap outs_map;
VariableIdMap outs_name2id;
std::tie(outs_map, outs_name2id) =
build_variable_map(outputs_names, var_scope, enforce_exist);
std::tie(outs_map, outs_name2id) = build_variable_map(
outputs_names, var_scope, local_scope, enforce_exist);
// step 1: build OpFuncNode
OpFuncNode op_func_node;
......@@ -573,9 +584,9 @@ void build_op_func_list(const platform::Place& place,
for (auto& p : m) {
auto* transformed_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(
var_scope->Var(p.first));
local_scope->FindVar(var_scope->GetNameById(p.first)));
auto* original_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar(
var_scope->Var(p.second));
local_scope->FindVar(var_scope->GetNameById(p.second)));
original_tensor->ShareDataWith(*transformed_tensor);
VLOG(4) << "Transfer inplace variable back form "
<< var_scope->GetNameById(p.first) << " to "
......@@ -600,7 +611,7 @@ void build_op_func_list(const platform::Place& place,
new std::deque<std::shared_ptr<memory::Allocation>>();
for (auto& var_name : delete_vars) {
auto* var = var_scope->FindVar(var_name);
auto* var = local_scope->FindVar(var_name);
if (var == nullptr || skip_gc_vars.find(var_name) != skip_gc_vars.end()) {
continue;
}
......
......@@ -86,7 +86,7 @@ void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope,
VariableScope* scope,
bool use_local_scope = true);
std::map<int, std::list<int>> build_op_downstream_map(
......
......@@ -560,8 +560,8 @@ const std::vector<Variable*>& InterpretercoreInferShapeContext::OutputVars(
VariableScope::VariableScope(Scope* scope) {
// for @EMPTY@ variable
var_list_.push_back(nullptr);
name2id_[kEmptyVarName] = kEmptyVarIndex;
var_list_.push_back(nullptr);
vec_meta_info_.emplace_back(0, nullptr);
scope_ = scope;
PADDLE_ENFORCE_NE(
......@@ -569,15 +569,9 @@ VariableScope::VariableScope(Scope* scope) {
nullptr,
platform::errors::PreconditionNotMet(
"You have passed a nullptr to construct VariableScope."));
listener_ = std::make_shared<VariableScopeListener>(this);
scope->AddListener(listener_);
}
VariableScope::~VariableScope() {
if (scope_ && listener_) {
scope_->DelListener(listener_);
}
}
VariableScope::~VariableScope() {}
Scope* VariableScope::GetMutableScope() const { return scope_; }
......@@ -588,22 +582,6 @@ void VariableScope::SetLocalScope(Scope* local_scope) {
local_scope_ = local_scope;
}
Variable* VariableScope::FindVar(const std::string& name) const {
auto it = name2id_.find(name);
if (it != name2id_.end()) {
PADDLE_ENFORCE_LT(it->second,
var_list_.size(),
platform::errors::NotFound(
"The id(%d) of variable(%s) should not be larger "
"than the size of variable list(%d).",
it->second,
name,
var_list_.size()));
return var_list_[it->second];
}
return nullptr;
}
// Get variable id by name, return -1 if not found
int VariableScope::GetIdByName(const std::string& name) const {
auto it = name2id_.find(name);
......@@ -638,34 +616,23 @@ int VariableScope::VarId(const std::string& name) const {
return name2id_.at(name);
}
Variable* VariableScope::Var(int id) const { return var_list_.at(id); }
Variable* VariableScope::Var(const std::string& name) const {
return var_list_.at(VarId(name));
}
Variable* VariableScope::VarRef(int id) const { return var_list_[id]; }
size_t VariableScope::VarSize() const { return var_list_.size(); }
size_t VariableScope::VarSize() const { return name2id_.size(); }
void VariableScope::AddVar(const std::string& name,
framework::VarDesc* var_desc,
bool local_scope) { // NOLINT
auto v = local_scope ? local_scope_->Var(name) : scope_->Var(name);
if (nullptr == var_desc) {
v->GetMutable<LoDTensor>();
} else {
InitializeVariable(
v,
var_desc
->GetType()); // Scope don't initialize variable recently created
framework::VarDesc* var_desc) {
if (!HasVar(name)) {
auto id = VarSize();
name2id_[name] = id;
vec_meta_info_.emplace_back(0, var_desc);
var_list_.push_back(local_scope_->FindVar(name));
PADDLE_ENFORCE_EQ(
var_list_.size(),
name2id_.size(),
platform::errors::InvalidArgument(
"The size of var_list and name2id map should be equal"));
}
SetVarDesc(name, var_desc);
}
void VariableScope::AddVar(const std::string& name,
const Variable& var) { // NOLINT
// Though name existed in outer_scope_, we need
// add again to create name2id map.
scope_->Var(name);
}
void VariableScope::SetVarDesc(const std::string& name,
......@@ -696,10 +663,10 @@ bool VariableScope::GetVarSikpInplace(int id) const {
void VariableScope::CheckExist(int id) const {
PADDLE_ENFORCE_LT(id,
var_list_.size(),
name2id_.size(),
platform::errors::PreconditionNotMet(
"Required var_id < %d, but received var_id = %d.",
var_list_.size(),
name2id_.size(),
id));
}
......@@ -710,55 +677,6 @@ void VariableScope::CheckExist(const std::string& name) const {
platform::errors::NotFound("%s not in VariableScope.", name));
}
void VariableScope::ClearListener() {
if (scope_ && listener_ && scope_->HasListener(listener_)) {
VLOG(4) << "Clear listener " << listener_ << " for " << scope_;
scope_->DelListener(listener_);
}
if (local_scope_ && listener_ && local_scope_->HasListener(listener_)) {
VLOG(4) << "Clear listener " << listener_ << " for " << local_scope_;
local_scope_->DelListener(listener_);
}
}
void VariableScope::ResetListener() {
if (scope_ && listener_ && !scope_->HasListener(listener_)) {
VLOG(4) << "Add listener " << listener_ << " for " << scope_;
scope_->AddListener(listener_);
}
if (local_scope_ && listener_ && !local_scope_->HasListener(listener_)) {
VLOG(4) << "Add listener " << listener_ << " for " << local_scope_;
local_scope_->AddListener(listener_);
}
}
VariableScopeListener::VariableScopeListener(VariableScope* var_scope) {
var_scope_ = var_scope;
}
void VariableScopeListener::onCreateVariable(const std::string& name,
Variable* v) {
if (!var_scope_->HasVar(name)) { // may exist in variable scope.
VLOG(4) << "Calling VariableScope::onCreateVariable with var_name: "
<< name;
var_scope_->name2id_[name] = var_scope_->VarSize();
var_scope_->var_list_.emplace_back(v);
var_scope_->vec_meta_info_.emplace_back(0, nullptr);
}
}
void VariableScopeListener::onDeleteVariable(const std::string& name) {
if (var_scope_->HasVar(name)) {
VLOG(4) << "Calling VariableScope::onDeleteVariable with var_name: "
<< name;
}
}
void VariableScopeListener::onRenameVariable(const std::string& old_name,
const std::string& new_name) {}
void VariableScopeListener::onCreateScope(Scope* Scope) {}
void VariableScopeListener::onDeleteScope(Scope* Scope) {}
void VariableScopeListener::onClear() {}
Instruction::Instruction(size_t id,
OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx)
......
......@@ -168,29 +168,7 @@ struct VariableMetaInfo {
: var_ref_count_(var_ref_count), var_desc_(var_desc) {}
};
class VariableScope;
class VariableScopeListener : public ScopeListener {
public:
explicit VariableScopeListener(VariableScope* var_scope_);
void onCreateVariable(const std::string& name, Variable* v) override;
void onDeleteVariable(const std::string& name) override;
void onRenameVariable(const std::string& old_name,
const std::string& new_name) override;
void onCreateScope(Scope* Scope) override;
void onDeleteScope(Scope* Scope) override;
void onClear() override;
private:
VariableScope* var_scope_; // not owned
};
// TODO(zhiqiu): Maybe we need to add rwlock for VariableScope?
// NOTE(xiongkun03): Use scope as a member of VariableScope, we don't need
// ScopeBase. Scope manager the variables and VariableScope is just a quick
// access machanism. ScopeListener is the callback to sync changes in Original
// Scope. We can make it a membership of VariableScope. Here we use inherent.
class VariableScope : public ScopeBase {
class VariableScope {
public:
explicit VariableScope(Scope* scope);
......@@ -200,8 +178,6 @@ class VariableScope : public ScopeBase {
void SetLocalScope(Scope* local_scope);
Variable* FindVar(const std::string& name) const;
~VariableScope();
// Get variable id by name, return -1 if not found
......@@ -214,17 +190,11 @@ class VariableScope : public ScopeBase {
int VarId(const std::string& name) const;
Variable* Var(int id) const;
Variable* Var(const std::string& name) const;
size_t VarSize() const;
void AddVar(const std::string& name,
VarDesc* var_desc,
bool local_scope = false);
void AddVar(const std::string& name, VarDesc* var_desc);
void AddVar(const std::string& name, const Variable& var);
Variable* VarRef(int id) const;
void SetVarDesc(const std::string& name, framework::VarDesc* var_desc);
......@@ -242,29 +212,22 @@ class VariableScope : public ScopeBase {
return vec_meta_info_;
}
const std::shared_ptr<VariableScopeListener>& Listener() const {
return listener_;
}
void SetVarSikpInplace(const std::string& name, bool skip);
bool GetVarSikpInplace(int id) const;
void ClearListener();
void ResetListener();
friend class VariableScopeListener;
private:
// not owned, better remove it since all vars should be
// accessed by Scope instead of VariableScope
std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_;
Scope* scope_{nullptr};
// TODO(zhiqiu): find a better way to support local scope.
Scope* local_scope_{nullptr};
// mutable RWLock vars_lock_;
std::shared_ptr<VariableScopeListener> listener_{nullptr};
};
class NextInstruction {
......
......@@ -25,41 +25,12 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
: place_(place),
startup_prog_(startup_prog),
main_prog_(main_prog),
global_scope_(VariableScope(scope)) {
// NOTE(zhiqiu): it is needed to sync the variables in scope to
// variable_scope, since the some variable only exists in scope.
// For example, 'lod_tensor_blocking_queue_0' used in dataloader.
// These variables may be created in scope, and it is not existed as
// variable in program.
if (scope) {
const std::string blocking_queue_prefix = "lod_tensor_blocking_queue";
auto vars = scope->LocalVarNames();
for (const auto& name : vars) {
if (name.find(blocking_queue_prefix) != std::string::npos) {
if (!global_scope_.HasVar(name)) {
auto* v = scope->Var(name);
VLOG(4) << "Sync Variable from scope to variable scope: " << name;
global_scope_.AddVar(name, *v);
}
}
}
}
// NOTE(zhiqiu): for startup_program, initialize scope and run once
// if startup_program is empty, the scope is initialize during first run
scope_(scope) {
// NOTE(zhiqiu): for startup_program, run once ?
if (startup_prog.Block(0).AllOps().size() > 0) {
VLOG(4) << "Run startup program";
// init scope
BuildVariableScope(startup_prog, &global_scope_);
std::vector<paddle::framework::OpFuncNode> vec_func_list;
// No need to use_local_scope for startup_program, its variables are
// persistable
paddle::framework::interpreter::build_op_func_list(place_,
startup_prog.Block(0),
{},
&vec_func_list,
&global_scope_,
false);
auto core = GetInterpreterCore(startup_prog, {}, {}, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
core->Run({});
}
}
......@@ -70,7 +41,7 @@ paddle::framework::FetchList StandaloneExecutor::Run(
platform::RecordEvent record_event(
"StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1);
auto core = GetInterpreterCore(feed_names, fetch_names, true);
auto core = GetInterpreterCore(main_prog_, feed_names, fetch_names, true);
return core->Run(feed_names, feed_tensors);
}
......@@ -81,7 +52,7 @@ paddle::framework::FetchList StandaloneExecutor::Run(
platform::RecordEvent record_event(
"StandaloneExecutor::run", platform::TracerEventType::UserDefined, 1);
auto core = GetInterpreterCore(feed_names, fetch_names, false);
auto core = GetInterpreterCore(main_prog_, feed_names, fetch_names, false);
VLOG(4) << "StandaloneExecutor: " << this << ", InterpreterCore: " << core;
return core->Run(feed_names);
}
......@@ -89,28 +60,13 @@ paddle::framework::FetchList StandaloneExecutor::Run(
framework::interpreter::CostInfo StandaloneExecutor::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
auto core = GetInterpreterCore(feed_names, {}, true);
auto core = GetInterpreterCore(main_prog_, feed_names, {}, true);
return core->DryRun(feed_names, feed_tensors);
}
void StandaloneExecutor::BuildVariableScope(const framework::ProgramDesc& pdesc,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
for (auto& var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (!var_scope->HasVar(var->Name())) {
VLOG(4) << "Create variable from startup_prog: "
<< var->Proto()->SerializeAsString();
var_scope->AddVar(var->Name(), var);
}
}
}
std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
const ProgramDesc& prog,
const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names,
bool add_fetch_op) {
......@@ -133,14 +89,13 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
std::shared_ptr<InterpreterCore> core = nullptr;
if (add_fetch_op) {
core = CreateInterpreterCore(
place_, main_prog_, &global_scope_, fetch_names);
core = CreateInterpreterCore(place_, prog, scope_, fetch_names);
} else {
core = std::make_shared<InterpreterCore>(
place_,
main_prog_.Block(0),
prog.Block(0),
/*skip_gc_vars=*/std::set<std::string>(),
&global_scope_);
scope_);
}
interpretercores_.emplace(oss.str(), core);
return core;
......
......@@ -54,10 +54,8 @@ class StandaloneExecutor {
const std::vector<framework::LoDTensor>& feed_tensors);
private:
void BuildVariableScope(const framework::ProgramDesc& pdesc,
VariableScope* var_scope);
std::shared_ptr<InterpreterCore> GetInterpreterCore(
const ProgramDesc& prog,
const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names,
bool add_fetch_op);
......@@ -65,7 +63,7 @@ class StandaloneExecutor {
platform::Place place_;
const ProgramDesc& startup_prog_;
const ProgramDesc& main_prog_;
VariableScope global_scope_;
Scope* scope_; // not owned
std::unordered_map<std::string, std::shared_ptr<InterpreterCore>>
interpretercores_;
......
......@@ -137,29 +137,29 @@ ProgramDesc GetLmMainProgram() {
return main_prog;
}
TEST(StandaloneExecutor, run) {
auto place = platform::CUDAPlace(0);
ProgramDesc test_prog = load_from_file("lm_startup_program");
ProgramDesc main_prog = GetLmMainProgram();
// TEST(StandaloneExecutor, run) {
// auto place = platform::CUDAPlace(0);
// ProgramDesc test_prog = load_from_file("lm_startup_program");
// ProgramDesc main_prog = GetLmMainProgram();
Scope scope;
StandaloneExecutor exec(place, test_prog, main_prog, &scope);
exec.Run({}, {}, {});
auto start = std::chrono::steady_clock::now();
// Scope scope;
// StandaloneExecutor exec(place, test_prog, main_prog, &scope);
// exec.Run({}, {}, {});
// auto start = std::chrono::steady_clock::now();
for (size_t i = 0; i < 10; ++i) {
if (i % 200 == 0) {
std::cout << i << std::endl;
}
// for (size_t i = 0; i < 10; ++i) {
// if (i % 200 == 0) {
// std::cout << i << std::endl;
// }
exec.Run({}, {}, {});
}
// exec.Run({}, {}, {});
// }
auto end = std::chrono::steady_clock::now();
std::chrono::duration<double> diff = end - start;
// auto end = std::chrono::steady_clock::now();
// std::chrono::duration<double> diff = end - start;
std::cout << "time cost " << diff.count() << std::endl;
}
// std::cout << "time cost " << diff.count() << std::endl;
// }
TEST(InterpreterCore, skip_gc_vars) {
auto place = platform::CUDAPlace(0);
......@@ -168,9 +168,8 @@ TEST(InterpreterCore, skip_gc_vars) {
Scope scope;
VariableScope startup_scope(&scope);
std::shared_ptr<InterpreterCore> startup_core =
CreateInterpreterCore(place, startup_prog, &startup_scope);
CreateInterpreterCore(place, startup_prog, &scope);
startup_core->Run({}, {});
std::set<std::string> skip_gc_vars = {"uniform_0.tmp_0",
......@@ -183,26 +182,31 @@ TEST(InterpreterCore, skip_gc_vars) {
"split_0.tmp_0",
"elementwise_add_0.tmp_0",
"tmp_0"};
std::shared_ptr<InterpreterCore> main_core =
CreateInterpreterCore(place, main_prog, &scope, {}, skip_gc_vars);
auto check_gc_result =
[](VariableScope& scope, std::set<std::string>& vars, bool is_skip_gc) {
[](Scope& scope, std::set<std::string>& vars, bool is_skip_gc) {
// the first local scope is created in startup_core
// the second local scope is created in main_core
ASSERT_EQ(scope.kids().size(), 2UL);
auto* local_scope = scope.kids().back();
for (const std::string& var_name : vars) {
ASSERT_EQ(
scope.FindVar(var_name)->GetMutable<LoDTensor>()->IsInitialized(),
is_skip_gc);
ASSERT_EQ(local_scope->FindVar(var_name)
->GetMutable<LoDTensor>()
->IsInitialized(),
is_skip_gc);
}
};
VariableScope main_scope(&scope);
std::shared_ptr<InterpreterCore> main_core =
CreateInterpreterCore(place, main_prog, &main_scope, {}, skip_gc_vars);
main_core->Run({}, {});
check_gc_result(main_scope, skip_gc_vars, true);
check_gc_result(main_scope, gc_vars, false);
check_gc_result(scope, skip_gc_vars, true);
check_gc_result(scope, gc_vars, false);
main_core->Run({}, {});
check_gc_result(main_scope, skip_gc_vars, true);
check_gc_result(main_scope, gc_vars, false);
check_gc_result(scope, skip_gc_vars, true);
check_gc_result(scope, gc_vars, false);
}
void TestShareWorkQueue(const ProgramDesc& prog,
......@@ -213,11 +217,10 @@ void TestShareWorkQueue(const ProgramDesc& prog,
const platform::CPUPlace place = platform::CPUPlace();
Scope scope;
VariableScope variable_scope(&scope);
std::shared_ptr<InterpreterCore> core1 =
CreateInterpreterCore(place, prog, &variable_scope, fetch_names);
CreateInterpreterCore(place, prog, &scope, fetch_names);
std::shared_ptr<InterpreterCore> core2 =
CreateInterpreterCore(place, prog, &variable_scope, fetch_names);
CreateInterpreterCore(place, prog, &scope, fetch_names);
core2->ShareWorkQueueFrom(core1);
auto run_and_check = [&feed_names, &feed_tensors, &fetch_results](
......
......@@ -67,9 +67,6 @@ Variable* Scope::Var(const std::string& name) {
SCOPE_VARS_WRITER_LOCK
ret = VarInternal(name);
}
for (auto l : listeners_) {
l->onCreateVariable(name, ret);
}
return ret;
}
......@@ -85,9 +82,6 @@ Variable* Scope::Var(std::string* name) {
}
ret = VarInternal(new_name);
}
for (auto l : listeners_) {
l->onCreateVariable(new_name, ret);
}
return ret;
}
......@@ -124,9 +118,6 @@ void Scope::DropKids() {
for (Scope* s : kids_) delete s;
kids_.clear();
}
for (auto l : listeners_) {
l->onClear();
}
}
bool Scope::HasKid(const Scope* scope) const {
......@@ -175,9 +166,6 @@ void Scope::DeleteScope(Scope* scope) const {
Async([scope] { delete scope; });
}
}
for (auto l : listeners_) {
l->onDeleteScope(scope);
}
}
void Scope::EraseVars(const std::vector<std::string>& var_names) {
......@@ -192,11 +180,6 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
}
}
}
for (auto l : listeners_) {
for (auto& var_name : var_names) {
l->onDeleteVariable(var_name);
}
}
}
void Scope::Rename(const std::string& origin_name,
......@@ -205,9 +188,6 @@ void Scope::Rename(const std::string& origin_name,
SCOPE_VARS_WRITER_LOCK
RenameInternal(origin_name, new_name);
}
for (auto l : listeners_) {
l->onRenameVariable(origin_name, new_name);
}
}
std::string Scope::Rename(const std::string& origin_name) const {
......@@ -216,9 +196,6 @@ std::string Scope::Rename(const std::string& origin_name) const {
SCOPE_VARS_WRITER_LOCK
RenameInternal(origin_name, new_name);
}
for (auto l : listeners_) {
l->onRenameVariable(origin_name, new_name);
}
return new_name;
}
......@@ -282,22 +259,6 @@ Variable* Scope::FindVarLocally(const std::string& name) const {
return nullptr;
}
void Scope::AddListener(const std::shared_ptr<ScopeListener>& listener) {
auto it = std::find(listeners_.begin(), listeners_.end(), listener);
if (it == listeners_.end()) {
listeners_.push_back(listener);
}
}
void Scope::DelListener(const std::shared_ptr<ScopeListener>& listener) {
listeners_.remove(listener);
}
bool Scope::HasListener(const std::shared_ptr<ScopeListener>& listener) {
auto it = std::find(listeners_.begin(), listeners_.end(), listener);
return it != listeners_.end();
}
void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) {
SCOPE_VARS_WRITER_LOCK
for (auto iter = vars_.begin(); iter != vars_.end();) {
......
......@@ -51,22 +51,6 @@ class ScopeBase {
class Scope;
class ScopeListener {
// NOTE(xiongkun03) Abstract Class, doesn't have any attributes.
// Used by VariableScope. If we modify the original scope, we
// need synchronize changes to VariableScope. So we add listerer
// in original Scope.
public:
virtual ~ScopeListener() {}
virtual void onCreateVariable(const std::string& name, Variable* v) {}
virtual void onDeleteVariable(const std::string& name) {}
virtual void onRenameVariable(const std::string& old_name,
const std::string& new_name) {}
virtual void onCreateScope(Scope* Scope) {}
virtual void onDeleteScope(Scope* Scope) {}
virtual void onClear() {}
};
/**
* @brief Scope that manage all variables.
*
......@@ -150,12 +134,6 @@ class Scope : public ScopeBase {
// Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const;
void AddListener(const std::shared_ptr<ScopeListener>& listener);
void DelListener(const std::shared_ptr<ScopeListener>& listener);
bool HasListener(const std::shared_ptr<ScopeListener>& listener);
protected:
struct KeyHasher {
std::size_t operator()(const std::string& key) const {
......@@ -192,7 +170,6 @@ class Scope : public ScopeBase {
// Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_;
const Scope* parent_{nullptr};
std::list<std::shared_ptr<ScopeListener>> listeners_;
DISABLE_COPY_AND_ASSIGN(Scope);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册