提交 9958775b 编写于 作者: Q Qiao Longfei

add NewTmpScope to scope

上级 7021979b
...@@ -81,6 +81,8 @@ Scope& Scope::NewScope() const { ...@@ -81,6 +81,8 @@ Scope& Scope::NewScope() const {
return *child; return *child;
} }
Scope* Scope::NewTmpScope() const { return new Scope(this); }
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
return VarInternal(name); return VarInternal(name);
......
...@@ -55,6 +55,8 @@ class Scope { ...@@ -55,6 +55,8 @@ class Scope {
/// Mark it to const because that new kid scope cannot change parent scope. /// Mark it to const because that new kid scope cannot change parent scope.
Scope& NewScope() const; Scope& NewScope() const;
Scope* NewTmpScope() const;
/// Create a variable with given name if it doesn't exist. /// Create a variable with given name if it doesn't exist.
/// Caller doesn't own the returned Variable. /// Caller doesn't own the returned Variable.
Variable* Var(const std::string& name); Variable* Var(const std::string& name);
......
...@@ -107,6 +107,9 @@ class RequestSend final : public RequestBase { ...@@ -107,6 +107,9 @@ class RequestSend final : public RequestBase {
int trainer_id = request_->GetTrainerId(); int trainer_id = request_->GetTrainerId();
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
if (!request_handler_->sync_mode()) {
request_->ReleaseOwnershipOfLocalScope();
}
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
......
...@@ -180,7 +180,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -180,7 +180,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const std::vector<int>& height_sections, const std::vector<int>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope) { const framework::Scope& scope) {
auto& local_scope = scope.NewScope(); framework::Scope* local_scope = scope.NewTmpScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& cpu_ctx = *pool.Get(platform::CPUPlace()); auto& cpu_ctx = *pool.Get(platform::CPUPlace());
...@@ -224,22 +224,22 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -224,22 +224,22 @@ void prefetch(const std::string& id_name, const std::string& out_name,
#endif #endif
} }
auto splited_ids = SplitIds(ids_vector, height_sections, &local_scope); auto splited_ids = SplitIds(ids_vector, height_sections, local_scope);
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids, SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
&local_scope); local_scope);
// create output var in local scope // create output var in local scope
for (auto& name : out_var_names) { for (auto& name : out_var_names) {
local_scope.Var(name)->GetMutable<framework::LoDTensor>(); local_scope->Var(name)->GetMutable<framework::LoDTensor>();
} }
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) { for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(local_scope, in_var_names[i])) { if (NeedSend(*local_scope, in_var_names[i])) {
VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i] VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i]
<< " to get " << out_var_names[i] << " back"; << " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar( rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], cpu_ctx, local_scope, in_var_names[i], out_var_names[i], epmap[i], cpu_ctx, *local_scope, in_var_names[i], out_var_names[i],
table_names[i])); table_names[i]));
} else { } else {
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i]; VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
...@@ -252,8 +252,8 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -252,8 +252,8 @@ void prefetch(const std::string& id_name, const std::string& out_name,
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
out_var_names, height_sections, splited_ids, out_var_names, height_sections, splited_ids,
context, &local_scope, &actual_ctx); context, local_scope, &actual_ctx);
scope.DeleteScope(&local_scope); delete local_scope;
} }
}; // namespace distributed }; // namespace distributed
......
...@@ -58,13 +58,15 @@ class VarHandle { ...@@ -58,13 +58,15 @@ class VarHandle {
VarHandle(const std::string ep, const std::string& method, VarHandle(const std::string ep, const std::string& method,
const std::string& name, const std::string& name,
const platform::DeviceContext* p_ctx = nullptr, const platform::DeviceContext* p_ctx = nullptr,
const framework::Scope* p_scope = nullptr) const framework::Scope* p_scope = nullptr,
bool delete_local_scope = false)
: status_(kDefaultState) { : status_(kDefaultState) {
ep_ = ep; ep_ = ep;
ctx_ = p_ctx; ctx_ = p_ctx;
scope_ = p_scope; scope_ = p_scope;
name_ = name; name_ = name;
method_ = method; method_ = method;
delete_local_scope_ = delete_local_scope;
} }
virtual ~VarHandle() {} virtual ~VarHandle() {}
...@@ -86,6 +88,7 @@ class VarHandle { ...@@ -86,6 +88,7 @@ class VarHandle {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
status_ = ok ? kFinishState : kErrorState; status_ = ok ? kFinishState : kErrorState;
} }
if (delete_local_scope_ && scope_) delete scope_;
VLOG(7) << "VarHandle finish:" << ok; VLOG(7) << "VarHandle finish:" << ok;
wait_cond_.notify_all(); wait_cond_.notify_all();
} }
...@@ -112,6 +115,7 @@ class VarHandle { ...@@ -112,6 +115,7 @@ class VarHandle {
std::string name_; std::string name_;
// RPC method name. // RPC method name.
std::string method_; std::string method_;
bool delete_local_scope_;
protected: protected:
std::mutex sync_mutex_; std::mutex sync_mutex_;
......
...@@ -53,13 +53,9 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -53,13 +53,9 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Async // Async
if (!sync_mode_) { if (!sync_mode_) {
VLOG(3) << "async process var: " << varname; VLOG(3) << "async process var: " << varname;
try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope); scope);
} catch (std::exception& e) { delete scope;
LOG(ERROR) << "async: run sub program error " << e.what();
return false;
}
return true; return true;
} else { // sync } else { // sync
rpc_server_->WaitCond(kRequestSend); rpc_server_->WaitCond(kRequestSend);
......
...@@ -60,14 +60,12 @@ class VariableResponse { ...@@ -60,14 +60,12 @@ class VariableResponse {
bool create_scope = false) bool create_scope = false)
: scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) { : scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) {
if (create_scope) { if (create_scope) {
local_scope_ = &scope->NewScope(); local_scope_ = scope->NewTmpScope();
} }
} }
virtual ~VariableResponse() { virtual ~VariableResponse() {
if (create_scope_) { if (local_scope_) delete local_scope_;
scope_->DeleteScope(local_scope_);
}
} }
int Parse(Source* source, const sendrecv::VariableMessage& meta) { int Parse(Source* source, const sendrecv::VariableMessage& meta) {
...@@ -86,6 +84,12 @@ class VariableResponse { ...@@ -86,6 +84,12 @@ class VariableResponse {
inline std::string Varname() const { return meta_.varname(); } inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); } inline std::string OutVarname() const { return meta_.out_varname(); }
inline std::string TableName() const { return meta_.table_name(); } inline std::string TableName() const { return meta_.table_name(); }
inline void ReleaseOwnershipOfLocalScope() {
PADDLE_ENFORCE(create_scope_,
"only when create_scope_ is true can you release the "
"ownership of local scope");
local_scope_ = nullptr;
}
// should call parse first. // should call parse first.
framework::Variable* GetVar() { framework::Variable* GetVar() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册