From 184835856c94043a5c27f5da3921cdaba433273c Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 18 Apr 2018 14:44:17 +0800 Subject: [PATCH] fix copy size --- paddle/fluid/operators/detail/sendrecvop_utils.cc | 9 +++++---- paddle/fluid/operators/split_byref_op.h | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 16c612c45a3..69fcffe9bc3 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -82,7 +82,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, platform::CPUPlace cpu; auto& gpu_dev_ctx = static_cast(ctx); - auto copy_size = tensor.memory_size(); + auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type()); payload = memory::Alloc(cpu, copy_size); memory::Copy(cpu, payload, @@ -99,7 +99,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, } else { payload = tensor.data(); } - payload_size = tensor.memory_size(); + payload_size = tensor.numel() * framework::SizeOfType(tensor.type()); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); } break; case framework::proto::VarType_Type_SELECTED_ROWS: { @@ -118,7 +118,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, platform::CPUPlace cpu; auto& gpu_dev_ctx = static_cast(ctx); - auto copy_size = tensor->memory_size(); + auto copy_size = + tensor->numel() * framework::SizeOfType(tensor->type()); payload = memory::Alloc(cpu, copy_size); memory::Copy(cpu, payload, boost::get(tensor->place()), @@ -133,7 +134,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, } else { payload = slr->mutable_value()->data(); } - payload_size = tensor->memory_size(); + payload_size = tensor->numel() * framework::SizeOfType(tensor->type()); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); } break; default: diff --git a/paddle/fluid/operators/split_byref_op.h b/paddle/fluid/operators/split_byref_op.h index 9b54c7c74ac..a3aad68ea73 100644 --- a/paddle/fluid/operators/split_byref_op.h +++ b/paddle/fluid/operators/split_byref_op.h @@ -32,7 +32,8 @@ class SplitByrefOpKernel : public framework::OpKernel { for (size_t i = 0; i < outs.size(); ++i) { // NOTE: no need to call mutable_data here to allocate memory. auto* out = outs[i]; - *out = std::move(in->Slice(row_offset, out->dims()[0])); + VLOG(3) << "spliting by ref: " << row_offset << " " << out->dims()[0]; + *out = std::move(in->Slice(row_offset, row_offset + out->dims()[0])); row_offset += out->dims()[0]; } } -- GitLab