From 47ad8d490984ebec9a4124d90d443576e6cf5fa8 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Thu, 19 Jul 2018 20:19:13 +0800 Subject: [PATCH] Fix deserialize bug --- paddle/fluid/framework/tensor_util.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index f98011e896..ab693004cf 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -15,6 +15,7 @@ #include #include #include +#include "paddle/fluid/framework/data_type.h" namespace paddle { namespace framework { @@ -261,7 +262,8 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, os.write(out.data(), size); } { // the 3rd field, tensor data - uint64_t size = tensor.memory_size(); + uint64_t size = tensor.numel() * framework::SizeOfType(tensor.type()); + auto* data_ptr = tensor.data(); PADDLE_ENFORCE(size < std::numeric_limits::max(), "Index overflow when writing tensor"); @@ -331,6 +333,9 @@ void TensorFromStream(std::istream& is, Tensor* tensor, tensor->Resize(framework::make_ddim(dims)); void* buf; auto ctx = platform::CPUDeviceContext(); + size_t size = + tensor->numel() * + framework::SizeOfType(framework::ToTypeIndex(desc.data_type())); if (platform::is_gpu_place(dev_ctx.GetPlace())) { #ifdef PADDLE_WITH_CUDA Tensor cpu_tensor; @@ -338,7 +343,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor, framework::VisitDataType( desc.data_type(), DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace())); - is.read(static_cast(buf), cpu_tensor.memory_size()); + is.read(static_cast(buf), size); auto dst_place = dev_ctx.GetPlace(); framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor); #else @@ -348,7 +353,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor, framework::VisitDataType( desc.data_type(), DeserializedDataFunctor(&buf, tensor, ctx.GetPlace())); - is.read(static_cast(buf), tensor->memory_size()); + is.read(static_cast(buf), size); } } } -- GitLab