diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc index 12e8eb0b4da2252b104415aef4156bf100c3e565..bdda5703436765480f353ee964624364f45dbefb 100644 --- a/paddle/fluid/operators/detail/variable_response.cc +++ b/paddle/fluid/operators/detail/variable_response.cc @@ -48,6 +48,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, void* dest, int size) { const void* data = NULL; int size_to_write = 0; + int length = size; + int total_written = 0; if (platform::is_gpu_place(place)) { #ifdef PADDLE_WITH_CUDA @@ -56,16 +58,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, platform::CPUPlace cpu; char* p = reinterpret_cast(dest); - while (size > 0) { + while (total_written < length) { if (!input->GetDirectBufferPointer(&data, &size_to_write)) { return false; } - + // NOTE: if raw buffer is large and have two neighbor fields of raw + // buffers GetDirectBufferPointer can get all of them, use length to + // truncate it. + if (total_written + size_to_write > length) { + size_to_write = length - total_written; + } memory::Copy(boost::get(place), reinterpret_cast(p), cpu, data, size_to_write, gpu_dev_ctx.stream()); p += size_to_write; - size -= size_to_write; + total_written += size_to_write; input->Skip(size_to_write); } @@ -77,16 +84,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, } char* p = reinterpret_cast(dest); - while (size > 0) { + while (total_written < length) { if (!input->GetDirectBufferPointer(&data, &size_to_write)) { return false; } + // NOTE: if raw buffer is large and have two neighbor fields of raw buffers + // GetDirectBufferPointer can get all of them, use length to truncate it. + if (total_written + size_to_write > length) { + size_to_write = length - total_written; + } // TODO(gongwb): can we avoid copy? platform::CPUPlace cpu; memory::Copy(cpu, reinterpret_cast(p), cpu, data, size_to_write); p += size_to_write; - size -= size_to_write; + total_written += size_to_write; input->Skip(size_to_write); } @@ -153,6 +165,7 @@ bool VariableResponse::CopySelectRowsData( const platform::DeviceContext& ctx, int length) { auto var = scope_->FindVar(meta_.varname()); auto* slr = var->GetMutable(); + slr->mutable_rows()->resize(length / 8); // int64 int64_t* rows_data = slr->mutable_rows()->data(); // copy rows CPU data, GPU data will be copied lazily. @@ -233,7 +246,6 @@ int VariableResponse::Parse(Source* source) { if (tag != 0) { return -1; } - return 0; }