未验证 提交 eb62a778 编写于 作者: 武毅 提交者: GitHub

Merge pull request #9409 from typhoonzero/fix_slr_deser

Fix slr deser
......@@ -48,6 +48,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
void* dest, int size) {
const void* data = NULL;
int size_to_write = 0;
int length = size;
int total_written = 0;
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
......@@ -56,16 +58,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
platform::CPUPlace cpu;
char* p = reinterpret_cast<char*>(dest);
while (size > 0) {
while (total_written < length) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
// NOTE: if raw buffer is large and have two neighbor fields of raw
// buffers GetDirectBufferPointer can get all of them, use length to
// truncate it.
if (total_written + size_to_write > length) {
size_to_write = length - total_written;
}
memory::Copy(boost::get<platform::CUDAPlace>(place),
reinterpret_cast<void*>(p), cpu, data, size_to_write,
gpu_dev_ctx.stream());
p += size_to_write;
size -= size_to_write;
total_written += size_to_write;
input->Skip(size_to_write);
}
......@@ -77,16 +84,21 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
}
char* p = reinterpret_cast<char*>(dest);
while (size > 0) {
while (total_written < length) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
// NOTE: if raw buffer is large and have two neighbor fields of raw buffers
// GetDirectBufferPointer can get all of them, use length to truncate it.
if (total_written + size_to_write > length) {
size_to_write = length - total_written;
}
// TODO(gongwb): can we avoid copy?
platform::CPUPlace cpu;
memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write);
p += size_to_write;
size -= size_to_write;
total_written += size_to_write;
input->Skip(size_to_write);
}
......@@ -153,6 +165,7 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->resize(length / 8); // int64
int64_t* rows_data = slr->mutable_rows()->data();
// copy rows CPU data, GPU data will be copied lazily.
......@@ -233,7 +246,6 @@ int VariableResponse::Parse(Source* source) {
if (tag != 0) {
return -1;
}
return 0;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册