未验证 提交 0ddc5d86 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #11258 from tensor-tang/refine

Refine test and scope lock
......@@ -43,48 +43,29 @@ Scope& Scope::NewScope() const {
}
Variable* Scope::Var(const std::string& name) {
// acquire the lock when new var under this scope
std::unique_lock<std::mutex> lock(mutex_);
auto* v = FindVarLocally(name);
if (v != nullptr) return v;
v = new Variable();
vars_[name].reset(v);
VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first);
return v;
return VarInternal(name);
}
Variable* Scope::Var(std::string* name) {
auto var_name = string::Sprintf("%p.%d", this, vars_.size());
std::unique_lock<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) {
*name = var_name;
*name = new_name;
}
return Var(var_name);
return VarInternal(new_name);
}
Variable* Scope::FindVar(const std::string& name) const {
// acquire the lock when find var
std::unique_lock<std::mutex> lock(mutex_);
return FindVarInternal(name);
}
Variable* Scope::FindVarInternal(const std::string& name) const {
auto var = FindVarLocally(name);
if (var != nullptr) {
return var;
}
return (parent_ == nullptr) ? nullptr : parent_->FindVarInternal(name);
}
const Scope* Scope::FindScope(const Variable* var) const {
for (auto& kv : vars_) {
if (kv.second.get() == var) {
return this;
}
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
std::unique_lock<std::mutex> lock(mutex_);
return FindScopeInternal(var);
}
void Scope::DropKids() {
std::unique_lock<std::mutex> lock(mutex_);
for (Scope* s : kids_) delete s;
......@@ -92,6 +73,7 @@ void Scope::DropKids() {
}
std::vector<std::string> Scope::LocalVarNames() const {
std::unique_lock<std::mutex> lock(mutex_);
std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
......@@ -127,6 +109,39 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const {
std::unique_lock<std::mutex> lock(mutex_);
RenameInternal(origin_name, new_name);
}
std::string Scope::Rename(const std::string& origin_name) const {
std::unique_lock<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
RenameInternal(origin_name, new_name);
return new_name;
}
Variable* Scope::VarInternal(const std::string& name) {
auto* v = FindVarLocally(name);
if (v != nullptr) return v;
v = new Variable();
vars_[name].reset(v);
VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first);
return v;
}
const Scope* Scope::FindScopeInternal(const Variable* var) const {
for (auto& kv : vars_) {
if (kv.second.get() == var) {
return this;
}
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
}
void Scope::RenameInternal(const std::string& origin_name,
const std::string& new_name) const {
auto origin_it = vars_.find(origin_name);
PADDLE_ENFORCE(origin_it != vars_.end(),
"Cannot find original variable with name %s", origin_name);
......@@ -137,10 +152,12 @@ void Scope::Rename(const std::string& origin_name,
vars_.erase(origin_it);
}
std::string Scope::Rename(const std::string& origin_name) const {
auto var_name = string::Sprintf("%p.%d", this, vars_.size());
Rename(origin_name, var_name);
return var_name;
Variable* Scope::FindVarInternal(const std::string& name) const {
auto var = FindVarLocally(name);
if (var != nullptr) {
return var;
}
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
}
Variable* Scope::FindVarLocally(const std::string& name) const {
......
......@@ -88,12 +88,20 @@ class Scope {
// Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const* parent) : parent_(parent) {}
// Called by Var.
Variable* VarInternal(const std::string& name);
// Called by FindScope.
const Scope* FindScopeInternal(const Variable* var) const;
// Called by Rename.
void RenameInternal(const std::string& origin_name,
const std::string& new_name) const;
// Called by FindVar recursively.
// Caller doesn't own the returned Variable.
Variable* FindVarInternal(const std::string& name) const;
// Called by FindVarInternal and Var.
// Caller doesn't own the returned Variable.
Variable* FindVarLocally(const std::string& name) const;
// Scope in `kids_` are owned by this class.
......
......@@ -103,9 +103,9 @@ void ThreadRunInfer(
const int tid, paddle::framework::Scope* scope,
const std::vector<std::vector<const paddle::framework::LoDTensor*>>& jobs) {
// maybe framework:ProgramDesc is not thread-safe
paddle::platform::CPUPlace place;
paddle::framework::Executor executor(place);
auto& sub_scope = scope->NewScope();
auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place);
auto inference_program =
paddle::inference::Load(&executor, scope, FLAGS_model_path);
......@@ -182,8 +182,8 @@ TEST(inference, nlp) {
stop_ms = GetCurrentMs();
} else {
// 1. Define place, executor, scope
auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place);
paddle::platform::CPUPlace place;
paddle::framework::Executor executor(place);
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册