提交 f4a49cb0 编写于 作者: D dongzhihong

Merge remote-tracking branch 'origin/develop' into doc/api1

...@@ -43,48 +43,29 @@ Scope& Scope::NewScope() const { ...@@ -43,48 +43,29 @@ Scope& Scope::NewScope() const {
} }
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
// acquire the lock when new var under this scope
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
auto* v = FindVarLocally(name); return VarInternal(name);
if (v != nullptr) return v;
v = new Variable();
vars_[name].reset(v);
VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first);
return v;
} }
Variable* Scope::Var(std::string* name) { Variable* Scope::Var(std::string* name) {
auto var_name = string::Sprintf("%p.%d", this, vars_.size()); std::unique_lock<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) { if (name != nullptr) {
*name = var_name; *name = new_name;
} }
return Var(var_name); return VarInternal(new_name);
} }
Variable* Scope::FindVar(const std::string& name) const { Variable* Scope::FindVar(const std::string& name) const {
// acquire the lock when find var
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
return FindVarInternal(name); return FindVarInternal(name);
} }
Variable* Scope::FindVarInternal(const std::string& name) const {
auto var = FindVarLocally(name);
if (var != nullptr) {
return var;
}
return (parent_ == nullptr) ? nullptr : parent_->FindVarInternal(name);
}
const Scope* Scope::FindScope(const Variable* var) const { const Scope* Scope::FindScope(const Variable* var) const {
for (auto& kv : vars_) { std::unique_lock<std::mutex> lock(mutex_);
if (kv.second.get() == var) { return FindScopeInternal(var);
return this;
}
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
} }
void Scope::DropKids() { void Scope::DropKids() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
for (Scope* s : kids_) delete s; for (Scope* s : kids_) delete s;
...@@ -92,6 +73,7 @@ void Scope::DropKids() { ...@@ -92,6 +73,7 @@ void Scope::DropKids() {
} }
std::vector<std::string> Scope::LocalVarNames() const { std::vector<std::string> Scope::LocalVarNames() const {
std::unique_lock<std::mutex> lock(mutex_);
std::vector<std::string> known_vars; std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size()); known_vars.reserve(this->vars_.size());
for (auto& p : vars_) { for (auto& p : vars_) {
...@@ -127,6 +109,39 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) { ...@@ -127,6 +109,39 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
void Scope::Rename(const std::string& origin_name, void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const { const std::string& new_name) const {
std::unique_lock<std::mutex> lock(mutex_);
RenameInternal(origin_name, new_name);
}
std::string Scope::Rename(const std::string& origin_name) const {
std::unique_lock<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
RenameInternal(origin_name, new_name);
return new_name;
}
Variable* Scope::VarInternal(const std::string& name) {
auto* v = FindVarLocally(name);
if (v != nullptr) return v;
v = new Variable();
vars_[name].reset(v);
VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first);
return v;
}
const Scope* Scope::FindScopeInternal(const Variable* var) const {
for (auto& kv : vars_) {
if (kv.second.get() == var) {
return this;
}
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
}
void Scope::RenameInternal(const std::string& origin_name,
const std::string& new_name) const {
auto origin_it = vars_.find(origin_name); auto origin_it = vars_.find(origin_name);
PADDLE_ENFORCE(origin_it != vars_.end(), PADDLE_ENFORCE(origin_it != vars_.end(),
"Cannot find original variable with name %s", origin_name); "Cannot find original variable with name %s", origin_name);
...@@ -137,10 +152,12 @@ void Scope::Rename(const std::string& origin_name, ...@@ -137,10 +152,12 @@ void Scope::Rename(const std::string& origin_name,
vars_.erase(origin_it); vars_.erase(origin_it);
} }
std::string Scope::Rename(const std::string& origin_name) const { Variable* Scope::FindVarInternal(const std::string& name) const {
auto var_name = string::Sprintf("%p.%d", this, vars_.size()); auto var = FindVarLocally(name);
Rename(origin_name, var_name); if (var != nullptr) {
return var_name; return var;
}
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
} }
Variable* Scope::FindVarLocally(const std::string& name) const { Variable* Scope::FindVarLocally(const std::string& name) const {
......
...@@ -88,12 +88,20 @@ class Scope { ...@@ -88,12 +88,20 @@ class Scope {
// Call Scope::NewScope for a sub-scope. // Call Scope::NewScope for a sub-scope.
explicit Scope(Scope const* parent) : parent_(parent) {} explicit Scope(Scope const* parent) : parent_(parent) {}
// Called by Var.
Variable* VarInternal(const std::string& name);
// Called by FindScope.
const Scope* FindScopeInternal(const Variable* var) const;
// Called by Rename.
void RenameInternal(const std::string& origin_name,
const std::string& new_name) const;
// Called by FindVar recursively. // Called by FindVar recursively.
// Caller doesn't own the returned Variable.
Variable* FindVarInternal(const std::string& name) const; Variable* FindVarInternal(const std::string& name) const;
// Called by FindVarInternal and Var. // Called by FindVarInternal and Var.
// Caller doesn't own the returned Variable.
Variable* FindVarLocally(const std::string& name) const; Variable* FindVarLocally(const std::string& name) const;
// Scope in `kids_` are owned by this class. // Scope in `kids_` are owned by this class.
......
...@@ -103,9 +103,9 @@ void ThreadRunInfer( ...@@ -103,9 +103,9 @@ void ThreadRunInfer(
const int tid, paddle::framework::Scope* scope, const int tid, paddle::framework::Scope* scope,
const std::vector<std::vector<const paddle::framework::LoDTensor*>>& jobs) { const std::vector<std::vector<const paddle::framework::LoDTensor*>>& jobs) {
// maybe framework:ProgramDesc is not thread-safe // maybe framework:ProgramDesc is not thread-safe
paddle::platform::CPUPlace place;
paddle::framework::Executor executor(place);
auto& sub_scope = scope->NewScope(); auto& sub_scope = scope->NewScope();
auto place = paddle::platform::CPUPlace();
auto executor = paddle::framework::Executor(place);
auto inference_program = auto inference_program =
paddle::inference::Load(&executor, scope, FLAGS_model_path); paddle::inference::Load(&executor, scope, FLAGS_model_path);
...@@ -182,8 +182,8 @@ TEST(inference, nlp) { ...@@ -182,8 +182,8 @@ TEST(inference, nlp) {
stop_ms = GetCurrentMs(); stop_ms = GetCurrentMs();
} else { } else {
// 1. Define place, executor, scope // 1. Define place, executor, scope
auto place = paddle::platform::CPUPlace(); paddle::platform::CPUPlace place;
auto executor = paddle::framework::Executor(place); paddle::framework::Executor executor(place);
// 2. Initialize the inference_program and load parameters // 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program; std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册