提交 45af8c1e 编写于 作者: 武毅 提交者: gongweibao

Performance/zero copy variable seriralization (#8839)

上级 12fc76e1
......@@ -187,7 +187,6 @@ bool TensorContainsInf(const framework::Tensor& tensor) {
void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx) {
// TODO(typhoonzero): serialize to ostream
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char*>(&version), sizeof(version));
......
if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc)
endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// NOTE: This file was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#include "bytebuffer_stream.h"
namespace paddle {
namespace operators {
namespace detail {
GrpcByteBufferSource::GrpcByteBufferSource() {}
bool GrpcByteBufferSource::Init(const grpc::ByteBuffer& src) {
cur_ = -1;
left_ = 0;
ptr_ = nullptr;
byte_count_ = 0;
bool ok = src.Dump(&slices_).ok();
if (!ok) {
slices_.clear();
}
return ok;
}
bool GrpcByteBufferSource::Next(const void** data, int* size) {
// Use loop instead of if in case buffer contained empty slices.
while (left_ == 0) {
// Advance to next slice.
cur_++;
if (cur_ >= slices_.size()) {
return false;
}
const ::grpc::Slice& s = slices_[cur_];
left_ = s.size();
ptr_ = reinterpret_cast<const char*>(s.begin());
}
*data = ptr_;
*size = left_;
byte_count_ += left_;
ptr_ += left_;
left_ = 0;
return true;
}
void GrpcByteBufferSource::BackUp(int count) {
ptr_ -= count;
left_ += count;
byte_count_ -= count;
}
bool GrpcByteBufferSource::Skip(int count) {
const void* data;
int size;
while (Next(&data, &size)) {
if (size >= count) {
BackUp(size - count);
return true;
}
// size < count;
count -= size;
}
// error or we have too large count;
return false;
}
google::protobuf::int64 GrpcByteBufferSource::ByteCount() const {
return byte_count_;
}
} // namespace detail
} // namespace operators
} // namespace paddle
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// NOTE: This file was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#pragma once
#include <grpc++/grpc++.h>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
namespace paddle {
namespace operators {
namespace detail {
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
class GrpcByteBufferSource
: public ::google::protobuf::io::ZeroCopyInputStream {
public:
GrpcByteBufferSource();
bool Init(const ::grpc::ByteBuffer& src); // Can be called multiple times.
bool Next(const void** data, int* size) override;
void BackUp(int count) override;
bool Skip(int count) override;
::google::protobuf::int64 ByteCount() const override;
private:
std::vector<::grpc::Slice> slices_;
size_t cur_; // Current slice index.
int left_; // Number of bytes in slices_[cur_] left to yield.
const char* ptr_; // Address of next byte in slices_[cur_] to yield.
::google::protobuf::int64 byte_count_;
};
} // namespace detail
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// NOTE: This file was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data.
#pragma once
#include <grpc++/grpc++.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace detail {
char* EncodeVarint32(char* dst, uint32_t v) {
// Operate on characters as unsigneds
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
static const int B = 128;
if (v < (1 << 7)) {
*(ptr++) = v;
} else if (v < (1 << 14)) {
*(ptr++) = v | B;
*(ptr++) = v >> 7;
} else if (v < (1 << 21)) {
*(ptr++) = v | B;
*(ptr++) = (v >> 7) | B;
*(ptr++) = v >> 14;
} else if (v < (1 << 28)) {
*(ptr++) = v | B;
*(ptr++) = (v >> 7) | B;
*(ptr++) = (v >> 14) | B;
*(ptr++) = v >> 21;
} else {
*(ptr++) = v | B;
*(ptr++) = (v >> 7) | B;
*(ptr++) = (v >> 14) | B;
*(ptr++) = (v >> 21) | B;
*(ptr++) = v >> 28;
}
return reinterpret_cast<char*>(ptr);
}
char* EncodeVarint64(char* dst, uint64_t v) {
static const int B = 128;
unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
while (v >= B) {
*(ptr++) = (v & (B - 1)) | B;
v >>= 7;
}
*(ptr++) = static_cast<unsigned char>(v);
return reinterpret_cast<char*>(ptr);
}
int VarintLength(uint64_t v) {
int len = 1;
while (v >= 128) {
v >>= 7;
len++;
}
return len;
}
class ProtoEncodeHelper {
public:
ProtoEncodeHelper(char* buf, int max_size)
: base_(buf), p_(buf), limit_(base_ + max_size) {}
~ProtoEncodeHelper() {
// Make sure callers didn't do operations that went over max_size promised
PADDLE_ENFORCE_LE(p_, limit_);
}
const char* data() const { return base_; }
size_t size() const { return p_ - base_; }
void WriteUint64(int tag, uint64_t v) {
Encode32(combine(tag, WIRETYPE_VARINT));
Encode64(v);
}
void WriteBool(int tag, bool v) {
Encode32(combine(tag, WIRETYPE_VARINT));
EncodeBool(v);
}
void WriteString(int tag, const std::string& v) {
Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED));
Encode32(v.size());
EncodeBytes(v.data(), v.size());
}
void WriteVarlengthBeginning(int tag, uint32_t len) {
Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED));
Encode32(len);
}
void WriteRawBytes(const std::string& v) { EncodeBytes(v.data(), v.size()); }
private:
// Note: this module's behavior must match the protocol buffer wire encoding
// format.
enum {
WIRETYPE_VARINT = 0,
WIRETYPE_LENGTH_DELIMITED = 2,
};
static uint32_t combine(uint32_t tag, uint32_t type) {
return ((tag << 3) | type);
}
inline void Encode32(uint32_t v) {
if (v < 128) {
// Fast path for single-byte values. Many of the calls will use a
// constant value for v, so the comparison will get optimized away
// when Encode32 is inlined into the caller.
*p_ = v;
p_++;
} else {
p_ = EncodeVarint32(p_, v);
}
}
void Encode64(uint64_t v) { p_ = EncodeVarint64(p_, v); }
void EncodeBool(bool v) {
*p_ = (v ? 1 : 0); // Equal to varint32 encoding of 0 or 1
p_++;
}
void EncodeBytes(const char* bytes, int N) {
memcpy(p_, bytes, N);
p_ += N;
}
char* base_;
char* p_;
char* limit_; // Just for CHECKs
};
} // detail
} // operators
} // paddle
......@@ -33,10 +33,34 @@ enum VarType {
}
message VariableMessage {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message LodData { repeated int64 lod_data = 1; }
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
bytes serialized = 3;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// tensor data
bytes serialized = 7;
// selected_rows data
bytes rows = 8;
}
message VoidMessage {}
......@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h"
namespace paddle {
namespace operators {
......@@ -63,6 +68,242 @@ void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
}
}
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg) {
using VarMsg = sendrecv::VariableMessage;
sendrecv::VariableMessage request;
std::string header;
request.AppendToString(&header);
// When using GPU, need to free the copied CPU buffer
// when the ByteBuffer destroies
// TODO(typhoonzero): add unref here, if we have dependent
// parallelism execution, need to know when to free the tensor.
DestroyCallback destroy_callback = [](void* backing) {};
void* buf = malloc(1024);
void* payload;
size_t payload_size;
ProtoEncodeHelper e((char*)buf, 1024);
e.WriteString(VarMsg::kVarnameFieldNumber, name);
if (var->IsType<framework::LoDTensor>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
} else if (var->IsType<framework::SelectedRows>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
}
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR: {
auto tensor = var->Get<framework::LoDTensor>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
framework::ToDataType(tensor.type()));
for (auto& dim : framework::vectorize(tensor.dims())) {
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
}
auto lod = tensor.lod(); // std::vector<Vector<size_t>>
if (lod.size() > 0) {
e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size());
for (auto& each : lod) {
e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber,
2 + // tag + varintlength of submessage
1 + // kLodDataFieldNumber
each.size());
// auto copied from GPU
for (auto& d : each) {
e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d);
}
}
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.memory_size();
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()),
copy_size, gpu_dev_ctx.stream());
destroy_callback = [](void* backing) {
std::cout << "destroy payload" << std::endl;
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
#endif
} else {
payload = tensor.data<void>();
}
payload_size = tensor.memory_size();
std::string tmp(reinterpret_cast<char*>(payload), payload_size);
for (int i = 0; i < tmp.size(); ++i) {
printf("%02X ", tmp.data()[i]);
}
printf("\n");
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break;
case framework::proto::VarType_Type_SELECTED_ROWS: {
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
auto* slr = var->GetMutable<framework::SelectedRows>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
framework::ToDataType(slr->value().type()));
for (auto& dim : framework::vectorize(slr->value().dims())) {
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
}
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor->memory_size();
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor->place()),
reinterpret_cast<const void*>(tensor->data<void>()),
copy_size, gpu_dev_ctx.stream());
ctx.Wait();
float* ttt = reinterpret_cast<float*>(payload);
for (int i = 0; i < copy_size / 4; i++) {
std::cout << "copied to cpu: " << ttt[i] << std::endl;
}
destroy_callback = [](void* backing) {
std::cout << "destroy..." << std::endl;
// platform::CPUPlace cpu;
// memory::Free(cpu, backing);
};
#endif
} else {
payload = slr->mutable_value()->data<void>();
}
payload_size = tensor->memory_size();
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break;
default:
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
break;
}
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer
slices[0] = ::grpc::Slice(e.size());
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
slices[1] = ::grpc::Slice(
grpc_slice_new_with_user_data(payload, payload_size, destroy_callback,
static_cast<char*>(payload)),
::grpc::Slice::STEAL_REF);
if (framework::ToVarType(var->Type()) ==
framework::proto::VarType_Type_SELECTED_ROWS) {
auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2((char*)buf, 128);
// NOTE: rows is of type int64_t
size_t rows_memory_size =
slr->rows().capacity() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
slices[3] = ::grpc::Slice(
grpc_slice_new_with_user_data(
const_cast<void*>(
reinterpret_cast<const void*>(slr->rows().data())),
rows_memory_size,
[](void* backing) {
// TODO(typhoonzero): add unref here, same as above.
},
const_cast<char*>(
reinterpret_cast<const char*>(slr->rows().data()))),
::grpc::Slice::STEAL_REF);
num_slices = 4;
}
::grpc::ByteBuffer tmp(&slices[0], num_slices);
msg->Swap(&tmp);
}
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
framework::Variable* var) {
sendrecv::VariableMessage meta;
GrpcByteBufferSource source;
source.Init(msg);
::google::protobuf::io::CodedInputStream input(&source);
// do zerocopy parsing
PADDLE_ENFORCE(meta.ParseFromCodedStream(&input));
PADDLE_ENFORCE(input.ConsumedEntireMessage());
// dims is needed by both tensor and selectedrows
std::vector<int> vecdims;
for (auto& d : meta.dims()) {
vecdims.push_back(d);
}
framework::DDim dims = framework::make_ddim(vecdims);
if (meta.type() == sendrecv::LOD_TENSOR) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(dims);
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta.data_type()));
framework::LoD lod;
for (int i = 0; i < meta.lod_level(); ++i) {
framework::Vector<size_t> v;
for (int j = 0; j < meta.lod(i).lod_data_size(); ++j) {
v.push_back(meta.lod(i).lod_data(j));
}
lod.push_back(v);
}
tensor->set_lod(lod);
// How to avoid copying and use the message buffer directly?
// Maybe need to find a way to release all memory except tensor content.
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(boost::get<platform::CUDAPlace>(tensor->place()),
tensor_data, cpu,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size(), gpu_dev_ctx.stream());
#endif
} else {
memcpy(tensor_data,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size());
}
} else if (meta.type() == sendrecv::SELECTED_ROWS) {
auto* slr = var->GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
int64_t* rows_data = slr->mutable_rows()->data();
tensor->Resize(dims);
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta.data_type()));
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(boost::get<platform::CUDAPlace>(tensor->place()),
tensor_data, cpu,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size(), gpu_dev_ctx.stream());
#endif
} else {
memcpy(tensor_data,
reinterpret_cast<const void*>(meta.serialized().data()),
meta.serialized().size());
}
// copy rows CPU data, GPU data will be copied lazly
memcpy(rows_data, reinterpret_cast<const void*>(meta.rows().data()),
meta.rows().size());
}
}
} // namespace detail
} // namespace operators
} // namespace paddle
\ No newline at end of file
......@@ -33,6 +33,14 @@ namespace detail {
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
typedef void (*DestroyCallback)(void*);
inline int64_t GetTimestamp() {
return std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
}
void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg);
......@@ -40,6 +48,32 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var,
void DeserializeFromMessage(const sendrecv::VariableMessage& msg,
const platform::DeviceContext& ctx,
framework::Variable* var);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg);
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
framework::Variable* var);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
switch (type) {
case sendrecv::VariableMessage::FP32:
return typeid(float); // NOLINT
case sendrecv::VariableMessage::FP64:
return typeid(double); // NOLINT
case sendrecv::VariableMessage::INT32:
return typeid(int); // NOLINT
case sendrecv::VariableMessage::INT64:
return typeid(int64_t); // NOLINT
case sendrecv::VariableMessage::BOOL:
return typeid(bool); // NOLINT
default:
PADDLE_THROW("Not support type %d", type);
}
}
} // namespace detail
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
void RunSerdeTestTensor(platform::Place place) {
// serialize var to ByteBuffer
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({4, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
int tensor_numel = 4 * 8 * 4 * 2;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
float* orig_tensor_data = tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 31.9);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), 0);
// deserialize
std::vector<::grpc::Slice> slices;
(void)msg.Dump(&slices);
std::string tmp;
for (const auto& s : slices) {
tmp.append(reinterpret_cast<const char*>(s.begin()), s.size());
}
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 0);
EXPECT_EQ(varmsg.dims()[0], 4);
EXPECT_EQ(varmsg.dims()[1], 8);
EXPECT_EQ(varmsg.dims()[2], 4);
EXPECT_EQ(varmsg.dims()[3], 2);
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
const float* tensor_data =
reinterpret_cast<const float*>(varmsg.serialized().data());
for (int i = 0; i < varmsg.serialized().size(); ++i) {
printf("%02X ", varmsg.serialized().data()[i]);
}
printf("\n");
for (int i = 0; i < tensor_numel; ++i) {
std::cout << "#####tensor data: " << tensor_data[i] << std::endl;
EXPECT_EQ(tensor_data[i], orig_tensor_data[i]);
std::cout << "test end 1 " << std::endl;
}
std::cout << "tensor data end " << std::endl;
// deserialize zero-copy
framework::Variable var2;
operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
auto tensor2 = var2.Get<framework::LoDTensor>();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2.data<float>());
}
EXPECT_EQ(varmsg.lod_level(), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
for (int i = 0; i < tensor_numel; ++i)
EXPECT_EQ(tensor_data2[i], orig_tensor_data[i]);
}
void RunSerdeTestSelectedRows(platform::Place place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// serialize var to ByteBuffer
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10}));
int tensor_numel = 2 * 10;
float* orig_tensor_data = tensor->mutable_data<float>(place);
math::set_constant(ctx, tensor, 32.7);
rows->push_back(3);
rows->push_back(10);
::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
EXPECT_GT(msg.Length(), 0);
// deserialize
std::vector<::grpc::Slice> slices;
(void)msg.Dump(&slices);
std::string tmp;
for (const auto& s : slices) {
tmp.append(reinterpret_cast<const char*>(s.begin()), s.size());
}
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 1);
const float* tensor_data =
reinterpret_cast<const float*>(varmsg.serialized().data());
const int64_t* rows_data =
reinterpret_cast<const int64_t*>(varmsg.rows().data());
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_EQ(tensor_data[i], orig_tensor_data[i]);
}
EXPECT_EQ(rows_data[0], 3);
EXPECT_EQ(rows_data[1], 10);
// deserialize zero-copy
framework::Variable var2;
operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
auto* slr2 = var2.GetMutable<framework::SelectedRows>();
auto* tensor2 = slr2->mutable_value();
auto* rows2 = slr2->mutable_rows();
float* tensor_data2 = nullptr;
framework::Tensor tmp_tensor;
if (platform::is_gpu_place(ctx.GetPlace())) {
platform::CPUPlace cpu;
framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
tensor_data2 = tmp_tensor.data<float>();
} else {
tensor_data2 = const_cast<float*>(tensor2->data<float>());
}
const int64_t* rows_data2 = rows2->data();
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_EQ(tensor_data2[i], orig_tensor_data[i]);
}
EXPECT_EQ(rows_data2[0], 3);
EXPECT_EQ(rows_data2[1], 10);
}
// TEST(SelectedRows, CPU) {
// platform::CPUPlace place;
// RunSerdeTestSelectedRows(place);
// }
// TEST(SelectedRows, GPU) {
// platform::CUDAPlace place;
// RunSerdeTestSelectedRows(place);
// }
TEST(Tensor, CPU) {
platform::CPUPlace place;
RunSerdeTestTensor(place);
}
TEST(Tensor, GPU) {
platform::CUDAPlace place;
RunSerdeTestTensor(place);
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册