提交 9e4e9e9b 编写于 作者: Q Qiao Longfei

clean rpc server profiler

上级 913b5699
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
#include "paddle/fluid/operators/distributed/grpc_variable_response.h" #include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_string(rpc_server_profile_path, "/tmp/profile_ps",
"the profile log file path");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
...@@ -289,7 +292,8 @@ int GRPCVariableResponse::Parse(Source* source) { ...@@ -289,7 +292,8 @@ int GRPCVariableResponse::Parse(Source* source) {
// TODO(panyx0718): Should we allow to customize file dir. // TODO(panyx0718): Should we allow to customize file dir.
platform::DisableProfiler( platform::DisableProfiler(
platform::EventSortingKey::kDefault, platform::EventSortingKey::kDefault,
string::Sprintf("/tmp/profile_ps_%lld", listener_id)); string::Sprintf("%s_%lld", FLAGS_rpc_server_profile_path,
listener_id));
} }
break; break;
} }
......
...@@ -50,7 +50,6 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -50,7 +50,6 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Async // Async
if (!sync_mode_) { if (!sync_mode_) {
VLOG(3) << "async process var: " << varname; VLOG(3) << "async process var: " << varname;
rpc_server_->Profiler().OneStep();
try { try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope); scope);
......
...@@ -20,42 +20,10 @@ ...@@ -20,42 +20,10 @@
#include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_int32(rpc_server_profile_period, 0,
"the period of listen_and_serv to do profile");
DEFINE_string(rpc_server_profile_path, "/dev/null",
"the profile log file path");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
RPCServerProfiler::RPCServerProfiler(int profile_period,
const std::string& profile_log_path)
: profile_period_(profile_period), profile_log_path_(profile_log_path) {
step_ = 0;
}
void RPCServerProfiler::OneStep() {
PADDLE_ENFORCE_LE(step_, profile_period_,
"step_ should not be larger then "
"profile_period_");
if (profile_period_ <= 0) {
return;
}
if (step_ == 0) {
auto pf_state = paddle::platform::ProfilerState::kCPU;
paddle::platform::EnableProfiler(pf_state);
}
if (step_ == profile_period_) {
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kTotal,
profile_log_path_);
step_ = 0;
} else {
step_++;
}
}
void RPCServer::ShutDown() { void RPCServer::ShutDown() {
LOG(INFO) << "RPCServer ShutDown "; LOG(INFO) << "RPCServer ShutDown ";
ShutDownImpl(); ShutDownImpl();
......
...@@ -23,30 +23,16 @@ ...@@ -23,30 +23,16 @@
#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
DECLARE_int32(rpc_server_profile_period);
DECLARE_string(rpc_server_profile_path); DECLARE_string(rpc_server_profile_path);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
class RPCServerProfiler {
public:
RPCServerProfiler(int profile_period, const std::string& profile_log_path);
void OneStep();
private:
const int profile_period_;
std::string profile_log_path_;
int step_;
};
class RPCServer { class RPCServer {
public: public:
explicit RPCServer(const std::string& address, int client_num) explicit RPCServer(const std::string& address, int client_num)
: cur_cond_(0), : cur_cond_(0),
profiler_(FLAGS_rpc_server_profile_period,
FLAGS_rpc_server_profile_path),
bind_address_(address), bind_address_(address),
exit_flag_(false), exit_flag_(false),
selected_port_(0), selected_port_(0),
...@@ -86,7 +72,6 @@ class RPCServer { ...@@ -86,7 +72,6 @@ class RPCServer {
void Complete(); void Complete();
void ResetBarrierCounter(); void ResetBarrierCounter();
RPCServerProfiler& Profiler() { return profiler_; }
bool NeedResetAllVars(); bool NeedResetAllVars();
...@@ -101,7 +86,6 @@ class RPCServer { ...@@ -101,7 +86,6 @@ class RPCServer {
std::unordered_map<std::string, int> rpc_cond_map_; std::unordered_map<std::string, int> rpc_cond_map_;
std::atomic<int> cur_cond_; std::atomic<int> cur_cond_;
std::condition_variable rpc_cond_; std::condition_variable rpc_cond_;
RPCServerProfiler profiler_;
protected: protected:
std::string bind_address_; std::string bind_address_;
......
...@@ -134,7 +134,6 @@ void ListenAndServOp::RunSyncLoop( ...@@ -134,7 +134,6 @@ void ListenAndServOp::RunSyncLoop(
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
while (true) { while (true) {
rpc_service_->Profiler().OneStep();
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(distributed::kRequestSend); rpc_service_->SetCond(distributed::kRequestSend);
......
...@@ -226,7 +226,7 @@ RecordBlock::~RecordBlock() { ...@@ -226,7 +226,7 @@ RecordBlock::~RecordBlock() {
void EnableProfiler(ProfilerState state) { void EnableProfiler(ProfilerState state) {
PADDLE_ENFORCE(state != ProfilerState::kDisabled, PADDLE_ENFORCE(state != ProfilerState::kDisabled,
"Can't enbale profling, since the input state is ", "Can't enable profiling, since the input state is ",
"ProfilerState::kDisabled"); "ProfilerState::kDisabled");
std::lock_guard<std::mutex> l(profiler_mu); std::lock_guard<std::mutex> l(profiler_mu);
......
...@@ -118,7 +118,6 @@ def __bootstrap__(): ...@@ -118,7 +118,6 @@ def __bootstrap__():
] ]
if core.is_compiled_with_dist(): if core.is_compiled_with_dist():
read_env_flags.append('rpc_deadline') read_env_flags.append('rpc_deadline')
read_env_flags.append('rpc_server_profile_period')
read_env_flags.append('rpc_server_profile_path') read_env_flags.append('rpc_server_profile_path')
read_env_flags.append('enable_rpc_profiler') read_env_flags.append('enable_rpc_profiler')
read_env_flags.append('rpc_send_thread_num') read_env_flags.append('rpc_send_thread_num')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册