提交 0f86397d 编写于 作者: T typhoonzero

fix build

上级 17009d06
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#ifdef PADDLE_WITH_CUDA
#include <nccl.h> #include <nccl.h>
#endif
#include <sys/time.h> #include <sys/time.h>
#include <thread> // NOLINT #include <thread> // NOLINT
...@@ -51,10 +53,12 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -51,10 +53,12 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kTypeFieldNumber, 0); e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 1); e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) { } else if (var->IsType<ncclUniqueId>()) {
// NOTE: sendrecv only support RAW type for NCCL_ID // NOTE: sendrecv only support RAW type for NCCL_ID
VLOG(3) << "serilizing: setting var type nccl id"; VLOG(3) << "serilizing: setting var type nccl id";
e.WriteUint64(VarMsg::kTypeFieldNumber, 2); e.WriteUint64(VarMsg::kTypeFieldNumber, 2);
#endif
} }
if (!out_name.empty()) { if (!out_name.empty()) {
...@@ -141,17 +145,19 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -141,17 +145,19 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} }
payload_size = tensor->numel() * framework::SizeOfType(tensor->type()); payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) { } else if (var->IsType<ncclUniqueId>()) {
// ===========================NCCL ID================================== // ===========================NCCL ID==================================
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES); NCCL_UNIQUE_ID_BYTES);
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>(); ncclUniqueId* uid = var->GetMutable<ncclUniqueId>();
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES)); e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES));
#endif
} else { } else {
PADDLE_THROW("Serialize does not support type: %s", PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name()); typeid(var->Type()).name());
} }
#ifdef PADDLE_WITH_CUDA
if (var->IsType<ncclUniqueId>()) { if (var->IsType<ncclUniqueId>()) {
// for serialize NCCL_ID // for serialize NCCL_ID
::grpc::Slice slices(e.size()); ::grpc::Slice slices(e.size());
...@@ -160,7 +166,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -160,7 +166,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
msg->Swap(&tmp); msg->Swap(&tmp);
return; return;
} }
#endif
// steal reference of tensor data // steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer int num_slices = 2; // only SelectedRows have rows buffer
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
...@@ -378,19 +381,19 @@ int VariableResponse::Parse(Source* source) { ...@@ -378,19 +381,19 @@ int VariableResponse::Parse(Source* source) {
} }
if (meta_.type() == sendrecv::NCCL_ID) { if (meta_.type() == sendrecv::NCCL_ID) {
VLOG(3) << "parse nccl id request"; #ifdef PADDLE_WITH_CUDA
auto* var = scope_->FindVar(meta_.varname()); auto* var = scope_->FindVar(meta_.varname());
if (var != nullptr) { if (var != nullptr) {
VLOG(3) << "parse nccl id: length " << length;
ncclUniqueId* id = var->GetMutable<ncclUniqueId>(); ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal, if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
length)) { length)) {
return tag; return tag;
} }
// memcpy(id->internal, meta_.serialized().c_str(),
// meta_.serialized().size());
} }
break; break;
#else
PADDLE_THROW("Not compiled with CUDA!");
#endif
} }
framework::DDim dims = GetDims(meta_.dims()); framework::DDim dims = GetDims(meta_.dims());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册