提交 18483585 编写于 作者: T typhoonzero

fix copy size

上级 788636f0
......@@ -82,7 +82,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.memory_size();
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
......@@ -99,7 +99,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} else {
payload = tensor.data<void>();
}
payload_size = tensor.memory_size();
payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break;
case framework::proto::VarType_Type_SELECTED_ROWS: {
......@@ -118,7 +118,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
platform::CPUPlace cpu;
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor->memory_size();
auto copy_size =
tensor->numel() * framework::SizeOfType(tensor->type());
payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor->place()),
......@@ -133,7 +134,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} else {
payload = slr->mutable_value()->data<void>();
}
payload_size = tensor->memory_size();
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break;
default:
......
......@@ -32,7 +32,8 @@ class SplitByrefOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < outs.size(); ++i) {
// NOTE: no need to call mutable_data here to allocate memory.
auto* out = outs[i];
*out = std::move(in->Slice(row_offset, out->dims()[0]));
VLOG(3) << "spliting by ref: " << row_offset << " " << out->dims()[0];
*out = std::move(in->Slice(row_offset, row_offset + out->dims()[0]));
row_offset += out->dims()[0];
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册