提交 45363069 编写于 作者: Q qiaolongfei

fix prefetch hang problem, add some more logs

上级 fdecae5f
...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include <sys/time.h>
#include <limits>
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
namespace paddle { namespace paddle {
...@@ -52,7 +54,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -52,7 +54,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
}); });
req_count_++; req_count_++;
...@@ -109,7 +111,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -109,7 +111,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_); s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
}); });
req_count_++; req_count_++;
...@@ -153,7 +155,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, ...@@ -153,7 +155,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_); &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
}); });
req_count_++; req_count_++;
...@@ -169,7 +171,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { ...@@ -169,7 +171,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE); req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s); rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
req_count_++; req_count_++;
} }
...@@ -181,7 +183,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { ...@@ -181,7 +183,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE); req.set_varname(FETCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s); rpc->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
req_count_++; req_count_++;
} }
......
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include <paddle/fluid/operators/detail/send_recv.pb.h>
#include <limits>
#include <string>
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
...@@ -224,6 +226,7 @@ void AsyncGRPCServer::ShutdownQueue() { ...@@ -224,6 +226,7 @@ void AsyncGRPCServer::ShutdownQueue() {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
cq_send_->Shutdown(); cq_send_->Shutdown();
cq_get_->Shutdown(); cq_get_->Shutdown();
cq_prefetch_->Shutdown();
} }
// This URL explains why shutdown is complicate: // This URL explains why shutdown is complicate:
...@@ -236,6 +239,7 @@ void AsyncGRPCServer::ShutDown() { ...@@ -236,6 +239,7 @@ void AsyncGRPCServer::ShutDown() {
void AsyncGRPCServer::TryToRegisterNewSendOne() { void AsyncGRPCServer::TryToRegisterNewSendOne() {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return; return;
} }
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
...@@ -246,6 +250,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { ...@@ -246,6 +250,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
void AsyncGRPCServer::TryToRegisterNewGetOne() { void AsyncGRPCServer::TryToRegisterNewGetOne() {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return; return;
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
...@@ -257,6 +262,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { ...@@ -257,6 +262,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
VLOG(4) << "TryToRegisterNewPrefetchOne in"; VLOG(4) << "TryToRegisterNewPrefetchOne in";
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
return; return;
} }
RequestPrefetch* prefetch = RequestPrefetch* prefetch =
...@@ -274,18 +280,21 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, ...@@ -274,18 +280,21 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
void* tag = NULL; void* tag = NULL;
bool ok = false; bool ok = false;
while (true) { while (true) {
VLOG(3) << "HandleRequest for " << cq_name << " while in";
if (!cq->Next(&tag, &ok)) { if (!cq->Next(&tag, &ok)) {
LOG(INFO) << cq_name << " CompletionQueue shutdown!"; LOG(INFO) << cq_name << " CompletionQueue shutdown!";
break; break;
} }
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
PADDLE_ENFORCE(tag); PADDLE_ENFORCE(tag);
// FIXME(typhoonzero): de-couple the barriers with recv_op // FIXME(typhoonzero): de-couple the barriers with recv_op
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1);
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
RequestBase* base = (RequestBase*)tag; RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference: // reference:
// https://github.com/tensorflow/tensorflow/issues/5596 // https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
......
...@@ -89,7 +89,7 @@ inline const char* GrpcMethodName(GrpcMethod id) { ...@@ -89,7 +89,7 @@ inline const char* GrpcMethodName(GrpcMethod id) {
case GrpcMethod::kGetVariable: case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable"; return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable: case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendREcvService/PrefetchVariable"; return "/sendrecv.SendRecvService/PrefetchVariable";
} }
// Shouldn't be reached. // Shouldn't be reached.
...@@ -117,5 +117,5 @@ class GrpcService final { ...@@ -117,5 +117,5 @@ class GrpcService final {
}; };
} // namespace detail } // namespace detail
} // namespace operator } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,22 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,22 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <stdint.h> #include <stdint.h>
#include <sys/stat.h>
#include <ostream> #include <ostream>
#include <thread>
#include <unistd.h>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -177,7 +168,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -177,7 +168,8 @@ class ListenAndServOp : public framework::OperatorBase {
} }
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope); ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts; VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts
<< "(ms)";
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next // sum the input sparse variables which rows is empty at the next
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册