diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index b57958591fb752132407c35958db0781d0e023f0..cd1b4de426a49fa66dbbf8cf7d09990ac8d21227 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -79,11 +79,11 @@ class Tensor { inline const DDim& dims() const; /*! Resize the dimensions of the memory block. */ - inline void Resize(const DDim& dims); + inline Tensor& Resize(const DDim& dims); /*! The internal of two tensors share the same memory block. */ template - inline void ShareDataWith(const Tensor& src); + inline Tensor& ShareDataWith(const Tensor& src); /** * @brief Copy the content of external tensor to a new place. diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 8d9bec6dc9c3f0af822a0d8cd8588dc932970652..7d7263b899afb7a2128548f264065a8013b6f0c9 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -23,9 +23,11 @@ template inline void Tensor::check_memory_size() const { PADDLE_ENFORCE_NOT_NULL( holder_, "Tenosr holds no memory. Call Tensor::mutable_data first."); - PADDLE_ENFORCE_GE(holder_->size(), product(dims_) * sizeof(T) + offset_, - "Tensor's dims_ is out of bound. Call Tensor::mutable_data " - "first to re-allocate memory."); + PADDLE_ENFORCE_GE( + holder_->size(), product(dims_) * sizeof(T) + offset_, + "Tensor's dims_ is out of bound. Call Tensor::mutable_data " + "first to re-allocate memory.\n" + "or maybe the required data-type mismatches the data already stored."); } template @@ -78,9 +80,10 @@ inline T* Tensor::mutable_data(platform::Place place) { } template -inline void Tensor::ShareDataWith(const Tensor& src) { +inline Tensor& Tensor::ShareDataWith(const Tensor& src) { src.check_memory_size(); *this = src; + return *this; } template @@ -136,7 +139,10 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { return dst; } -inline void Tensor::Resize(const DDim& dims) { dims_ = dims; } +inline Tensor& Tensor::Resize(const DDim& dims) { + dims_ = dims; + return *this; +} inline const DDim& Tensor::dims() const { return dims_; }