未验证 提交 32c3e61b 编写于 作者: X xiongkun 提交者: GitHub

Add Sync Machanism for Scope and VaraibleScope. Fix test_fetch_var (#37085)

上级 d5df6bdf
...@@ -474,13 +474,12 @@ struct VariableMetaInfo { ...@@ -474,13 +474,12 @@ struct VariableMetaInfo {
// TODO(zhiqiu): Maybe we need to add rwlock for VariableScope? // TODO(zhiqiu): Maybe we need to add rwlock for VariableScope?
// NOTE(xiongkun03): Use scope as a member of VariableScope, we don't need // NOTE(xiongkun03): Use scope as a member of VariableScope, we don't need
// ScopeBase. // ScopeBase. Scope manager the variables and VariableScope is just a quick
// Scope manager the variables and VariableScope is just a // access machanism. ScopeListener is the callback to sync changes in Original
// quick // Scope. We can make it a membership of VariableScope. Here we use inherent.
// access machanism. class VariableScope : public ScopeBase, public ScopeListener {
class VariableScope : public ScopeBase {
public: public:
VariableScope() { VariableScope(Scope* outer_scope) {
// for @EMPTY@ variable // for @EMPTY@ variable
var_list_.push_back(nullptr); var_list_.push_back(nullptr);
name2id_[kEmptyVarName] = 0; name2id_[kEmptyVarName] = 0;
...@@ -488,9 +487,20 @@ class VariableScope : public ScopeBase { ...@@ -488,9 +487,20 @@ class VariableScope : public ScopeBase {
info.var_ref_count_ = 0; info.var_ref_count_ = 0;
info.vardesc_ = nullptr; info.vardesc_ = nullptr;
vec_meta_info_.push_back(info); vec_meta_info_.push_back(info);
scope_ptr_.reset(new Scope()); outer_scope_ = outer_scope;
PADDLE_ENFORCE_NE(
outer_scope_, nullptr,
platform::errors::PreconditionNotMet(
"You have passed a nullptr to construct VariableScope."));
outer_scope->AddListener(this);
} }
const Scope* GetScope() const { return scope_ptr_.get(); }
~VariableScope() {
if (outer_scope_ != nullptr) outer_scope_->DelListener(this);
}
const Scope* GetScope() const { return outer_scope_; }
Variable* FindVar(const std::string& name) const { Variable* FindVar(const std::string& name) const {
auto it = name2id_.find(name); auto it = name2id_.find(name);
...@@ -548,8 +558,9 @@ class VariableScope : public ScopeBase { ...@@ -548,8 +558,9 @@ class VariableScope : public ScopeBase {
size_t VarSize() const { return var_list_.size(); } size_t VarSize() const { return var_list_.size(); }
void AddVar(const std::string& name, VarDesc* var_desc) { // NOLINT void AddVar(const std::string& name, VarDesc* var_desc) { // NOLINT
name2id_[name] = VarSize(); // AddVar -> Scope::Var -> onCreateVariable.
auto v = scope_ptr_->Var(name); VLOG(4) << "Add variable: " << name << " through AddVar()";
auto v = outer_scope_->Var(name);
if (nullptr == var_desc) { if (nullptr == var_desc) {
v->GetMutable<LoDTensor>(); v->GetMutable<LoDTensor>();
} else { } else {
...@@ -558,26 +569,13 @@ class VariableScope : public ScopeBase { ...@@ -558,26 +569,13 @@ class VariableScope : public ScopeBase {
var_desc var_desc
->GetType()); // Scope don't initialize variable recently created ->GetType()); // Scope don't initialize variable recently created
} }
var_list_.push_back(v); SetVarDesc(name, var_desc);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = var_desc;
vec_meta_info_.push_back(info);
} }
void AddVar(const std::string& name, Variable& var) { // NOLINT void AddVar(const std::string& name, Variable& var) { // NOLINT
// must copy. // Though name existed in outer_scope_, we need
VLOG(4) << "Add variable: " << name << " through AddVar()"; // add again to create name2id map.
auto v = scope_ptr_->Var(name); outer_scope_->Var(name);
*v = var;
name2id_[name] = VarSize();
var_list_.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = nullptr;
vec_meta_info_.push_back(info);
} }
void SetVarDesc(const std::string& name, framework::VarDesc* var_desc) { void SetVarDesc(const std::string& name, framework::VarDesc* var_desc) {
...@@ -607,6 +605,32 @@ class VariableScope : public ScopeBase { ...@@ -607,6 +605,32 @@ class VariableScope : public ScopeBase {
platform::errors::NotFound("%s not in VariableScope.", name)); platform::errors::NotFound("%s not in VariableScope.", name));
} }
public: // callbacks from ScopeListener class
void onCreateVariable(const std::string& name) override {
auto v = outer_scope_->GetVar(name); // must exsit in outer_scope_
if (!HasVar(name)) { // may exist in variable scope.
VLOG(4) << "Calling VariableScope::onCreateVariable with var_name: "
<< name;
name2id_[name] = VarSize();
var_list_.push_back(v);
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = nullptr; // set nullptr, then modifty it in AddVar()
vec_meta_info_.push_back(info);
}
}
void onDeleteVariable(const std::string& name) override {
if (HasVar(name)) {
VLOG(4) << "Calling VariableScope::onDeleteVariable with var_name: "
<< name;
}
}
void onRenameVariable(const std::string& old_name,
const std::string& new_name) override {}
void onCreateScope(Scope* Scope) override {}
void onDeleteScope(Scope* Scope) override {}
void onClear() override {}
std::vector<VariableMetaInfo>& MutableVecMetaInfo() { return vec_meta_info_; } std::vector<VariableMetaInfo>& MutableVecMetaInfo() { return vec_meta_info_; }
const std::vector<VariableMetaInfo>& VecMetaInfo() const { const std::vector<VariableMetaInfo>& VecMetaInfo() const {
...@@ -617,7 +641,7 @@ class VariableScope : public ScopeBase { ...@@ -617,7 +641,7 @@ class VariableScope : public ScopeBase {
std::vector<Variable*> var_list_; std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_; std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_; std::vector<VariableMetaInfo> vec_meta_info_;
std::unique_ptr<Scope> scope_ptr_; Scope* outer_scope_ = nullptr;
}; };
class NextInstruction { class NextInstruction {
......
...@@ -23,9 +23,9 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -23,9 +23,9 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
: place_(place), : place_(place),
startup_prog_(startup_prog), startup_prog_(startup_prog),
main_prog_(main_prog), main_prog_(main_prog),
outer_scope_(scope) { outer_scope_(scope),
global_scope_(scope) {
paddle::framework::InitDevices(); paddle::framework::InitDevices();
// init scope // init scope
BuildVariableOuterScope(startup_prog, &global_scope_, scope); BuildVariableOuterScope(startup_prog, &global_scope_, scope);
......
...@@ -59,18 +59,35 @@ std::unique_ptr<Scope> Scope::NewTmpScope() const { ...@@ -59,18 +59,35 @@ std::unique_ptr<Scope> Scope::NewTmpScope() const {
} }
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
SCOPE_VARS_WRITER_LOCK // NOTE(xiongkun03): add {} here to unlock. With {}, scope
return VarInternal(name); // will do callback after unlock.
Variable* ret = nullptr;
{
SCOPE_VARS_WRITER_LOCK
ret = VarInternal(name);
}
for (auto l : listeners_) {
l->onCreateVariable(name);
}
return ret;
} }
Variable* Scope::Var(std::string* name) { Variable* Scope::Var(std::string* name) {
SCOPE_VARS_WRITER_LOCK Variable* ret = nullptr;
auto new_name = std::to_string(reinterpret_cast<uintptr_t>(this)) + "." + std::string new_name;
std::to_string(vars_.size()); {
if (name != nullptr) { SCOPE_VARS_WRITER_LOCK
*name = new_name; new_name = std::to_string(reinterpret_cast<uintptr_t>(this)) + "." +
std::to_string(vars_.size());
if (name != nullptr) {
*name = new_name;
}
ret = VarInternal(new_name);
}
for (auto l : listeners_) {
l->onCreateVariable(new_name);
} }
return VarInternal(new_name); return ret;
} }
Variable* Scope::FindVar(const std::string& name) const { Variable* Scope::FindVar(const std::string& name) const {
...@@ -101,9 +118,14 @@ const Scope* Scope::FindScope(const std::string& name) const { ...@@ -101,9 +118,14 @@ const Scope* Scope::FindScope(const std::string& name) const {
} }
void Scope::DropKids() { void Scope::DropKids() {
SCOPE_KIDS_WRITER_LOCK {
for (Scope* s : kids_) delete s; SCOPE_KIDS_WRITER_LOCK
kids_.clear(); for (Scope* s : kids_) delete s;
kids_.clear();
}
for (auto l : listeners_) {
l->onClear();
}
} }
bool Scope::HasKid(const Scope* scope) const { bool Scope::HasKid(const Scope* scope) const {
...@@ -125,42 +147,64 @@ std::vector<std::string> Scope::LocalVarNames() const { ...@@ -125,42 +147,64 @@ std::vector<std::string> Scope::LocalVarNames() const {
} }
void Scope::DeleteScope(Scope* scope) const { void Scope::DeleteScope(Scope* scope) const {
SCOPE_KIDS_WRITER_LOCK {
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); SCOPE_KIDS_WRITER_LOCK
PADDLE_ENFORCE_NE(it, this->kids_.end(), auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
platform::errors::NotFound( PADDLE_ENFORCE_NE(it, this->kids_.end(),
"%p is not found in %p as kid scope", scope, this)); platform::errors::NotFound(
this->kids_.erase(it); "%p is not found in %p as kid scope", scope, this));
// When making memory benchmark on Fluid, we have to delete scope sync. this->kids_.erase(it);
if (FLAGS_benchmark || FLAGS_eager_delete_scope) { // When making memory benchmark on Fluid, we have to delete scope sync.
delete scope; if (FLAGS_benchmark || FLAGS_eager_delete_scope) {
} else { delete scope;
Async([scope] { delete scope; }); } else {
Async([scope] { delete scope; });
}
}
for (auto l : listeners_) {
l->onDeleteScope(scope);
} }
} }
void Scope::EraseVars(const std::vector<std::string>& var_names) { void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::set<std::string> var_set(var_names.begin(), var_names.end()); {
SCOPE_VARS_WRITER_LOCK std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) { SCOPE_VARS_WRITER_LOCK
if (var_set.find(it->first) != var_set.end()) { for (auto it = vars_.begin(); it != vars_.end();) {
it = vars_.erase(it); if (var_set.find(it->first) != var_set.end()) {
} else { it = vars_.erase(it);
++it; } else {
++it;
}
}
}
for (auto l : listeners_) {
for (auto& var_name : var_names) {
l->onDeleteVariable(var_name);
} }
} }
} }
void Scope::Rename(const std::string& origin_name, void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const { const std::string& new_name) const {
SCOPE_VARS_WRITER_LOCK {
RenameInternal(origin_name, new_name); SCOPE_VARS_WRITER_LOCK
RenameInternal(origin_name, new_name);
}
for (auto l : listeners_) {
l->onRenameVariable(origin_name, new_name);
}
} }
std::string Scope::Rename(const std::string& origin_name) const { std::string Scope::Rename(const std::string& origin_name) const {
SCOPE_VARS_WRITER_LOCK
auto new_name = string::Sprintf("%p.%d", this, vars_.size()); auto new_name = string::Sprintf("%p.%d", this, vars_.size());
RenameInternal(origin_name, new_name); {
SCOPE_VARS_WRITER_LOCK
RenameInternal(origin_name, new_name);
}
for (auto l : listeners_) {
l->onRenameVariable(origin_name, new_name);
}
return new_name; return new_name;
} }
...@@ -222,6 +266,17 @@ Variable* Scope::FindVarLocally(const std::string& name) const { ...@@ -222,6 +266,17 @@ Variable* Scope::FindVarLocally(const std::string& name) const {
return nullptr; return nullptr;
} }
void Scope::AddListener(ScopeListener* listener) {
auto it = std::find(listeners_.begin(), listeners_.end(), listener);
if (it == listeners_.end()) {
listeners_.push_back(listener);
}
}
void Scope::DelListener(ScopeListener* listener) {
listeners_.remove(listener);
}
void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) { void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
for (auto iter = vars_.begin(); iter != vars_.end();) { for (auto iter = vars_.begin(); iter != vars_.end();) {
......
...@@ -51,6 +51,22 @@ class ScopeBase { ...@@ -51,6 +51,22 @@ class ScopeBase {
class Scope; class Scope;
class ScopeListener {
// NOTE(xiongkun03) Abstract Class, doesn't have any attributes.
// Used by VariableScope. If we modify the original scope, we
// need synchronize changes to VariableScope. So we add listerer
// in original Scope.
public:
virtual ~ScopeListener() {}
virtual void onCreateVariable(const std::string& name) {}
virtual void onDeleteVariable(const std::string& name) {}
virtual void onRenameVariable(const std::string& old_name,
const std::string& new_name) {}
virtual void onCreateScope(Scope* Scope) {}
virtual void onDeleteScope(Scope* Scope) {}
virtual void onClear() {}
};
/** /**
* @brief Scope that manage all variables. * @brief Scope that manage all variables.
* *
...@@ -128,6 +144,10 @@ class Scope : public ScopeBase { ...@@ -128,6 +144,10 @@ class Scope : public ScopeBase {
// Rename variable to a new name and return the new name // Rename variable to a new name and return the new name
std::string Rename(const std::string& origin_name) const; std::string Rename(const std::string& origin_name) const;
void AddListener(ScopeListener* listener);
void DelListener(ScopeListener* listener);
protected: protected:
struct KeyHasher { struct KeyHasher {
std::size_t operator()(const std::string& key) const { std::size_t operator()(const std::string& key) const {
...@@ -164,6 +184,7 @@ class Scope : public ScopeBase { ...@@ -164,6 +184,7 @@ class Scope : public ScopeBase {
// Scope in `kids_` are owned by this class. // Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
const Scope* parent_{nullptr}; const Scope* parent_{nullptr};
std::list<ScopeListener*> listeners_;
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册