diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 3787b139a5b40bfc18df27e443edb8649fb711ae..bdda5703436765480f353ee964624364f45dbefb 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -48,6 +48,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, void* dest, int size) { const void* data = NULL; int size_to_write = 0; + int length = size; + int total_written = 0; if (platform::is_gpu_place(place)) { #ifdef PADDLE_WITH_CUDA @@ -56,16 +58,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, platform::CPUPlace cpu; char* p = reinterpret_cast(dest); - while (size > 0) { + while (total_written < length) { if (!input->GetDirectBufferPointer(&data, &size_to_write)) { return false; } - + // NOTE: if raw buffer is large and have two neighbor fields of raw + // buffers GetDirectBufferPointer can get all of them, use length to + // truncate it. + if (total_written + size_to_write > length) { + size_to_write = length - total_written; + } memory::Copy(boost::get(place), reinterpret_cast(p), cpu, data, size_to_write, gpu_dev_ctx.stream()); p += size_to_write; - size -= size_to_write; + total_written += size_to_write; input->Skip(size_to_write); } @@ -77,16 +84,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, } char* p = reinterpret_cast(dest); - while (size > 0) { + while (total_written < length) { if (!input->GetDirectBufferPointer(&data, &size_to_write)) { return false; } + // NOTE: if raw buffer is large and have two neighbor fields of raw buffers + // GetDirectBufferPointer can get all of them, use length to truncate it. + if (total_written + size_to_write > length) { + size_to_write = length - total_written; + } // TODO(gongwb): can we avoid copy? platform::CPUPlace cpu; memory::Copy(cpu, reinterpret_cast(p), cpu, data, size_to_write); p += size_to_write; - size -= size_to_write; + total_written += size_to_write; input->Skip(size_to_write); } @@ -234,7 +246,6 @@ int VariableResponse::Parse(Source* source) { if (tag != 0) { return -1; } - return 0; }