diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 9b465b85b0a02ffe990ab669a22f78e923e24f99..8b7533ce712b0a01060842b6f71449ed6bd23e2c 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -187,7 +187,6 @@ bool TensorContainsInf(const framework::Tensor& tensor) { void TensorToStream(std::ostream& os, const Tensor& tensor, const platform::DeviceContext& dev_ctx) { - // TODO(typhoonzero): serialize to ostream { // the 1st field, uint32_t version constexpr uint32_t version = 0; os.write(reinterpret_cast(&version), sizeof(version)); diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index 0581bd2ac55218a2955fcb260d8b61cac0d210b5..94395ccfbcbd74ee40552a5c70dc8b8063a5f851 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -1,3 +1,6 @@ if(WITH_DISTRIBUTE) - grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) + grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) + set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + cc_test(serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc) endif() diff --git a/paddle/fluid/operators/detail/bytebuffer_stream.cc b/paddle/fluid/operators/detail/bytebuffer_stream.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9488156e073e515926240c9bb66d7b6edf8f82e --- /dev/null +++ b/paddle/fluid/operators/detail/bytebuffer_stream.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// NOTE: This file was originally created by tensorflow +// (https://github.com/tensorflow/tensorflow/) we borrow this +// file and did some modifications so that we can send gRPC +// requests without too much copying of the tensor data. + +#include "bytebuffer_stream.h" + +namespace paddle { +namespace operators { +namespace detail { + +GrpcByteBufferSource::GrpcByteBufferSource() {} + +bool GrpcByteBufferSource::Init(const grpc::ByteBuffer& src) { + cur_ = -1; + left_ = 0; + ptr_ = nullptr; + byte_count_ = 0; + bool ok = src.Dump(&slices_).ok(); + if (!ok) { + slices_.clear(); + } + return ok; +} + +bool GrpcByteBufferSource::Next(const void** data, int* size) { + // Use loop instead of if in case buffer contained empty slices. + while (left_ == 0) { + // Advance to next slice. + cur_++; + if (cur_ >= slices_.size()) { + return false; + } + const ::grpc::Slice& s = slices_[cur_]; + left_ = s.size(); + ptr_ = reinterpret_cast(s.begin()); + } + + *data = ptr_; + *size = left_; + byte_count_ += left_; + ptr_ += left_; + left_ = 0; + return true; +} + +void GrpcByteBufferSource::BackUp(int count) { + ptr_ -= count; + left_ += count; + byte_count_ -= count; +} + +bool GrpcByteBufferSource::Skip(int count) { + const void* data; + int size; + while (Next(&data, &size)) { + if (size >= count) { + BackUp(size - count); + return true; + } + // size < count; + count -= size; + } + // error or we have too large count; + return false; +} + +google::protobuf::int64 GrpcByteBufferSource::ByteCount() const { + return byte_count_; +} + +} // namespace detail +} // namespace operators +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/detail/bytebuffer_stream.h b/paddle/fluid/operators/detail/bytebuffer_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..099deb12d0e436427c147ab9b1eb553b712e14fb --- /dev/null +++ b/paddle/fluid/operators/detail/bytebuffer_stream.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// NOTE: This file was originally created by tensorflow +// (https://github.com/tensorflow/tensorflow/) we borrow this +// file and did some modifications so that we can send gRPC +// requests without too much copying of the tensor data. + +#pragma once + +#include +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" + +namespace paddle { +namespace operators { +namespace detail { + +// A ZeroCopyInputStream that reads from a grpc::ByteBuffer. +class GrpcByteBufferSource + : public ::google::protobuf::io::ZeroCopyInputStream { + public: + GrpcByteBufferSource(); + bool Init(const ::grpc::ByteBuffer& src); // Can be called multiple times. + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + ::google::protobuf::int64 ByteCount() const override; + + private: + std::vector<::grpc::Slice> slices_; + size_t cur_; // Current slice index. + int left_; // Number of bytes in slices_[cur_] left to yield. + const char* ptr_; // Address of next byte in slices_[cur_] to yield. + ::google::protobuf::int64 byte_count_; +}; + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/proto_encoder_helper.h b/paddle/fluid/operators/detail/proto_encoder_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..4a7bfb8bd586fe84c9243bc64117d146c4386674 --- /dev/null +++ b/paddle/fluid/operators/detail/proto_encoder_helper.h @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// NOTE: This file was originally created by tensorflow +// (https://github.com/tensorflow/tensorflow/) we borrow this +// file and did some modifications so that we can send gRPC +// requests without too much copying of the tensor data. + +#pragma once + +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace detail { + +char* EncodeVarint32(char* dst, uint32_t v) { + // Operate on characters as unsigneds + unsigned char* ptr = reinterpret_cast(dst); + static const int B = 128; + if (v < (1 << 7)) { + *(ptr++) = v; + } else if (v < (1 << 14)) { + *(ptr++) = v | B; + *(ptr++) = v >> 7; + } else if (v < (1 << 21)) { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = v >> 14; + } else if (v < (1 << 28)) { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = (v >> 14) | B; + *(ptr++) = v >> 21; + } else { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = (v >> 14) | B; + *(ptr++) = (v >> 21) | B; + *(ptr++) = v >> 28; + } + return reinterpret_cast(ptr); +} + +char* EncodeVarint64(char* dst, uint64_t v) { + static const int B = 128; + unsigned char* ptr = reinterpret_cast(dst); + while (v >= B) { + *(ptr++) = (v & (B - 1)) | B; + v >>= 7; + } + *(ptr++) = static_cast(v); + return reinterpret_cast(ptr); +} + +int VarintLength(uint64_t v) { + int len = 1; + while (v >= 128) { + v >>= 7; + len++; + } + return len; +} + +class ProtoEncodeHelper { + public: + ProtoEncodeHelper(char* buf, int max_size) + : base_(buf), p_(buf), limit_(base_ + max_size) {} + + ~ProtoEncodeHelper() { + // Make sure callers didn't do operations that went over max_size promised + PADDLE_ENFORCE_LE(p_, limit_); + } + + const char* data() const { return base_; } + size_t size() const { return p_ - base_; } + + void WriteUint64(int tag, uint64_t v) { + Encode32(combine(tag, WIRETYPE_VARINT)); + Encode64(v); + } + void WriteBool(int tag, bool v) { + Encode32(combine(tag, WIRETYPE_VARINT)); + EncodeBool(v); + } + void WriteString(int tag, const std::string& v) { + Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED)); + Encode32(v.size()); + EncodeBytes(v.data(), v.size()); + } + void WriteVarlengthBeginning(int tag, uint32_t len) { + Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED)); + Encode32(len); + } + void WriteRawBytes(const std::string& v) { EncodeBytes(v.data(), v.size()); } + + private: + // Note: this module's behavior must match the protocol buffer wire encoding + // format. + enum { + WIRETYPE_VARINT = 0, + WIRETYPE_LENGTH_DELIMITED = 2, + }; + static uint32_t combine(uint32_t tag, uint32_t type) { + return ((tag << 3) | type); + } + inline void Encode32(uint32_t v) { + if (v < 128) { + // Fast path for single-byte values. Many of the calls will use a + // constant value for v, so the comparison will get optimized away + // when Encode32 is inlined into the caller. + *p_ = v; + p_++; + } else { + p_ = EncodeVarint32(p_, v); + } + } + void Encode64(uint64_t v) { p_ = EncodeVarint64(p_, v); } + void EncodeBool(bool v) { + *p_ = (v ? 1 : 0); // Equal to varint32 encoding of 0 or 1 + p_++; + } + void EncodeBytes(const char* bytes, int N) { + memcpy(p_, bytes, N); + p_ += N; + } + + char* base_; + char* p_; + char* limit_; // Just for CHECKs +}; + +} // detail +} // operators +} // paddle diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index 8f962b4c69cc83dc2ab98b7dc27e18bc4b42bf18..b0215d4a80c9440f09c35434903fd6166b03e8b0 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -33,10 +33,34 @@ enum VarType { } message VariableMessage { + enum Type { + // Pod Types + BOOL = 0; + INT16 = 1; + INT32 = 2; + INT64 = 3; + FP16 = 4; + FP32 = 5; + FP64 = 6; + } + + message LodData { repeated int64 lod_data = 1; } + string varname = 1; // TODO(Yancey1989): reference framework::proto::VarDesc::VarType VarType type = 2; - bytes serialized = 3; + // bool persistable is not needed for sending. + // tensor info: + Type data_type = 3; + repeated int64 dims = 4; + + // lod details: + int64 lod_level = 5; + repeated LodData lod = 6; + // tensor data + bytes serialized = 7; + // selected_rows data + bytes rows = 8; } message VoidMessage {} diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 169fd40fd950a74e61a4ed06a370f25b533957db..64d181f4083dfcd43a59cad1cca21ec63df4d85f 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/operators/detail/bytebuffer_stream.h" +#include "paddle/fluid/operators/detail/proto_encoder_helper.h" namespace paddle { namespace operators { @@ -63,6 +68,242 @@ void DeserializeFromMessage(const sendrecv::VariableMessage& msg, } } +void SerializeToByteBuffer(const std::string& name, framework::Variable* var, + const platform::DeviceContext& ctx, + ::grpc::ByteBuffer* msg) { + using VarMsg = sendrecv::VariableMessage; + sendrecv::VariableMessage request; + std::string header; + request.AppendToString(&header); + // When using GPU, need to free the copied CPU buffer + // when the ByteBuffer destroies + // TODO(typhoonzero): add unref here, if we have dependent + // parallelism execution, need to know when to free the tensor. + DestroyCallback destroy_callback = [](void* backing) {}; + + void* buf = malloc(1024); + void* payload; + size_t payload_size; + ProtoEncodeHelper e((char*)buf, 1024); + e.WriteString(VarMsg::kVarnameFieldNumber, name); + if (var->IsType()) { + e.WriteUint64(VarMsg::kTypeFieldNumber, 0); + } else if (var->IsType()) { + e.WriteUint64(VarMsg::kTypeFieldNumber, 1); + } + + switch (framework::ToVarType(var->Type())) { + case framework::proto::VarType_Type_LOD_TENSOR: { + auto tensor = var->Get(); + e.WriteUint64(VarMsg::kDataTypeFieldNumber, + framework::ToDataType(tensor.type())); + for (auto& dim : framework::vectorize(tensor.dims())) { + e.WriteUint64(VarMsg::kDimsFieldNumber, dim); + } + auto lod = tensor.lod(); // std::vector> + if (lod.size() > 0) { + e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size()); + + for (auto& each : lod) { + e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber, + 2 + // tag + varintlength of submessage + 1 + // kLodDataFieldNumber + each.size()); + // auto copied from GPU + for (auto& d : each) { + e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d); + } + } + } + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE(platform::is_gpu_place(tensor.place())); + platform::CPUPlace cpu; + auto& gpu_dev_ctx = + static_cast(ctx); + auto copy_size = tensor.memory_size(); + payload = memory::Alloc(cpu, copy_size); + memory::Copy(cpu, payload, + boost::get(tensor.place()), + reinterpret_cast(tensor.data()), + copy_size, gpu_dev_ctx.stream()); + destroy_callback = [](void* backing) { + std::cout << "destroy payload" << std::endl; + platform::CPUPlace cpu; + memory::Free(cpu, backing); + }; +#endif + } else { + payload = tensor.data(); + } + payload_size = tensor.memory_size(); + + std::string tmp(reinterpret_cast(payload), payload_size); + for (int i = 0; i < tmp.size(); ++i) { + printf("%02X ", tmp.data()[i]); + } + printf("\n"); + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); + } break; + case framework::proto::VarType_Type_SELECTED_ROWS: { + // TODO(typhoonzero): selectedrows implement should not use unique_ptr + auto* slr = var->GetMutable(); + e.WriteUint64(VarMsg::kDataTypeFieldNumber, + framework::ToDataType(slr->value().type())); + for (auto& dim : framework::vectorize(slr->value().dims())) { + e.WriteUint64(VarMsg::kDimsFieldNumber, dim); + } + e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0); + auto* tensor = slr->mutable_value(); + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + platform::CPUPlace cpu; + auto& gpu_dev_ctx = + static_cast(ctx); + auto copy_size = tensor->memory_size(); + payload = memory::Alloc(cpu, copy_size); + memory::Copy(cpu, payload, + boost::get(tensor->place()), + reinterpret_cast(tensor->data()), + copy_size, gpu_dev_ctx.stream()); + ctx.Wait(); + float* ttt = reinterpret_cast(payload); + for (int i = 0; i < copy_size / 4; i++) { + std::cout << "copied to cpu: " << ttt[i] << std::endl; + } + destroy_callback = [](void* backing) { + std::cout << "destroy..." << std::endl; + // platform::CPUPlace cpu; + // memory::Free(cpu, backing); + }; +#endif + } else { + payload = slr->mutable_value()->data(); + } + payload_size = tensor->memory_size(); + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); + } break; + default: + PADDLE_THROW("Serialize does not support type: %s", + typeid(var->Type()).name()); + break; + } + // steal reference of tensor data + ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows + int num_slices = 2; // only SelectedRows have rows buffer + slices[0] = ::grpc::Slice(e.size()); + memcpy(const_cast(slices[0].begin()), e.data(), e.size()); + slices[1] = ::grpc::Slice( + grpc_slice_new_with_user_data(payload, payload_size, destroy_callback, + static_cast(payload)), + ::grpc::Slice::STEAL_REF); + + if (framework::ToVarType(var->Type()) == + framework::proto::VarType_Type_SELECTED_ROWS) { + auto* slr = var->GetMutable(); + + ProtoEncodeHelper e2((char*)buf, 128); + // NOTE: rows is of type int64_t + size_t rows_memory_size = + slr->rows().capacity() * framework::SizeOfType(typeid(int64_t)); + e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); + slices[2] = ::grpc::Slice(e2.size()); + memcpy(const_cast(slices[2].begin()), e2.data(), e2.size()); + + slices[3] = ::grpc::Slice( + grpc_slice_new_with_user_data( + const_cast( + reinterpret_cast(slr->rows().data())), + rows_memory_size, + [](void* backing) { + // TODO(typhoonzero): add unref here, same as above. + }, + const_cast( + reinterpret_cast(slr->rows().data()))), + ::grpc::Slice::STEAL_REF); + num_slices = 4; + } + + ::grpc::ByteBuffer tmp(&slices[0], num_slices); + msg->Swap(&tmp); +} + +void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, + const platform::DeviceContext& ctx, + framework::Variable* var) { + sendrecv::VariableMessage meta; + GrpcByteBufferSource source; + source.Init(msg); + ::google::protobuf::io::CodedInputStream input(&source); + // do zerocopy parsing + PADDLE_ENFORCE(meta.ParseFromCodedStream(&input)); + PADDLE_ENFORCE(input.ConsumedEntireMessage()); + // dims is needed by both tensor and selectedrows + std::vector vecdims; + for (auto& d : meta.dims()) { + vecdims.push_back(d); + } + framework::DDim dims = framework::make_ddim(vecdims); + + if (meta.type() == sendrecv::LOD_TENSOR) { + auto* tensor = var->GetMutable(); + tensor->Resize(dims); + void* tensor_data = tensor->mutable_data( + ctx.GetPlace(), + paddle::operators::detail::ToTypeIndex(meta.data_type())); + framework::LoD lod; + for (int i = 0; i < meta.lod_level(); ++i) { + framework::Vector v; + for (int j = 0; j < meta.lod(i).lod_data_size(); ++j) { + v.push_back(meta.lod(i).lod_data(j)); + } + lod.push_back(v); + } + tensor->set_lod(lod); + // How to avoid copying and use the message buffer directly? + // Maybe need to find a way to release all memory except tensor content. + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + platform::CPUPlace cpu; + auto& gpu_dev_ctx = static_cast(ctx); + memory::Copy(boost::get(tensor->place()), + tensor_data, cpu, + reinterpret_cast(meta.serialized().data()), + meta.serialized().size(), gpu_dev_ctx.stream()); +#endif + } else { + memcpy(tensor_data, + reinterpret_cast(meta.serialized().data()), + meta.serialized().size()); + } + } else if (meta.type() == sendrecv::SELECTED_ROWS) { + auto* slr = var->GetMutable(); + auto* tensor = slr->mutable_value(); + int64_t* rows_data = slr->mutable_rows()->data(); + tensor->Resize(dims); + void* tensor_data = tensor->mutable_data( + ctx.GetPlace(), + paddle::operators::detail::ToTypeIndex(meta.data_type())); + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + platform::CPUPlace cpu; + auto& gpu_dev_ctx = static_cast(ctx); + memory::Copy(boost::get(tensor->place()), + tensor_data, cpu, + reinterpret_cast(meta.serialized().data()), + meta.serialized().size(), gpu_dev_ctx.stream()); +#endif + } else { + memcpy(tensor_data, + reinterpret_cast(meta.serialized().data()), + meta.serialized().size()); + } + // copy rows CPU data, GPU data will be copied lazly + memcpy(rows_data, reinterpret_cast(meta.rows().data()), + meta.rows().size()); + } +} + } // namespace detail } // namespace operators -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h index 670d0e162473750d0a5f8e9025ef1cf9a9ef407c..65704db5ae2604c8e462ffc2828085ecb2893e43 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/detail/sendrecvop_utils.h @@ -33,6 +33,14 @@ namespace detail { #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" +typedef void (*DestroyCallback)(void*); + +inline int64_t GetTimestamp() { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + void SerializeToMessage(const std::string& name, const framework::Variable* var, const platform::DeviceContext& ctx, sendrecv::VariableMessage* msg); @@ -40,6 +48,32 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var, void DeserializeFromMessage(const sendrecv::VariableMessage& msg, const platform::DeviceContext& ctx, framework::Variable* var); + +void SerializeToByteBuffer(const std::string& name, framework::Variable* var, + const platform::DeviceContext& ctx, + ::grpc::ByteBuffer* msg); + +void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, + const platform::DeviceContext& ctx, + framework::Variable* var); + +inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { + switch (type) { + case sendrecv::VariableMessage::FP32: + return typeid(float); // NOLINT + case sendrecv::VariableMessage::FP64: + return typeid(double); // NOLINT + case sendrecv::VariableMessage::INT32: + return typeid(int); // NOLINT + case sendrecv::VariableMessage::INT64: + return typeid(int64_t); // NOLINT + case sendrecv::VariableMessage::BOOL: + return typeid(bool); // NOLINT + default: + PADDLE_THROW("Not support type %d", type); + } +} + } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/test_serde.cc b/paddle/fluid/operators/detail/test_serde.cc new file mode 100644 index 0000000000000000000000000000000000000000..8054c89ecfe7ba8273564c9d480ae6f20c5b4286 --- /dev/null +++ b/paddle/fluid/operators/detail/test_serde.cc @@ -0,0 +1,195 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace operators = paddle::operators; +namespace math = paddle::operators::math; +namespace memory = paddle::memory; + +void RunSerdeTestTensor(platform::Place place) { + // serialize var to ByteBuffer + framework::Variable var; + auto* tensor = var.GetMutable(); + tensor->Resize(framework::make_ddim({4, 8, 4, 2})); + framework::LoD lod; + lod.push_back(framework::Vector({1, 3, 8})); + tensor->set_lod(lod); + int tensor_numel = 4 * 8 * 4 * 2; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + float* orig_tensor_data = tensor->mutable_data(place); + math::set_constant(ctx, tensor, 31.9); + + ::grpc::ByteBuffer msg; + operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); + EXPECT_GT(msg.Length(), 0); + + // deserialize + std::vector<::grpc::Slice> slices; + (void)msg.Dump(&slices); + std::string tmp; + for (const auto& s : slices) { + tmp.append(reinterpret_cast(s.begin()), s.size()); + } + sendrecv::VariableMessage varmsg; + EXPECT_TRUE(varmsg.ParseFromString(tmp)); + EXPECT_EQ(varmsg.varname(), "myvar"); + EXPECT_EQ(varmsg.type(), 0); + EXPECT_EQ(varmsg.dims()[0], 4); + EXPECT_EQ(varmsg.dims()[1], 8); + EXPECT_EQ(varmsg.dims()[2], 4); + EXPECT_EQ(varmsg.dims()[3], 2); + EXPECT_EQ(varmsg.lod_level(), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); + EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); + + const float* tensor_data = + reinterpret_cast(varmsg.serialized().data()); + for (int i = 0; i < varmsg.serialized().size(); ++i) { + printf("%02X ", varmsg.serialized().data()[i]); + } + printf("\n"); + for (int i = 0; i < tensor_numel; ++i) { + std::cout << "#####tensor data: " << tensor_data[i] << std::endl; + EXPECT_EQ(tensor_data[i], orig_tensor_data[i]); + std::cout << "test end 1 " << std::endl; + } + std::cout << "tensor data end " << std::endl; + + // deserialize zero-copy + framework::Variable var2; + operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); + auto tensor2 = var2.Get(); + float* tensor_data2 = nullptr; + framework::Tensor tmp_tensor; + + if (platform::is_gpu_place(ctx.GetPlace())) { + platform::CPUPlace cpu; + framework::TensorCopy(tensor2, cpu, &tmp_tensor); + tensor_data2 = tmp_tensor.data(); + } else { + tensor_data2 = const_cast(tensor2.data()); + } + + EXPECT_EQ(varmsg.lod_level(), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); + EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); + for (int i = 0; i < tensor_numel; ++i) + EXPECT_EQ(tensor_data2[i], orig_tensor_data[i]); +} + +void RunSerdeTestSelectedRows(platform::Place place) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + + // serialize var to ByteBuffer + framework::Variable var; + auto* slr = var.GetMutable(); + auto* tensor = slr->mutable_value(); + auto* rows = slr->mutable_rows(); + + tensor->Resize(framework::make_ddim({2, 10})); + int tensor_numel = 2 * 10; + float* orig_tensor_data = tensor->mutable_data(place); + math::set_constant(ctx, tensor, 32.7); + rows->push_back(3); + rows->push_back(10); + + ::grpc::ByteBuffer msg; + operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); + EXPECT_GT(msg.Length(), 0); + + // deserialize + std::vector<::grpc::Slice> slices; + (void)msg.Dump(&slices); + std::string tmp; + for (const auto& s : slices) { + tmp.append(reinterpret_cast(s.begin()), s.size()); + } + sendrecv::VariableMessage varmsg; + EXPECT_TRUE(varmsg.ParseFromString(tmp)); + + EXPECT_EQ(varmsg.varname(), "myvar"); + EXPECT_EQ(varmsg.type(), 1); + + const float* tensor_data = + reinterpret_cast(varmsg.serialized().data()); + const int64_t* rows_data = + reinterpret_cast(varmsg.rows().data()); + for (int i = 0; i < tensor_numel; ++i) { + EXPECT_EQ(tensor_data[i], orig_tensor_data[i]); + } + EXPECT_EQ(rows_data[0], 3); + EXPECT_EQ(rows_data[1], 10); + // deserialize zero-copy + framework::Variable var2; + operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); + + auto* slr2 = var2.GetMutable(); + auto* tensor2 = slr2->mutable_value(); + auto* rows2 = slr2->mutable_rows(); + float* tensor_data2 = nullptr; + framework::Tensor tmp_tensor; + + if (platform::is_gpu_place(ctx.GetPlace())) { + platform::CPUPlace cpu; + framework::TensorCopy(*tensor2, cpu, &tmp_tensor); + tensor_data2 = tmp_tensor.data(); + } else { + tensor_data2 = const_cast(tensor2->data()); + } + const int64_t* rows_data2 = rows2->data(); + + for (int i = 0; i < tensor_numel; ++i) { + EXPECT_EQ(tensor_data2[i], orig_tensor_data[i]); + } + EXPECT_EQ(rows_data2[0], 3); + EXPECT_EQ(rows_data2[1], 10); +} + +// TEST(SelectedRows, CPU) { +// platform::CPUPlace place; +// RunSerdeTestSelectedRows(place); +// } + +// TEST(SelectedRows, GPU) { +// platform::CUDAPlace place; +// RunSerdeTestSelectedRows(place); +// } + +TEST(Tensor, CPU) { + platform::CPUPlace place; + RunSerdeTestTensor(place); +} + +TEST(Tensor, GPU) { + platform::CUDAPlace place; + RunSerdeTestTensor(place); +} \ No newline at end of file