提交 065b68b6 编写于 作者: Q Qiao Longfei

clean code

上级 347178bd
...@@ -106,12 +106,6 @@ class RequestSend final : public RequestBase { ...@@ -106,12 +106,6 @@ class RequestSend final : public RequestBase {
auto invar = request_->GetVar(); auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId(); int trainer_id = request_->GetTrainerId();
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
/*
if (!request_handler_->sync_mode()) {
request_->ReleaseOwnershipOfLocalScope();
}
*/
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
......
...@@ -80,7 +80,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -80,7 +80,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto &send_slr = send_var->Get<framework::SelectedRows>(); auto &send_slr = send_var->Get<framework::SelectedRows>();
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections); auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
auto send_rows = send_slr.rows(); auto &send_rows = send_slr.rows();
std::vector<std::vector<int>> outs_rows_idx; std::vector<std::vector<int>> outs_rows_idx;
std::vector<std::vector<int>> outs_dense_idx; std::vector<std::vector<int>> outs_dense_idx;
...@@ -88,7 +88,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -88,7 +88,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
outs_dense_idx.resize(out_num); outs_dense_idx.resize(out_num);
auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0]; auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0];
auto src = send_slr.value().data<T>(); auto *src = send_slr.value().data<T>();
// create output var in local scope // create output var in local scope
std::vector<framework::SelectedRows *> outs; std::vector<framework::SelectedRows *> outs;
...@@ -110,8 +110,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -110,8 +110,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
outs[i]->set_height(rpc_ctx.height_sections[i]); outs[i]->set_height(rpc_ctx.height_sections[i]);
auto dims = send_slr.GetCompleteDims(); auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size(); dims[0] = rows_idx.size();
outs[i]->mutable_value()->mutable_data<T>(dims, send_slr.place());
outs[i]->mutable_rows()->clear(); outs[i]->mutable_rows()->clear();
outs[i]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) { if (rows_idx.size() > 0) {
for (auto idx : rows_idx) { for (auto idx : rows_idx) {
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
......
...@@ -71,15 +71,13 @@ class VarHandle { ...@@ -71,15 +71,13 @@ class VarHandle {
VarHandle(const std::string ep, const std::string& method, VarHandle(const std::string ep, const std::string& method,
const std::string& name, const std::string& name,
const platform::DeviceContext* p_ctx = nullptr, const platform::DeviceContext* p_ctx = nullptr,
const framework::Scope* p_scope = nullptr, const framework::Scope* p_scope = nullptr)
bool delete_local_scope = false)
: status_(kDefaultState) { : status_(kDefaultState) {
ep_ = ep; ep_ = ep;
ctx_ = p_ctx; ctx_ = p_ctx;
scope_ = p_scope; scope_ = p_scope;
name_ = name; name_ = name;
method_ = method; method_ = method;
delete_local_scope_ = delete_local_scope;
} }
virtual ~VarHandle() {} virtual ~VarHandle() {}
...@@ -101,7 +99,6 @@ class VarHandle { ...@@ -101,7 +99,6 @@ class VarHandle {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
status_ = ok ? kFinishState : kErrorState; status_ = ok ? kFinishState : kErrorState;
} }
if (delete_local_scope_ && scope_) delete scope_;
VLOG(7) << "VarHandle finish:" << ok; VLOG(7) << "VarHandle finish:" << ok;
wait_cond_.notify_all(); wait_cond_.notify_all();
} }
...@@ -128,7 +125,6 @@ class VarHandle { ...@@ -128,7 +125,6 @@ class VarHandle {
std::string name_; std::string name_;
// RPC method name. // RPC method name.
std::string method_; std::string method_;
bool delete_local_scope_;
protected: protected:
std::mutex sync_mutex_; std::mutex sync_mutex_;
......
...@@ -59,15 +59,8 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -59,15 +59,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
"async mode should not recv BATCH_BARRIER_MESSAGE or " "async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"); "COMPLETE_MESSAGE");
} }
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
try { scope);
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope);
delete scope;
} catch (std::exception& e) {
LOG(ERROR) << "async: run sub program error " << e.what();
return false;
}
return true; return true;
} else { // sync } else { // sync
rpc_server_->WaitCond(kRequestSend); rpc_server_->WaitCond(kRequestSend);
......
...@@ -60,13 +60,14 @@ class VariableResponse { ...@@ -60,13 +60,14 @@ class VariableResponse {
bool create_scope = false) bool create_scope = false)
: scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) { : scope_(scope), dev_ctx_(dev_ctx), create_scope_(create_scope) {
if (create_scope) { if (create_scope) {
local_scope_ = &scope->NewScope(); local_scope_ = scope->NewTmpScope();
} }
} }
virtual ~VariableResponse() { virtual ~VariableResponse() {
if (local_scope_) { if (local_scope_) {
scope_->DeleteScope(local_scope_); delete local_scope_;
local_scope_ = nullptr;
} }
} }
...@@ -86,12 +87,6 @@ class VariableResponse { ...@@ -86,12 +87,6 @@ class VariableResponse {
inline std::string Varname() const { return meta_.varname(); } inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); } inline std::string OutVarname() const { return meta_.out_varname(); }
inline std::string TableName() const { return meta_.table_name(); } inline std::string TableName() const { return meta_.table_name(); }
inline void ReleaseOwnershipOfLocalScope() {
PADDLE_ENFORCE(create_scope_,
"only when create_scope_ is true can you release the "
"ownership of local scope");
local_scope_ = nullptr;
}
// should call parse first. // should call parse first.
framework::Variable* GetVar() { framework::Variable* GetVar() {
......
...@@ -54,6 +54,7 @@ inline int FindOutIdx(int row, const std::vector<int64_t>& abs_sections) { ...@@ -54,6 +54,7 @@ inline int FindOutIdx(int row, const std::vector<int64_t>& abs_sections) {
return i - 1; return i - 1;
} }
} }
PADDLE_ENFORCE_LT(row, abs_sections.back(), "row should be less then max id");
return abs_sections.size() - 1; return abs_sections.size() - 1;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册