未验证 提交 47ad8d49 编写于 作者: Y yuyang18

Fix deserialize bug

上级 97b774df
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -261,7 +262,8 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, ...@@ -261,7 +262,8 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
os.write(out.data(), size); os.write(out.data(), size);
} }
{ // the 3rd field, tensor data { // the 3rd field, tensor data
uint64_t size = tensor.memory_size(); uint64_t size = tensor.numel() * framework::SizeOfType(tensor.type());
auto* data_ptr = tensor.data<void>(); auto* data_ptr = tensor.data<void>();
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(), PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
"Index overflow when writing tensor"); "Index overflow when writing tensor");
...@@ -331,6 +333,9 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -331,6 +333,9 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
void* buf; void* buf;
auto ctx = platform::CPUDeviceContext(); auto ctx = platform::CPUDeviceContext();
size_t size =
tensor->numel() *
framework::SizeOfType(framework::ToTypeIndex(desc.data_type()));
if (platform::is_gpu_place(dev_ctx.GetPlace())) { if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor; Tensor cpu_tensor;
...@@ -338,7 +343,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -338,7 +343,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
framework::VisitDataType( framework::VisitDataType(
desc.data_type(), desc.data_type(),
DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace())); DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), cpu_tensor.memory_size()); is.read(static_cast<char*>(buf), size);
auto dst_place = dev_ctx.GetPlace(); auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor); framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
#else #else
...@@ -348,7 +353,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -348,7 +353,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
framework::VisitDataType( framework::VisitDataType(
desc.data_type(), desc.data_type(),
DeserializedDataFunctor(&buf, tensor, ctx.GetPlace())); DeserializedDataFunctor(&buf, tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), tensor->memory_size()); is.read(static_cast<char*>(buf), size);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册