提交 0f86397d 编写于 作者: T typhoonzero

fix build

上级 17009d06
......@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
......@@ -51,10 +53,12 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
} else if (var->IsType<framework::SelectedRows>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
// NOTE: sendrecv only support RAW type for NCCL_ID
VLOG(3) << "serilizing: setting var type nccl id";
e.WriteUint64(VarMsg::kTypeFieldNumber, 2);
#endif
}
if (!out_name.empty()) {
......@@ -141,17 +145,19 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
}
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
// ===========================NCCL ID==================================
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES);
ncclUniqueId* uid = var->GetMutable<ncclUniqueId>();
e.WriteRawBytes(std::string(uid->internal, NCCL_UNIQUE_ID_BYTES));
#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
}
#ifdef PADDLE_WITH_CUDA
if (var->IsType<ncclUniqueId>()) {
// for serialize NCCL_ID
::grpc::Slice slices(e.size());
......@@ -160,7 +166,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
msg->Swap(&tmp);
return;
}
#endif
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer
......
......@@ -17,6 +17,9 @@
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
......@@ -378,19 +381,19 @@ int VariableResponse::Parse(Source* source) {
}
if (meta_.type() == sendrecv::NCCL_ID) {
VLOG(3) << "parse nccl id request";
#ifdef PADDLE_WITH_CUDA
auto* var = scope_->FindVar(meta_.varname());
if (var != nullptr) {
VLOG(3) << "parse nccl id: length " << length;
ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
length)) {
return tag;
}
// memcpy(id->internal, meta_.serialized().c_str(),
// meta_.serialized().size());
}
break;
#else
PADDLE_THROW("Not compiled with CUDA!");
#endif
}
framework::DDim dims = GetDims(meta_.dims());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册