提交 08e4970e 编写于 作者: X Xin Pan

follow comments

上级 a848303e
......@@ -19,14 +19,16 @@ limitations under the License. */
using ::grpc::ServerAsyncResponseWriter;
DEFINE_int32(rpc_server_handle_send_threads, 20,
"Number of threads used to handle send at rpc server.");
DEFINE_int32(rpc_server_handle_get_threads, 20,
"Number of threads used to handle get at rpc server.");
DEFINE_int32(rpc_server_handle_prefetch_threads, 1,
"Number of threads used to handle prefetch at rpc server.");
namespace paddle {
namespace operators {
namespace detail {
namespace {
const int kNumHandleSendThreads = 20;
const int kNumHandleGetThreads = 20;
const int kNumHandlePrefetchThreads = 1;
} // namespace
enum CallStatus { PROCESS = 0, FINISH };
// reference:
......@@ -268,17 +270,17 @@ void AsyncGRPCServer::RunSyncUpdate() {
TryToRegisterNewPrefetchOne(i);
}
for (int i = 0; i < kNumHandleSendThreads; ++i) {
for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
t_sends_.emplace_back(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register)));
}
for (int i = 0; i < kNumHandleGetThreads; ++i) {
for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
t_gets_.emplace_back(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));
}
for (int i = 0; i < kNumHandlePrefetchThreads; ++i) {
for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) {
t_prefetchs_.emplace_back(new std::thread(
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
"cq_prefetch", prefetch_register)));
......@@ -290,13 +292,13 @@ void AsyncGRPCServer::RunSyncUpdate() {
condition_ready_.notify_all();
// wait server
server_->Wait();
for (int i = 0; i < kNumHandleSendThreads; ++i) {
for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
t_sends_[i]->join();
}
for (int i = 0; i < kNumHandleGetThreads; ++i) {
for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
t_gets_[i]->join();
}
for (int i = 0; i < kNumHandlePrefetchThreads; ++i) {
for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) {
t_prefetchs_[i]->join();
}
}
......
......@@ -85,9 +85,9 @@ class AsyncGRPCServer final {
void HandleRequest(::grpc::ServerCompletionQueue *cq,
const std::string &cq_name,
std::function<void(int)> TryToRegisterNewOne);
void TryToRegisterNewSendOne(int i);
void TryToRegisterNewGetOne(int i);
void TryToRegisterNewPrefetchOne(int i);
void TryToRegisterNewSendOne(int req_id);
void TryToRegisterNewGetOne(int req_id);
void TryToRegisterNewPrefetchOne(int req_id);
void ShutdownQueue();
private:
......
......@@ -31,10 +31,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace detail {
namespace {
const int kStartProfile = 1;
const int kStopProfile = 2;
} // namespace
using VarMsg = sendrecv::VariableMessage;
......@@ -128,9 +124,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
// trainer.
if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request.set_profile(kStartProfile);
request.set_profile(platform::kEnableProfiler);
} else {
request.set_profile(kStopProfile);
request.set_profile(platform::kDisableProfiler);
}
}
if (!out_name.empty()) {
......
......@@ -458,9 +458,11 @@ int VariableResponse::Parse(Source* source) {
if (listener_id <= 0) {
break;
}
if (profiling == 1 && !platform::IsProfileEnabled()) {
if (profiling == platform::kEnableProfiler &&
!platform::IsProfileEnabled()) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
} else if (profiling == 2 && platform::IsProfileEnabled()) {
} else if (profiling == platform::kDisableProfiler &&
platform::IsProfileEnabled()) {
// TODO(panyx0718): Should we allow to customize file dir.
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
......
......@@ -116,6 +116,8 @@ void ResetProfiler();
void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path);
const int kEnableProfiler = 1;
const int kDisableProfiler = 2;
// Test if the profiler is currently enabled.
bool IsProfileEnabled();
// Whether the trainer should send profiling state to PS.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册