diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index fd23fdeab71270d3e399af02636d801270f91c62..50f374e3703a97f6c1fdb4b14fdeb0b603f9ac86 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -43,49 +43,29 @@ Scope& Scope::NewScope() const { } Variable* Scope::Var(const std::string& name) { - // acquire the lock when new var under this scope std::unique_lock lock(mutex_); - auto* v = FindVarLocally(name); - if (v != nullptr) return v; - - v = new Variable(); - vars_[name].reset(v); - VLOG(3) << "Create variable " << name; - v->name_ = &(vars_.find(name)->first); - return v; + return VarInternal(name); } Variable* Scope::Var(std::string* name) { - auto var_name = string::Sprintf("%p.%d", this, vars_.size()); + std::unique_lock lock(mutex_); + auto new_name = string::Sprintf("%p.%d", this, vars_.size()); if (name != nullptr) { - *name = var_name; + *name = new_name; } - return Var(var_name); + return VarInternal(new_name); } Variable* Scope::FindVar(const std::string& name) const { - // acquire the lock when find var std::unique_lock lock(mutex_); return FindVarInternal(name); } -Variable* Scope::FindVarInternal(const std::string& name) const { - auto var = FindVarLocally(name); - if (var != nullptr) { - return var; - } - return (parent_ == nullptr) ? nullptr : parent_->FindVarInternal(name); -} - const Scope* Scope::FindScope(const Variable* var) const { std::unique_lock lock(mutex_); - for (auto& kv : vars_) { - if (kv.second.get() == var) { - return this; - } - } - return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); + return FindScopeInternal(var); } + void Scope::DropKids() { std::unique_lock lock(mutex_); for (Scope* s : kids_) delete s; @@ -93,6 +73,7 @@ void Scope::DropKids() { } std::vector Scope::LocalVarNames() const { + std::unique_lock lock(mutex_); std::vector known_vars; known_vars.reserve(this->vars_.size()); for (auto& p : vars_) { @@ -129,6 +110,38 @@ void Scope::EraseVars(const std::vector& var_names) { void Scope::Rename(const std::string& origin_name, const std::string& new_name) const { std::unique_lock lock(mutex_); + RenameInternal(origin_name, new_name); +} + +std::string Scope::Rename(const std::string& origin_name) const { + std::unique_lock lock(mutex_); + auto new_name = string::Sprintf("%p.%d", this, vars_.size()); + RenameInternal(origin_name, new_name); + return new_name; +} + +Variable* Scope::VarInternal(const std::string& name) { + auto* v = FindVarLocally(name); + if (v != nullptr) return v; + + v = new Variable(); + vars_[name].reset(v); + VLOG(3) << "Create variable " << name; + v->name_ = &(vars_.find(name)->first); + return v; +} + +const Scope* Scope::FindScopeInternal(const Variable* var) const { + for (auto& kv : vars_) { + if (kv.second.get() == var) { + return this; + } + } + return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); +} + +void Scope::RenameInternal(const std::string& origin_name, + const std::string& new_name) const { auto origin_it = vars_.find(origin_name); PADDLE_ENFORCE(origin_it != vars_.end(), "Cannot find original variable with name %s", origin_name); @@ -139,10 +152,12 @@ void Scope::Rename(const std::string& origin_name, vars_.erase(origin_it); } -std::string Scope::Rename(const std::string& origin_name) const { - auto var_name = string::Sprintf("%p.%d", this, vars_.size()); - Rename(origin_name, var_name); - return var_name; +Variable* Scope::FindVarInternal(const std::string& name) const { + auto var = FindVarLocally(name); + if (var != nullptr) { + return var; + } + return (parent_ == nullptr) ? nullptr : parent_->FindVar(name); } Variable* Scope::FindVarLocally(const std::string& name) const { diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index 98d103d867987fc02dc66df5ac855a14b66b8f03..34687df3ab1d4c9e9c9180ec2cd3a4dfbe8c1ca7 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -85,12 +85,20 @@ class Scope { // Call Scope::NewScope for a sub-scope. explicit Scope(Scope const* parent) : parent_(parent) {} + // Called by Var. + Variable* VarInternal(const std::string& name); + + // Called by FindScope. + const Scope* FindScopeInternal(const Variable* var) const; + + // Called by Rename. + void RenameInternal(const std::string& origin_name, + const std::string& new_name) const; + // Called by FindVar recursively. - // Caller doesn't own the returned Variable. Variable* FindVarInternal(const std::string& name) const; // Called by FindVarInternal and Var. - // Caller doesn't own the returned Variable. Variable* FindVarLocally(const std::string& name) const; mutable std::unordered_map> vars_;