diff --git a/benchmark/cluster/vgg16/vgg16_fluid.py b/benchmark/cluster/vgg16/vgg16_fluid.py index 05b5f3977cbed2f08df73c6d8ba2fff687db3313..e9360ab4c79d23bdf9f84d0c0d407af6d39bde3e 100644 --- a/benchmark/cluster/vgg16/vgg16_fluid.py +++ b/benchmark/cluster/vgg16/vgg16_fluid.py @@ -38,7 +38,7 @@ def str2bool(v): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - '--batch_size', type=int, default=128, help="Batch size for training.") + '--batch_size', type=int, default=16, help="Batch size for training.") parser.add_argument( '--learning_rate', type=float, @@ -61,7 +61,7 @@ parser.add_argument( parser.add_argument( '--data_set', type=str, - default='cifar10', + default='flowers', choices=['cifar10', 'flowers'], help='Optional dataset for benchmark.') parser.add_argument( @@ -200,26 +200,30 @@ def main(): fetch_list=[avg_cost, batch_acc, batch_size]) return loss, acc, b_size - if args.profile and args.task_index == 0: - # warmup. - for batch_id, data in enumerate(train_reader()): - if batch_id > 5: break - run_step(batch_id, data) - with profiler.profiler('All', 'total', '/tmp/profile_vgg'): + if args.profile: + with profiler.profiler('All', 'total', + '/tmp/profile_vgg_%d' % args.task_index): for batch_id, data in enumerate(train_reader()): if batch_id > 5: break run_step(batch_id, data) + total_time = 0.0 + count = 0 for batch_id, data in enumerate(train_reader()): ts = time.time() loss, acc, b_size = run_step(batch_id, data) iters += 1 num_samples += len(data) train_pass_acc.add(value=acc, weight=b_size) + + duration = time.time() - ts + total_time += duration + count += len(data) print( "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, " - "Speed = %.2f img/s" % (pass_id, iters, loss, acc, - len(data) / (time.time() - ts)) + "Speed = %.2f (%.2f) img/s" % (pass_id, iters, loss, acc, + len(data) / duration, + count / total_time) ) # The accuracy is the accumulation of batches, but not the current batch. pass_elapsed = time.time() - start_time diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index ae60ab15325ef101feb7270a4f5d840cb2112be0..47892b1bcc073d24ea617ea1c680138a88925177 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -196,9 +197,14 @@ bool RPCClient::Wait() { const size_t kReqCnt = req_count_; bool a[kReqCnt]; std::vector> waits(req_count_); + std::mutex mu; for (int i = 0; i < req_count_; i++) { - waits[i] = framework::AsyncIO([i, &a, this] { a[i] = Proceed(); }); + waits[i] = framework::AsyncIO([i, &a, &mu, this] { + bool ret = Proceed(); + std::lock_guard l(mu); + a[i] = ret; + }); } for (int i = 0; i < req_count_; i++) { diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index eb114a47d99541402f748bfffcf6b10fde3e78e2..58faead2bdf9a89749e08207d964836bbf5cb68e 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -19,10 +19,16 @@ limitations under the License. */ using ::grpc::ServerAsyncResponseWriter; +DEFINE_int32(rpc_server_handle_send_threads, 20, + "Number of threads used to handle send at rpc server."); +DEFINE_int32(rpc_server_handle_get_threads, 20, + "Number of threads used to handle get at rpc server."); +DEFINE_int32(rpc_server_handle_prefetch_threads, 1, + "Number of threads used to handle prefetch at rpc server."); + namespace paddle { namespace operators { namespace detail { - enum CallStatus { PROCESS = 0, FINISH }; // reference: @@ -63,18 +69,20 @@ class RequestSend final : public RequestBase { explicit RequestSend(GrpcService::AsyncService* service, ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, ReceivedQueue* queue, - const platform::DeviceContext* dev_ctx) + const platform::DeviceContext* dev_ctx, int req_id) : RequestBase(service, cq, sync_mode, dev_ctx), queue_(queue), - responder_(&ctx_) { + responder_(&ctx_), + req_id_(req_id) { if (sync_mode_) { request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { request_.reset(new VariableResponse(scope, dev_ctx_, true)); } int method_id = static_cast(detail::GrpcMethod::kSendVariable); - service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, - cq_, cq_, this); + service_->RequestAsyncUnary( + method_id, &ctx_, request_.get(), &responder_, cq_, cq_, + reinterpret_cast(static_cast(req_id))); } virtual ~RequestSend() {} @@ -86,15 +94,17 @@ class RequestSend final : public RequestBase { VLOG(3) << "RequestSend " << var_name; queue_->Push(std::make_pair(var_name, request_)); - sendrecv::VoidMessage reply; - responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; + responder_.Finish(reply_, ::grpc::Status::OK, + reinterpret_cast(static_cast(req_id_))); } protected: + sendrecv::VoidMessage reply_; std::shared_ptr request_; ReceivedQueue* queue_; ServerAsyncResponseWriter responder_; + int req_id_; }; class RequestGet final : public RequestBase { @@ -103,14 +113,17 @@ class RequestGet final : public RequestBase { ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, const platform::DeviceContext* dev_ctx, - framework::BlockingQueue* queue) + framework::BlockingQueue* queue, + int req_id) : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), - queue_(queue) { + queue_(queue), + req_id_(req_id) { auto method_id = static_cast(detail::GrpcMethod::kGetVariable); - service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, - cq_, this); + service_->RequestAsyncUnary( + method_id, &ctx_, &request_, &responder_, cq_, cq_, + reinterpret_cast(static_cast(req_id_))); } virtual ~RequestGet() {} @@ -123,13 +136,13 @@ class RequestGet final : public RequestBase { VLOG(3) << "RequestGet " << var_name; auto* var = scope_->FindVar(var_name); - ::grpc::ByteBuffer reply; if (var_name != FETCH_BARRIER_MESSAGE) { - SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); } - responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; + responder_.Finish(reply_, ::grpc::Status::OK, + reinterpret_cast(static_cast(req_id_))); if (var_name == FETCH_BARRIER_MESSAGE) { sendrecv::VariableMessage msg; @@ -140,9 +153,11 @@ class RequestGet final : public RequestBase { protected: sendrecv::VariableMessage request_; + ::grpc::ByteBuffer reply_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* scope_; framework::BlockingQueue* queue_; + int req_id_; }; class RequestPrefetch final : public RequestBase { @@ -153,21 +168,24 @@ class RequestPrefetch final : public RequestBase { const platform::DeviceContext* dev_ctx, framework::Executor* executor, framework::ProgramDesc* program, - framework::ExecutorPrepareContext* prefetch_ctx) + framework::ExecutorPrepareContext* prefetch_ctx, + int req_id) : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), executor_(executor), program_(program), - prefetch_ctx_(prefetch_ctx) { + prefetch_ctx_(prefetch_ctx), + req_id_(req_id) { if (sync_mode_) { request_.reset(new VariableResponse(scope, dev_ctx_, false)); } else { request_.reset(new VariableResponse(scope, dev_ctx_, true)); } int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); - service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, - cq_, cq_, this); + service_->RequestAsyncUnary( + method_id, &ctx_, request_.get(), &responder_, cq_, cq_, + reinterpret_cast(static_cast(req_id_))); } virtual ~RequestPrefetch() {} @@ -176,7 +194,6 @@ class RequestPrefetch final : public RequestBase { virtual void Process() { // prefetch process... - ::grpc::ByteBuffer reply; std::string var_name = request_->OutVarname(); VLOG(3) << "RequestPrefetch " << var_name; @@ -186,19 +203,22 @@ class RequestPrefetch final : public RequestBase { InitializeVariable(var, var_desc->GetType()); executor_->RunPreparedContext(prefetch_ctx_, scope_); - SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); - responder_.Finish(reply, ::grpc::Status::OK, this); status_ = FINISH; + responder_.Finish(reply_, ::grpc::Status::OK, + reinterpret_cast(static_cast(req_id_))); } protected: std::shared_ptr request_; + ::grpc::ByteBuffer reply_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* scope_; framework::Executor* executor_; framework::ProgramDesc* program_; framework::ExecutorPrepareContext* prefetch_ctx_; + int req_id_; }; void AsyncGRPCServer::WaitClientGet(int count) { @@ -232,24 +252,39 @@ void AsyncGRPCServer::RunSyncUpdate() { LOG(INFO) << "Server listening on " << address_ << " selected port: " << selected_port_; - std::function send_register = - std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this); - std::function get_register = - std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); - std::function prefetch_register = - std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this); - - // TODO(wuyi): Run these "HandleRequest" in thread pool - t_send_.reset( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, - cq_send_.get(), "cq_send", send_register))); - t_get_.reset( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, - cq_get_.get(), "cq_get", get_register))); - t_prefetch_.reset(new std::thread( - std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), - "cq_prefetch", prefetch_register))); + std::function send_register = std::bind( + &AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1); + std::function get_register = std::bind( + &AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1); + std::function prefetch_register = + std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this, + std::placeholders::_1); + for (int i = 0; i < kSendReqsBufSize; ++i) { + TryToRegisterNewSendOne(i); + } + for (int i = 0; i < kGetReqsBufSize; ++i) { + TryToRegisterNewGetOne(i); + } + for (int i = 0; i < kPrefetchReqsBufSize; ++i) { + TryToRegisterNewPrefetchOne(i); + } + + for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { + t_sends_.emplace_back( + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, + cq_send_.get(), "cq_send", send_register))); + } + for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { + t_gets_.emplace_back( + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, + cq_get_.get(), "cq_get", get_register))); + } + for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { + t_prefetchs_.emplace_back(new std::thread( + std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), + "cq_prefetch", prefetch_register))); + } { std::lock_guard lock(this->mutex_ready_); ready_ = 1; @@ -257,9 +292,15 @@ void AsyncGRPCServer::RunSyncUpdate() { condition_ready_.notify_all(); // wait server server_->Wait(); - t_send_->join(); - t_get_->join(); - t_prefetch_->join(); + for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { + t_sends_[i]->join(); + } + for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { + t_gets_[i]->join(); + } + for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { + t_prefetchs_[i]->join(); + } } void AsyncGRPCServer::ShutdownQueue() { @@ -276,47 +317,48 @@ void AsyncGRPCServer::ShutDown() { server_->Shutdown(); } -void AsyncGRPCServer::TryToRegisterNewSendOne() { +void AsyncGRPCServer::TryToRegisterNewSendOne(int i) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; return; } RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_, - scope_, &var_recv_queue_, dev_ctx_); + scope_, &var_recv_queue_, dev_ctx_, i); + send_reqs_[i] = static_cast(send); VLOG(4) << "Create RequestSend status:" << send->Status(); } -void AsyncGRPCServer::TryToRegisterNewGetOne() { +void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; return; } RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, - dev_ctx_, &var_get_queue_); + dev_ctx_, &var_get_queue_, req_id); + get_reqs_[req_id] = static_cast(get); VLOG(4) << "Create RequestGet status:" << get->Status(); } -void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { +void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) { std::unique_lock lock(cq_mutex_); if (is_shut_down_) { VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne"; return; } - RequestPrefetch* prefetch = - new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_, - dev_ctx_, executor_, program_, prefetch_ctx_.get()); + RequestPrefetch* prefetch = new RequestPrefetch( + &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_, + program_, prefetch_ctx_.get(), req_id); + prefetch_reqs_[req_id] = static_cast(prefetch); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } // FIXME(typhoonzero): change cq_name to enum. -void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, - const std::string& cq_name, - std::function TryToRegisterNewOne) { - TryToRegisterNewOne(); - +void AsyncGRPCServer::HandleRequest( + ::grpc::ServerCompletionQueue* cq, const std::string& cq_name, + std::function TryToRegisterNewOne) { void* tag = NULL; bool ok = false; @@ -327,8 +369,7 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, break; } VLOG(3) << "HandleRequest for " << cq_name << " get Next"; - - PADDLE_ENFORCE(tag); + int req_id = static_cast(reinterpret_cast(tag)); if (sync_mode_) { // FIXME(typhoonzero): de-couple the barriers with recv_op @@ -337,7 +378,17 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond"; } - RequestBase* base = reinterpret_cast(tag); + RequestBase* base = nullptr; + { + std::lock_guard l(cq_mutex_); + if (cq_name == "cq_get") { + base = get_reqs_[req_id]; + } else if (cq_name == "cq_send") { + base = send_reqs_[req_id]; + } else if (cq_name == "cq_prefetch") { + base = prefetch_reqs_[req_id]; + } + } // reference: // https://github.com/tensorflow/tensorflow/issues/5596 // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM @@ -345,19 +396,19 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, if (!ok) { LOG(WARNING) << cq_name << " recv no regular event:argument name[" << base->GetReqName() << "]"; - TryToRegisterNewOne(); + TryToRegisterNewOne(req_id); delete base; continue; } switch (base->Status()) { case PROCESS: { - TryToRegisterNewOne(); base->Process(); VLOG(4) << cq_name << " PROCESS status:" << base->Status(); break; } case FINISH: { + TryToRegisterNewOne(req_id); VLOG(4) << cq_name << " FINISH status:" << base->Status(); delete base; break; diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 238aaa29634a7eff65429c27aa3538a185723eb2..bdff9801a928699f8391bfb68c1c7bd2d75aa642 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include // NOLINT #include +#include #include "grpc++/grpc++.h" #include "paddle/fluid/framework/blocking_queue.h" @@ -30,6 +31,7 @@ limitations under the License. */ #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -82,19 +84,27 @@ class AsyncGRPCServer final { protected: void HandleRequest(::grpc::ServerCompletionQueue *cq, const std::string &cq_name, - std::function TryToRegisterNewOne); - void TryToRegisterNewSendOne(); - void TryToRegisterNewGetOne(); - void TryToRegisterNewPrefetchOne(); + std::function TryToRegisterNewOne); + void TryToRegisterNewSendOne(int req_id); + void TryToRegisterNewGetOne(int req_id); + void TryToRegisterNewPrefetchOne(int req_id); void ShutdownQueue(); private: + static const int kSendReqsBufSize = 100; + static const int kGetReqsBufSize = 100; + static const int kPrefetchReqsBufSize = 10; + std::mutex cq_mutex_; volatile bool is_shut_down_ = false; std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_; + RequestBase *send_reqs_[kSendReqsBufSize]; + RequestBase *get_reqs_[kGetReqsBufSize]; + RequestBase *prefetch_reqs_[kPrefetchReqsBufSize]; + GrpcService::AsyncService service_; std::unique_ptr<::grpc::Server> server_; @@ -113,8 +123,10 @@ class AsyncGRPCServer final { mutable int barrier_cond_step_; std::condition_variable barrier_condition_; - std::unique_ptr t_send_; - std::unique_ptr t_get_; + std::vector> t_sends_; + std::vector> t_gets_; + std::vector> t_prefetchs_; + std::unique_ptr t_prefetch_; std::unique_ptr prefetch_ctx_; diff --git a/paddle/fluid/operators/detail/grpc_service.h b/paddle/fluid/operators/detail/grpc_service.h index e6dab2f5a3a4280f3979417c3ca2d884a0b8ff2f..e0505c2b9d0903837713d7e0032b01ab091c2e04 100644 --- a/paddle/fluid/operators/detail/grpc_service.h +++ b/paddle/fluid/operators/detail/grpc_service.h @@ -25,6 +25,8 @@ #include #include "paddle/fluid/operators/detail/variable_response.h" +#include "paddle/fluid/platform/profiler.h" + // NOTE: This method was originally created by tensorflow // (https://github.com/tensorflow/tensorflow/) we borrow this // method and did some modifications so that we can parse gRPC diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index 9478c5702bcbf99fc88207b8c4843dbccf8a5925..a244afc46f3247c7e6e8481b09b5c729a2a569f7 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -70,10 +70,10 @@ message VariableMessage { bytes rows = 9; // Look up table block execution output variable name. string out_varname = 10; - // If true, the ps server will start profiling, the ps + // If 1, the ps server will start profiling, the ps // server stops profiling and generates a profile to /tmp/profile_ps_* - // when profile switches from true to false. - bool profile = 11; + // when profile switches from 1 to 2. + int64 profile = 11; } message VoidMessage {} diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index e6ee598db04dd9e0075b39a50d1d4e878d73086d..3bae56532d655a1725e18276e09e0cade47b5c68 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -123,7 +123,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, // 1 trainer returns true for ShouldSendProfileState(). It tells PS // servers the trainer's profiling state so that PS can follow the // trainer. - request.set_profile(platform::IsProfileEnabled()); + if (platform::ShouldSendProfileState()) { + if (platform::IsProfileEnabled()) { + request.set_profile(platform::kEnableProfiler); + } else { + request.set_profile(platform::kDisableProfiler); + } + } if (!out_name.empty()) { request.set_out_varname(out_name); } diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 462e303096e609c6797ca8cc16266ec3621623fc..24cb91a3bb820a0e5d51aaa49154434919080f69 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -449,8 +449,8 @@ int VariableResponse::Parse(Source* source) { break; } case sendrecv::VariableMessage::kProfileFieldNumber: { - bool profiling; - if (!input.ReadRaw(reinterpret_cast(&profiling), 1)) { + uint64_t profiling = 0; + if (!input.ReadVarint64(&profiling)) { return tag; } meta_.set_profile(profiling); @@ -458,9 +458,11 @@ int VariableResponse::Parse(Source* source) { if (listener_id <= 0) { break; } - if (profiling && !platform::IsProfileEnabled()) { + if (profiling == platform::kEnableProfiler && + !platform::IsProfileEnabled()) { platform::EnableProfiler(platform::ProfilerState::kCPU); - } else if (!profiling && platform::IsProfileEnabled()) { + } else if (profiling == platform::kDisableProfiler && + platform::IsProfileEnabled()) { // TODO(panyx0718): Should we allow to customize file dir. platform::DisableProfiler( platform::EventSortingKey::kDefault, diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index c9e10631680a6ea3876f555a3a6e6c12f79b39d5..1a9be044e024e4b1dda5ef7d515c65f3a7513710 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -245,7 +245,6 @@ class DeviceTracerImpl : public DeviceTracer { void Enable() { std::lock_guard l(trace_mu_); if (enabled_) { - fprintf(stderr, "DeviceTracer already enabled\n"); return; } EnableActivity(); diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index 643bb6183d144ec11a4890d9ea1ca970acb08b4c..bf43925373a12cd9ff2155d68c42d0266ba4df60 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -116,6 +116,8 @@ void ResetProfiler(); void DisableProfiler(EventSortingKey sorted_key, const std::string& profile_path); +const int kEnableProfiler = 1; +const int kDisableProfiler = 2; // Test if the profiler is currently enabled. bool IsProfileEnabled(); // Whether the trainer should send profiling state to PS.