提交 dbf9f6f4 编写于 作者: Y Yu Yang

Fix distribute compile

test=develop
上级 1d4d4e73
......@@ -4,6 +4,7 @@ paddle/operators/tensor.save
python/paddle/v2/fluid/tests/book/image_classification_resnet.inference.model/
python/paddle/v2/fluid/tests/book/image_classification_vgg.inference.model/
python/paddle/v2/fluid/tests/book/label_semantic_roles.inference.model/
paddle/fluid/operators/distributed/send_recv.proto
*.DS_Store
*.vs
build/
......
......@@ -155,6 +155,8 @@ class Tensor {
void clear() { holder_ = nullptr; }
const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; }
private:
/*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_;
......
......@@ -32,17 +32,21 @@ namespace paddle {
namespace operators {
namespace distributed {
static void SerializeDestroyCallback(void* payload) {
if (payload != nullptr) {
auto* shared_payload =
reinterpret_cast<std::shared_ptr<memory::Allocation>*>(payload);
delete shared_payload;
}
}
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
platform::RecordRPCEvent record_event("serial", &ctx);
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {};
VarMsg request;
void* payload = nullptr;
size_t payload_size;
std::shared_ptr<memory::Allocation>* payload = nullptr;
request.set_varname(name);
// Note: normally the profiler is enabled in 1 trainer, hence only
......@@ -61,10 +65,12 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
}
if (var->IsType<framework::LoDTensor>()) {
request.set_type(::sendrecv::LOD_TENSOR);
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
payload = new std::shared_ptr<memory::Allocation>(
GetTensorPayload(var, ctx, &request));
} else if (var->IsType<framework::SelectedRows>()) {
request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
payload = new std::shared_ptr<memory::Allocation>(
GetSelectedRowsPayload(var, ctx, &request));
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
request.set_type(::sendrecv::NCCL_ID);
......@@ -74,17 +80,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
typeid(var->Type()).name());
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback = [](void* backing) {
platform::CUDAPinnedPlace cuda_pinned;
memory::Free(cuda_pinned, backing);
};
#endif
}
std::string header;
request.AppendToString(&header);
auto buffer = std::unique_ptr<char[]>(new char[1024]);
......@@ -108,16 +103,18 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
return;
}
#endif
PADDLE_ENFORCE_NOT_NULL(payload);
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
payload->get()->size());
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer
slices[0] = ::grpc::Slice(e.size());
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
slices[1] = ::grpc::Slice(
grpc_slice_new_with_user_data(payload, payload_size, destroy_callback,
static_cast<char*>(payload)),
slices[1] = ::grpc::Slice(grpc_slice_new_with_user_data(
payload->get()->ptr(), payload->get()->size(),
SerializeDestroyCallback, payload),
::grpc::Slice::STEAL_REF);
if (var->IsType<framework::SelectedRows>()) {
......
......@@ -28,16 +28,35 @@ namespace distributed {
using VarMsg = sendrecv::VariableMessage;
static std::shared_ptr<memory::Allocation> GetCommunicationAllocationFromTensor(
const platform::DeviceContext& ctx, const framework::Tensor& tensor) {
if (is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
void* GetVarPayLoad(const std::string varname, int64_t size) {
PADDLE_ENFORCE(is_gpu_place(tensor.place()));
auto& gpu_dev_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
platform::CUDAPinnedPlace cuda_pinned;
return memory::Alloc(cuda_pinned, size);
}
#endif
auto result = memory::AllocShared(
cuda_pinned, copy_size, memory::allocation::Allocator::kCrossDevice);
memory::Copy(cuda_pinned, result->ptr(),
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()), copy_size,
gpu_dev_ctx.stream());
void GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) {
ctx.Wait();
return result;
#else
return nullptr; // THIS SHOULD NOT HAPPENED.
#endif
} else {
return tensor.Holder();
}
}
std::shared_ptr<memory::Allocation> GetTensorPayload(
framework::Variable* var, const platform::DeviceContext& ctx,
VarMsg* request) {
auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto
......@@ -56,31 +75,12 @@ void GetTensorPayload(framework::Variable* var,
}
}
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
// platform::CUDAPinnedPlace cuda_pinned;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
*payload = GetVarPayLoad(request->varname(), copy_size);
platform::CUDAPinnedPlace cuda_pinned;
memory::Copy(cuda_pinned, *payload,
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()), copy_size,
gpu_dev_ctx.stream());
ctx.Wait();
#endif
} else {
*payload = tensor.data<void>();
}
*payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
return GetCommunicationAllocationFromTensor(ctx, tensor);
}
void GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) {
std::shared_ptr<memory::Allocation> GetSelectedRowsPayload(
framework::Variable* var, const platform::DeviceContext& ctx,
VarMsg* request) {
auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
......@@ -92,23 +92,7 @@ void GetSelectedRowsPayload(framework::Variable* var,
}
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
*payload = GetVarPayLoad(request->varname(), copy_size);
platform::CUDAPinnedPlace cuda_pinned;
memory::Copy(cuda_pinned, *payload,
boost::get<platform::CUDAPlace>(tensor->place()),
reinterpret_cast<const void*>(tensor->data<void>()), copy_size,
gpu_dev_ctx.stream());
ctx.Wait();
#endif
} else {
*payload = slr->mutable_value()->data<void>();
}
*payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
return GetCommunicationAllocationFromTensor(ctx, *tensor);
}
} // namespace distributed
......
......@@ -33,13 +33,13 @@ namespace distributed {
using VarMsg = sendrecv::VariableMessage;
void GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size);
std::shared_ptr<memory::Allocation> GetTensorPayload(
framework::Variable* var, const platform::DeviceContext& ctx,
VarMsg* request);
void GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size);
std::shared_ptr<memory::Allocation> GetSelectedRowsPayload(
framework::Variable* var, const platform::DeviceContext& ctx,
VarMsg* request);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
switch (type) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册