提交 6edfae42 编写于 作者: Y Yancey1989

reset received vars on pserver

上级 f76f42c2
......@@ -67,24 +67,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
LOG(FATAL) << "sync: Can not find server side var: " << varname;
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
sparse_vars_.push_back(invar);
}
}
}
return true;
}
void RequestSendHandler::ResetSparseVarRecorder() {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
for (auto* var : sparse_vars_) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
sparse_vars_.clear();
}
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
......
......@@ -41,11 +41,6 @@ class RequestSendHandler final : public RequestHandler {
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
void ResetSparseVarRecorder();
private:
std::mutex mutex_sparse_vars_;
std::vector<framework::Variable*> sparse_vars_;
};
class RequestGetHandler final : public RequestHandler {
......
......@@ -101,6 +101,8 @@ void RPCServer::Complete() {
{
std::unique_lock<std::mutex> lock(mutex_);
client_num_--;
need_reset_all_vars_ = true;
VLOG(4) << "decrease client_num to: " << client_num_;
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--;
......@@ -109,6 +111,11 @@ void RPCServer::Complete() {
barrier_cond_.notify_all();
}
bool RPCServer::NeedResetAllVars() {
std::unique_lock<std::mutex> lock(mutex_);
return need_reset_all_vars_;
}
int RPCServer::GetClientNum() {
std::unique_lock<std::mutex> lock(mutex_);
return client_num_;
......@@ -120,6 +127,7 @@ void RPCServer::ResetBarrierCounter() {
for (auto& t : barrier_counter_) {
t.second = 0;
}
need_reset_all_vars_ = false;
}
void RPCServer::RegisterRPC(const std::string& rpc_name,
......
......@@ -49,7 +49,8 @@ class RPCServer {
bind_address_(address),
exit_flag_(false),
selected_port_(0),
client_num_(client_num) {}
client_num_(client_num),
need_reset_all_vars_(false) {}
virtual ~RPCServer() {}
virtual void StartServer() = 0;
......@@ -86,6 +87,8 @@ class RPCServer {
void ResetBarrierCounter();
RPCServerProfiler& Profiler() { return profiler_; }
bool NeedResetAllVars();
protected:
virtual void ShutDownImpl() = 0;
......@@ -104,6 +107,7 @@ class RPCServer {
std::atomic<int> exit_flag_;
int selected_port_;
int client_num_;
bool need_reset_all_vars_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_;
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
......@@ -101,9 +102,10 @@ static int64_t GetTimestamp() {
void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
const std::vector<int> &prefetch_block_id_list,
const int checkpoint_point_block_id) const {
const int checkpoint_point_block_id,
const std::vector<std::string> &recv_varnames) const {
VLOG(2) << "RunSyncLoop";
size_t num_blocks = program->Size();
auto optimize_blocks =
......@@ -166,8 +168,8 @@ void ListenAndServOp::RunSyncLoop(
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
// reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast<distributed::RequestSendHandler *>(request_send_handler_.get())
->ResetSparseVarRecorder();
ResetReceivedVars(recv_varnames, recv_scope, dev_ctx,
!rpc_service_->NeedResetAllVars());
rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet);
......@@ -175,6 +177,33 @@ void ListenAndServOp::RunSyncLoop(
} // while(true)
}
void ListenAndServOp::ResetReceivedVars(
const std::vector<std::string> &recv_varnames, framework::Scope *recv_scope,
platform::DeviceContext *dev_ctx, bool only_sparse_vars) const {
for (auto &varname : recv_varnames) {
auto var = recv_scope->FindVar(varname);
if (var == nullptr) {
VLOG(2) << "can not find var " << varname << " in received scope";
continue;
}
if (var->IsType<framework::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
if (!only_sparse_vars) {
if (var->IsType<framework::LoDTensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
static_cast<float>(0));
} else if (var->IsType<framework::Tensor>()) {
math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
static_cast<float>(0));
} else {
PADDLE_THROW(
"received var should be in [SelectedRows, LoDTensor, Tensor]");
}
}
}
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope) const {
......@@ -258,6 +287,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
bool sync_mode = Attr<bool>("sync_mode");
auto fan_in = Attr<int>("Fanin");
auto inputs = Inputs("X");
PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
......@@ -351,8 +381,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// Write to a file of server selected port for python use.
SavePort();
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list,
checkpoint_block_id);
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id, inputs);
} else {
RunAsyncLoop(&executor, program, &recv_scope);
}
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -48,8 +49,10 @@ class ListenAndServOp : public framework::OperatorBase {
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
const std::vector<int>& prefetch_block_id_list,
const int checkpoint_point_block_id) const;
const int checkpoint_point_block_id,
const std::vector<std::string>& recv_varnames) const;
void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
......@@ -64,6 +67,11 @@ class ListenAndServOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override;
void ResetReceivedVars(const std::vector<std::string>& recv_varnames,
framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx,
bool only_sparse_vars = true) const;
protected:
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册