提交 b8d315fb 编写于 作者: T tensor-tang

make scope thread safe

上级 bfd42683
...@@ -43,49 +43,29 @@ Scope& Scope::NewScope() const { ...@@ -43,49 +43,29 @@ Scope& Scope::NewScope() const {
} }
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
// acquire the lock when new var under this scope
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
auto* v = FindVarLocally(name); return VarInternal(name);
if (v != nullptr) return v;
v = new Variable();
vars_[name].reset(v);
VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first);
return v;
} }
Variable* Scope::Var(std::string* name) { Variable* Scope::Var(std::string* name) {
auto var_name = string::Sprintf("%p.%d", this, vars_.size()); std::unique_lock<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) { if (name != nullptr) {
*name = var_name; *name = new_name;
} }
return Var(var_name); return VarInternal(new_name);
} }
Variable* Scope::FindVar(const std::string& name) const { Variable* Scope::FindVar(const std::string& name) const {
// acquire the lock when find var
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
return FindVarInternal(name); return FindVarInternal(name);
} }
Variable* Scope::FindVarInternal(const std::string& name) const {
auto var = FindVarLocally(name);
if (var != nullptr) {
return var;
}
return (parent_ == nullptr) ? nullptr : parent_->FindVarInternal(name);
}
const Scope* Scope::FindScope(const Variable* var) const { const Scope* Scope::FindScope(const Variable* var) const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
for (auto& kv : vars_) { return FindScopeInternal(var);
if (kv.second.get() == var) {
return this;
}
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
} }
void Scope::DropKids() { void Scope::DropKids() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
for (Scope* s : kids_) delete s; for (Scope* s : kids_) delete s;
...@@ -93,6 +73,7 @@ void Scope::DropKids() { ...@@ -93,6 +73,7 @@ void Scope::DropKids() {
} }
std::vector<std::string> Scope::LocalVarNames() const { std::vector<std::string> Scope::LocalVarNames() const {
std::unique_lock<std::mutex> lock(mutex_);
std::vector<std::string> known_vars; std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size()); known_vars.reserve(this->vars_.size());
for (auto& p : vars_) { for (auto& p : vars_) {
...@@ -129,6 +110,38 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) { ...@@ -129,6 +110,38 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
void Scope::Rename(const std::string& origin_name, void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const { const std::string& new_name) const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
RenameInternal(origin_name, new_name);
}
std::string Scope::Rename(const std::string& origin_name) const {
std::unique_lock<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
RenameInternal(origin_name, new_name);
return new_name;
}
Variable* Scope::VarInternal(const std::string& name) {
auto* v = FindVarLocally(name);
if (v != nullptr) return v;
v = new Variable();
vars_[name].reset(v);
VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first);
return v;
}
const Scope* Scope::FindScopeInternal(const Variable* var) const {
for (auto& kv : vars_) {
if (kv.second.get() == var) {
return this;
}
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
}
void Scope::RenameInternal(const std::string& origin_name,
const std::string& new_name) const {
auto origin_it = vars_.find(origin_name); auto origin_it = vars_.find(origin_name);
PADDLE_ENFORCE(origin_it != vars_.end(), PADDLE_ENFORCE(origin_it != vars_.end(),
"Cannot find original variable with name %s", origin_name); "Cannot find original variable with name %s", origin_name);
...@@ -139,10 +152,12 @@ void Scope::Rename(const std::string& origin_name, ...@@ -139,10 +152,12 @@ void Scope::Rename(const std::string& origin_name,
vars_.erase(origin_it); vars_.erase(origin_it);
} }
std::string Scope::Rename(const std::string& origin_name) const { Variable* Scope::FindVarInternal(const std::string& name) const {
auto var_name = string::Sprintf("%p.%d", this, vars_.size()); auto var = FindVarLocally(name);
Rename(origin_name, var_name); if (var != nullptr) {
return var_name; return var;
}
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
} }
Variable* Scope::FindVarLocally(const std::string& name) const { Variable* Scope::FindVarLocally(const std::string& name) const {
......
...@@ -85,12 +85,20 @@ class Scope { ...@@ -85,12 +85,20 @@ class Scope {
// Call Scope::NewScope for a sub-scope. // Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const* parent) : parent_(parent) {} explicit Scope(Scope const* parent) : parent_(parent) {}
// Called by Var.
Variable* VarInternal(const std::string& name);
// Called by FindScope.
const Scope* FindScopeInternal(const Variable* var) const;
// Called by Rename.
void RenameInternal(const std::string& origin_name,
const std::string& new_name) const;
// Called by FindVar recursively. // Called by FindVar recursively.
// Caller doesn't own the returned Variable.
Variable* FindVarInternal(const std::string& name) const; Variable* FindVarInternal(const std::string& name) const;
// Called by FindVarInternal and Var. // Called by FindVarInternal and Var.
// Caller doesn't own the returned Variable.
Variable* FindVarLocally(const std::string& name) const; Variable* FindVarLocally(const std::string& name) const;
mutable std::unordered_map<std::string, std::unique_ptr<Variable>> vars_; mutable std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册