diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 661dfa69fe1580ff3890f12defcd124225be0c06..ae60ab15325ef101feb7270a4f5d840cb2112be0 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, // stub context SendProcessor* s = new SendProcessor(ch); s->Prepare(var_h, time_out); - s->response_call_back_ = NULL; + s->response_call_back_ = nullptr; auto call = s->stub_g_.PrepareUnaryCall( s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index f6229b71bc01a6de51f50f5fe880ada6e15e74dd..dabce7414d2f0dca74193f1cd10c341793c10ec9 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); class BaseProcessor { public: - explicit BaseProcessor(std::shared_ptr ch) { context_ = NULL; } + explicit BaseProcessor(std::shared_ptr ch) { + context_ = nullptr; + } virtual ~BaseProcessor() {} @@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor { ::grpc::GenericStub stub_g_; ::grpc::ByteBuffer reply_; - RequestSendCallBack response_call_back_ = NULL; + RequestSendCallBack response_call_back_ = nullptr; }; typedef std::function diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 7ca694886e9209a49e214352f5babc473a1f275a..1cdfe01170cf9a39f2dea85696642058d9cc81f0 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -261,8 +261,8 @@ void AsyncGRPCServer::ShutdownQueue() { // This URL explains why shutdown is complicate: void AsyncGRPCServer::ShutDown() { is_shut_down_ = true; - ShutdownQueue(); server_->Shutdown(); + ShutdownQueue(); } void AsyncGRPCServer::TryToRegisterNewSendOne() { diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 99b87b8c6cb3e597778b88c395e4abf400d82c39..0e1592eed41dc4f4a8666695ba32b26b9f6f7b15 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -47,6 +47,8 @@ class AsyncGRPCServer final { explicit AsyncGRPCServer(const std::string &address, bool sync_mode) : address_(address), sync_mode_(sync_mode) {} + ~AsyncGRPCServer() {} + void RunSyncUpdate(); // functions to sync server barrier status. diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index d754630fd78f90423ed58fbf1791432873f0f7ef..207ea3cb8b34e4c99b0bb148d4b2f29cd5969b98 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -53,109 +53,106 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, e.WriteUint64(VarMsg::kTypeFieldNumber, 1); } else if (var->IsType()) { // NOTE: sendrecv only support RAW type for NCCL_ID + VLOG(3) << "serilizing: setting var type nccl id"; e.WriteUint64(VarMsg::kTypeFieldNumber, 2); } if (!out_name.empty()) { e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name); } - switch (framework::ToVarType(var->Type())) { - case framework::proto::VarType_Type_LOD_TENSOR: { - auto tensor = var->Get(); - e.WriteUint64(VarMsg::kDataTypeFieldNumber, - framework::ToDataType(tensor.type())); - for (auto& dim : framework::vectorize(tensor.dims())) { - e.WriteUint64(VarMsg::kDimsFieldNumber, dim); - } - auto lod = tensor.lod(); // std::vector> - if (lod.size() > 0) { - e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size()); - - for (auto& each : lod) { - e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber, - 2 + // tag + varintlength of submessage - 1 + // kLodDataFieldNumber - each.size()); - // auto copied from GPU - for (auto& d : each) { - e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d); - } + if (var->IsType()) { + // ===========================Tensor================================== + auto tensor = var->Get(); + e.WriteUint64(VarMsg::kDataTypeFieldNumber, + framework::ToDataType(tensor.type())); + for (auto& dim : framework::vectorize(tensor.dims())) { + e.WriteUint64(VarMsg::kDimsFieldNumber, dim); + } + auto lod = tensor.lod(); // std::vector> + if (lod.size() > 0) { + e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size()); + + for (auto& each : lod) { + e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber, + 2 + // tag + varintlength of submessage + 1 + // kLodDataFieldNumber + each.size()); + // auto copied from GPU + for (auto& d : each) { + e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d); } } - if (platform::is_gpu_place(ctx.GetPlace())) { + } + if (platform::is_gpu_place(ctx.GetPlace())) { #ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE(platform::is_gpu_place(tensor.place())); + PADDLE_ENFORCE(platform::is_gpu_place(tensor.place())); + platform::CPUPlace cpu; + auto& gpu_dev_ctx = static_cast(ctx); + auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type()); + payload = memory::Alloc(cpu, copy_size); + + memory::Copy(cpu, payload, + boost::get(tensor.place()), + reinterpret_cast(tensor.data()), + copy_size, gpu_dev_ctx.stream()); + ctx.Wait(); + destroy_callback = [](void* backing) { platform::CPUPlace cpu; - auto& gpu_dev_ctx = - static_cast(ctx); - auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type()); - payload = memory::Alloc(cpu, copy_size); - - memory::Copy(cpu, payload, - boost::get(tensor.place()), - reinterpret_cast(tensor.data()), - copy_size, gpu_dev_ctx.stream()); - ctx.Wait(); - destroy_callback = [](void* backing) { - platform::CPUPlace cpu; - memory::Free(cpu, backing); - }; + memory::Free(cpu, backing); + }; #endif - } else { - payload = tensor.data(); - } - payload_size = tensor.numel() * framework::SizeOfType(tensor.type()); - e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); - } break; - case framework::proto::VarType_Type_SELECTED_ROWS: { - // TODO(typhoonzero): selectedrows implement should not use unique_ptr - auto* slr = var->GetMutable(); - e.WriteUint64(VarMsg::kDataTypeFieldNumber, - framework::ToDataType(slr->value().type())); - for (auto& dim : framework::vectorize(slr->value().dims())) { - e.WriteUint64(VarMsg::kDimsFieldNumber, dim); - } - e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0); - e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height()); - auto* tensor = slr->mutable_value(); - if (platform::is_gpu_place(ctx.GetPlace())) { + } else { + payload = tensor.data(); + } + payload_size = tensor.numel() * framework::SizeOfType(tensor.type()); + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); + } else if (var->IsType()) { + // ===========================SELECTED + // ROWS================================== + // TODO(typhoonzero): selectedrows implement should not use unique_ptr + auto* slr = var->GetMutable(); + e.WriteUint64(VarMsg::kDataTypeFieldNumber, + framework::ToDataType(slr->value().type())); + for (auto& dim : framework::vectorize(slr->value().dims())) { + e.WriteUint64(VarMsg::kDimsFieldNumber, dim); + } + e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0); + e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height()); + auto* tensor = slr->mutable_value(); + if (platform::is_gpu_place(ctx.GetPlace())) { #ifdef PADDLE_WITH_CUDA + platform::CPUPlace cpu; + auto& gpu_dev_ctx = static_cast(ctx); + auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type()); + payload = memory::Alloc(cpu, copy_size); + memory::Copy(cpu, payload, + boost::get(tensor->place()), + reinterpret_cast(tensor->data()), + copy_size, gpu_dev_ctx.stream()); + ctx.Wait(); + destroy_callback = [](void* backing) { platform::CPUPlace cpu; - auto& gpu_dev_ctx = - static_cast(ctx); - auto copy_size = - tensor->numel() * framework::SizeOfType(tensor->type()); - payload = memory::Alloc(cpu, copy_size); - memory::Copy(cpu, payload, - boost::get(tensor->place()), - reinterpret_cast(tensor->data()), - copy_size, gpu_dev_ctx.stream()); - ctx.Wait(); - destroy_callback = [](void* backing) { - platform::CPUPlace cpu; - memory::Free(cpu, backing); - }; + memory::Free(cpu, backing); + }; #endif - } else { - payload = slr->mutable_value()->data(); - } - payload_size = tensor->numel() * framework::SizeOfType(tensor->type()); - e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); - } break; - case framework::proto::VarType_Type_RAW: { - e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, - NCCL_UNIQUE_ID_BYTES); - ncclUniqueId* uid = var->GetMutable(); - e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES)); - } break; - default: - PADDLE_THROW("Serialize does not support type: %s", - typeid(var->Type()).name()); - break; + } else { + payload = slr->mutable_value()->data(); + } + payload_size = tensor->numel() * framework::SizeOfType(tensor->type()); + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); + } else if (var->IsType()) { + // ===========================NCCL ID================================== + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, + NCCL_UNIQUE_ID_BYTES); + ncclUniqueId* uid = var->GetMutable(); + e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES)); + } else { + PADDLE_THROW("Serialize does not support type: %s", + typeid(var->Type()).name()); } - if (framework::ToVarType(var->Type()) == framework::proto::VarType_Type_RAW) { + if (var->IsType()) { // for serialize NCCL_ID ::grpc::Slice slices(e.size()); memcpy(const_cast(slices.begin()), e.data(), e.size()); diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 81d755f5fcfbf501a65209a320ce0a0097ff659c..64fd84736dc9c38256b739e96474367628aa8b19 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -371,19 +371,26 @@ int VariableResponse::Parse(Source* source) { meta_.type() == sendrecv::NCCL_ID) && meta_.varname() != "", "meta info should be got first!"); + int length = 0; + if (wt != WIRETYPE_LENGTH_DELIMITED || + !ReadVarintSizeAsInt(&input, &length)) { + return tag; + } + if (meta_.type() == sendrecv::NCCL_ID) { + VLOG(3) << "parse nccl id request"; auto* var = scope_->FindVar(meta_.varname()); if (var != nullptr) { + VLOG(3) << "parse nccl id: length " << length; ncclUniqueId* id = var->GetMutable(); - memcpy(id->internal, meta_.serialized().c_str(), - meta_.serialized().size()); + if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal, + length)) { + return tag; + } + // memcpy(id->internal, meta_.serialized().c_str(), + // meta_.serialized().size()); } - } - - int length = 0; - if (wt != WIRETYPE_LENGTH_DELIMITED || - !ReadVarintSizeAsInt(&input, &length)) { - return tag; + break; } framework::DDim dims = GetDims(meta_.dims()); diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc index afb228fa6ffa3a7230d878913239650eb6574fbc..8d28be35a874c3312db2c8a8172dbfed6f0168e5 100644 --- a/paddle/fluid/operators/gen_nccl_id_op.cc +++ b/paddle/fluid/operators/gen_nccl_id_op.cc @@ -37,7 +37,8 @@ class GenNCCLIdOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& dev_ctx = *pool.Get(dev_place); + // put nccl id in CPUPlace + auto& dev_ctx = *pool.Get(platform::CPUPlace()); int trainer_id = Attr("trainer_id"); framework::Scope& local_scope = scope.NewScope(); @@ -60,9 +61,11 @@ class GenNCCLIdOp : public framework::OperatorBase { Attr>("endpoint_list"); detail::RPCClient client; for (auto& ep : endpoint_list) { + VLOG(3) << "sending nccl id to " << ep; client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID"); } client.Wait(); + VLOG(3) << "sending completed..."; } void GetIdByServer(framework::Scope* scope, @@ -78,9 +81,14 @@ class GenNCCLIdOp : public framework::OperatorBase { server_thread_.reset(new std::thread(std::bind( &detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_.get()))); - + rpc_service_->SetCond(0); + VLOG(3) << "start getting nccl id from trainer 0..."; auto recv = rpc_service_->Get(); - rpc_service_->ShutDown(); + VLOG(3) << "got nccl id and stop server..."; + // rpc_service_->SetCond(1); + // rpc_service_->ShutDown(); + rpc_service->Push(LISTEN_TERMINATE_MESSAGE); + VLOG(3) << "rpc server stopped"; // TODO(wuyi): reinit nccl communicators }