提交 cc1c6afb 编写于 作者: Y yi.wu

fix slr serde

上级 094d5096
...@@ -48,6 +48,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, ...@@ -48,6 +48,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
void* dest, int size) { void* dest, int size) {
const void* data = NULL; const void* data = NULL;
int size_to_write = 0; int size_to_write = 0;
int length = size;
int total_written = 0;
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -56,16 +58,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, ...@@ -56,16 +58,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
platform::CPUPlace cpu; platform::CPUPlace cpu;
char* p = reinterpret_cast<char*>(dest); char* p = reinterpret_cast<char*>(dest);
while (size > 0) { while (total_written < length) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) { if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false; return false;
} }
// NOTE: if raw buffer is large and have two neighbor fields of raw
// buffers GetDirectBufferPointer can get all of them, use length to
// truncate it.
if (total_written + size_to_write > length) {
size_to_write = length - total_written;
}
memory::Copy(boost::get<platform::CUDAPlace>(place), memory::Copy(boost::get<platform::CUDAPlace>(place),
reinterpret_cast<void*>(p), cpu, data, size_to_write, reinterpret_cast<void*>(p), cpu, data, size_to_write,
gpu_dev_ctx.stream()); gpu_dev_ctx.stream());
p += size_to_write; p += size_to_write;
size -= size_to_write; total_written += size_to_write;
input->Skip(size_to_write); input->Skip(size_to_write);
} }
...@@ -77,16 +84,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, ...@@ -77,16 +84,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
} }
char* p = reinterpret_cast<char*>(dest); char* p = reinterpret_cast<char*>(dest);
while (size > 0) { while (total_written < length) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) { if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false; return false;
} }
// NOTE: if raw buffer is large and have two neighbor fields of raw buffers
// GetDirectBufferPointer can get all of them, use length to truncate it.
if (total_written + size_to_write > length) {
size_to_write = length - total_written;
}
// TODO(gongwb): can we avoid copy? // TODO(gongwb): can we avoid copy?
platform::CPUPlace cpu; platform::CPUPlace cpu;
memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write); memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write);
p += size_to_write; p += size_to_write;
size -= size_to_write; total_written += size_to_write;
input->Skip(size_to_write); input->Skip(size_to_write);
} }
...@@ -234,7 +246,6 @@ int VariableResponse::Parse(Source* source) { ...@@ -234,7 +246,6 @@ int VariableResponse::Parse(Source* source) {
if (tag != 0) { if (tag != 0) {
return -1; return -1;
} }
return 0; return 0;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册