未验证 提交 61343fbf 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #10531 from typhoonzero/refine_grpc_serde_code

Refine serde code
......@@ -29,60 +29,26 @@ namespace paddle {
namespace operators {
namespace detail {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
using VarMsg = sendrecv::VariableMessage;
// When using GPU, need to free the copied CPU buffer
// when the ByteBuffer destroies
// TODO(typhoonzero): add unref here, if we have dependent
// parallelism execution, need to know when to free the tensor.
DestroyCallback destroy_callback = [](void* backing) {};
using VarMsg = sendrecv::VariableMessage;
auto buffer = std::unique_ptr<char[]>(new char[1024]);
void* buf = buffer.get();
void* payload = nullptr;
size_t payload_size;
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
e.WriteBool(VarMsg::kProfileFieldNumber, platform::IsProfileEnabled());
}
e.WriteString(VarMsg::kVarnameFieldNumber, name);
if (var->IsType<framework::LoDTensor>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
} else if (var->IsType<framework::SelectedRows>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
}
if (!out_name.empty()) {
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
}
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR: {
void GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) {
auto tensor = var->Get<framework::LoDTensor>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
framework::ToDataType(tensor.type()));
// FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto
request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
for (auto& dim : framework::vectorize(tensor.dims())) {
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
request->add_dims(dim);
}
auto lod = tensor.lod(); // std::vector<Vector<size_t>>
const framework::LoD lod = tensor.lod();
if (lod.size() > 0) {
e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size());
request->set_lod_level(lod.size());
for (auto& each : lod) {
e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber,
2 + // tag + varintlength of submessage
1 + // kLodDataFieldNumber
each.size());
// auto copied from GPU
VarMsg::LodData* lod_inner = request->add_lod();
for (auto& d : each) {
e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d);
lod_inner->add_lod_data(d);
}
}
}
......@@ -90,68 +56,100 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
payload = memory::Alloc(cpu, copy_size);
*payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()),
copy_size, gpu_dev_ctx.stream());
memory::Copy(cpu, *payload, boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()), copy_size,
gpu_dev_ctx.stream());
ctx.Wait();
destroy_callback = [](void* backing) {
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
#endif
} else {
payload = tensor.data<void>();
*payload = tensor.data<void>();
}
payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break;
case framework::proto::VarType_Type_SELECTED_ROWS: {
// TODO(typhoonzero): selectedrows implement should not use unique_ptr
*payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
}
void GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) {
auto* slr = var->GetMutable<framework::SelectedRows>();
e.WriteUint64(VarMsg::kDataTypeFieldNumber,
framework::ToDataType(slr->value().type()));
request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
request->set_lod_level(0);
request->set_slr_height(slr->height());
for (auto& dim : framework::vectorize(slr->value().dims())) {
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
request->add_dims(dim);
}
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size =
tensor->numel() * framework::SizeOfType(tensor->type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
*payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, *payload,
boost::get<platform::CUDAPlace>(tensor->place()),
reinterpret_cast<const void*>(tensor->data<void>()),
copy_size, gpu_dev_ctx.stream());
reinterpret_cast<const void*>(tensor->data<void>()), copy_size,
gpu_dev_ctx.stream());
ctx.Wait();
destroy_callback = [](void* backing) {
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
#endif
} else {
payload = slr->mutable_value()->data<void>();
*payload = slr->mutable_value()->data<void>();
}
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break;
default:
*payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
}
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {};
VarMsg request;
void* payload = nullptr;
size_t payload_size;
request.set_varname(name);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
request.set_profile(platform::IsProfileEnabled());
if (!out_name.empty()) {
request.set_out_varname(out_name);
}
if (var->IsType<framework::LoDTensor>()) {
request.set_type(::sendrecv::LOD_TENSOR);
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
} else if (var->IsType<framework::SelectedRows>()) {
request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
break;
}
if (platform::is_gpu_place(ctx.GetPlace())) {
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback = [](void* backing) {
platform::CPUPlace cpu;
memory::Free(cpu, backing);
};
}
std::string header;
request.AppendToString(&header);
auto buffer = std::unique_ptr<char[]>(new char[1024]);
void* buf = buffer.get();
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size()));
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer
......@@ -162,12 +160,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
static_cast<char*>(payload)),
::grpc::Slice::STEAL_REF);
if (framework::ToVarType(var->Type()) ==
framework::proto::VarType_Type_SELECTED_ROWS) {
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
// NOTE: rows is of type int64_t
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
......@@ -178,10 +173,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
grpc_slice_new_with_user_data(
const_cast<void*>(
reinterpret_cast<const void*>(slr->rows().data())),
rows_memory_size,
[](void* backing) {
// TODO(typhoonzero): add unref here, same as above.
},
rows_memory_size, [](void* backing) {},
const_cast<char*>(
reinterpret_cast<const char*>(slr->rows().data()))),
::grpc::Slice::STEAL_REF);
......
......@@ -117,11 +117,11 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// serialize var to ByteBuffer
framework::Variable var;
auto* tensor = var.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({4, 8, 4, 2}));
tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
framework::LoD lod;
lod.push_back(framework::Vector<size_t>({1, 3, 8}));
tensor->set_lod(lod);
int tensor_numel = 4 * 8 * 4 * 2;
int tensor_numel = 512 * 8 * 4 * 2;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
tensor->mutable_data<float>(place);
......@@ -142,7 +142,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
EXPECT_TRUE(varmsg.ParseFromString(tmp));
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 0);
EXPECT_EQ(varmsg.dims()[0], 4);
EXPECT_EQ(varmsg.dims()[0], 512);
EXPECT_EQ(varmsg.dims()[1], 8);
EXPECT_EQ(varmsg.dims()[2], 4);
EXPECT_EQ(varmsg.dims()[3], 2);
......
......@@ -210,15 +210,15 @@ bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
}
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int length = 0;
if (!input->ReadVarintSizeAsInt(&length)) {
int num_bytes = 0;
if (!input->ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
for (int i = 0; i < length; i++) {
int start_pos = input->CurrentPosition();
while (input->CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input->ReadVarint64(&v)) {
return false;
return tag;
}
lod->push_back(v);
}
......@@ -275,8 +275,8 @@ int VariableResponse::Parse(Source* source) {
break;
}
case sendrecv::VariableMessage::kTypeFieldNumber: {
uint64_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
uint32_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
......@@ -284,8 +284,8 @@ int VariableResponse::Parse(Source* source) {
break;
}
case sendrecv::VariableMessage::kDataTypeFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
uint32_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
......@@ -305,11 +305,12 @@ int VariableResponse::Parse(Source* source) {
// packed
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int length = 0;
if (!input.ReadVarintSizeAsInt(&length)) {
int num_bytes = 0;
if (!input.ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
for (int i = 0; i < length; i++) {
int start_pos = input.CurrentPosition();
while (input.CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
......@@ -318,7 +319,6 @@ int VariableResponse::Parse(Source* source) {
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kLodLevelFieldNumber: {
......@@ -372,9 +372,9 @@ int VariableResponse::Parse(Source* source) {
meta_.varname() != "",
"meta info should be got first!");
int length = 0;
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
......@@ -382,14 +382,14 @@ int VariableResponse::Parse(Source* source) {
if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0,
"lod info should be got first!");
if (!CopyLodTensorData(&input, *dev_ctx_, dims, length)) {
if (!CopyLodTensorData(&input, *dev_ctx_, dims, num_bytes)) {
return tag;
}
break;
}
if (meta_.type() == sendrecv::SELECTED_ROWS) {
if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, length)) {
if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, num_bytes)) {
return tag;
}
break;
......@@ -403,13 +403,13 @@ int VariableResponse::Parse(Source* source) {
meta_.varname() != "",
"meta info should be got first!");
int length = 0;
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (!CopySelectRowsData(&input, *dev_ctx_, length)) {
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return tag;
}
break;
......
......@@ -18,7 +18,9 @@ import math
import distributed_splitter as splitter
from .. import core
from ..framework import Program, default_main_program, Variable, Parameter
from ..framework import Program, default_main_program, \
default_startup_program, \
Variable, Parameter, grad_var_name
LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
......@@ -244,7 +246,7 @@ class DistributeTranspiler:
]
grad_list = [
grad for grad in grad_list
if grad.name != framework.grad_var_name(self.table_name)
if grad.name != grad_var_name(self.table_name)
]
self.table_param_grad = [
param_grad for param_grad in params_grads
......@@ -494,7 +496,7 @@ class DistributeTranspiler:
were split to several blocks.
"""
s_prog = Program()
orig_s_prog = framework.default_startup_program()
orig_s_prog = default_startup_program()
params = self.param_grad_ep_mapping[endpoint]["params"]
def _get_splited_name_and_shape(varname):
......@@ -619,7 +621,7 @@ class DistributeTranspiler:
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
all_ops = program.global_block().ops
table_grad_name = framework.grad_var_name(self.table_name)
table_grad_name = grad_var_name(self.table_name)
for op in all_ops:
if table_grad_name in op.output_arg_names:
op_index = list(all_ops).index(op)
......@@ -692,7 +694,7 @@ class DistributeTranspiler:
persistable=True)
grad_var = _clone_var(
pserver_program.global_block(),
self.origin_program.global_block().vars[framework.grad_var_name(
self.origin_program.global_block().vars[grad_var_name(
self.table_name)],
persistable=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册