提交 18483585 编写于 作者: T typhoonzero

fix copy size

上级 788636f0
...@@ -82,7 +82,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -82,7 +82,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto& gpu_dev_ctx = auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx); static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.memory_size(); auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
payload = memory::Alloc(cpu, copy_size); payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload, memory::Copy(cpu, payload,
...@@ -99,7 +99,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -99,7 +99,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} else { } else {
payload = tensor.data<void>(); payload = tensor.data<void>();
} }
payload_size = tensor.memory_size(); payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break; } break;
case framework::proto::VarType_Type_SELECTED_ROWS: { case framework::proto::VarType_Type_SELECTED_ROWS: {
...@@ -118,7 +118,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -118,7 +118,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto& gpu_dev_ctx = auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(ctx); static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor->memory_size(); auto copy_size =
tensor->numel() * framework::SizeOfType(tensor->type());
payload = memory::Alloc(cpu, copy_size); payload = memory::Alloc(cpu, copy_size);
memory::Copy(cpu, payload, memory::Copy(cpu, payload,
boost::get<platform::CUDAPlace>(tensor->place()), boost::get<platform::CUDAPlace>(tensor->place()),
...@@ -133,7 +134,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -133,7 +134,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} else { } else {
payload = slr->mutable_value()->data<void>(); payload = slr->mutable_value()->data<void>();
} }
payload_size = tensor->memory_size(); payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
} break; } break;
default: default:
......
...@@ -32,7 +32,8 @@ class SplitByrefOpKernel : public framework::OpKernel<T> { ...@@ -32,7 +32,8 @@ class SplitByrefOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < outs.size(); ++i) { for (size_t i = 0; i < outs.size(); ++i) {
// NOTE: no need to call mutable_data here to allocate memory. // NOTE: no need to call mutable_data here to allocate memory.
auto* out = outs[i]; auto* out = outs[i];
*out = std::move(in->Slice(row_offset, out->dims()[0])); VLOG(3) << "spliting by ref: " << row_offset << " " << out->dims()[0];
*out = std::move(in->Slice(row_offset, row_offset + out->dims()[0]));
row_offset += out->dims()[0]; row_offset += out->dims()[0];
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册