From ad64ca5da20e696d66cfcf9011d16a81e8ef8ff8 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 12 Sep 2017 10:45:11 +0800 Subject: [PATCH] Call Tensor::numel() everywhere. --- paddle/framework/tensor.h | 5 ++++- paddle/framework/tensor_impl.h | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index fc54ed697f6..19051db539d 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -162,7 +162,10 @@ class Tensor { /*! points to dimensions of memory block. */ DDim dims_; - /*! the element count of tensor. */ + /** + * A cache of the number of elements in a tensor. + * Would be 0 for an uninitialized tensor. + */ int64_t numel_; /** diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 03678784b46..5e32bfcac69 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -24,7 +24,7 @@ inline void Tensor::check_memory_size() const { PADDLE_ENFORCE_NOT_NULL( holder_, "Tenosr holds no memory. Call Tensor::mutable_data first."); PADDLE_ENFORCE_GE( - holder_->size(), numel_ * sizeof(T) + offset_, + holder_->size(), numel() * sizeof(T) + offset_, "Tensor's dims_ is out of bound. Call Tensor::mutable_data " "first to re-allocate memory.\n" "or maybe the required data-type mismatches the data already stored."); @@ -54,11 +54,11 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place) { template inline T* Tensor::mutable_data(platform::Place place) { static_assert(std::is_pod::value, "T must be POD"); - PADDLE_ENFORCE_GT(numel_, 0, + PADDLE_ENFORCE_GT(numel(), 0, "Tensor's numel must be larger than zero to call " "Tensor::mutable_data. Call Tensor::set_dim first."); /* some versions of boost::variant don't have operator!= */ - int64_t size = numel_ * sizeof(T); + int64_t size = numel() * sizeof(T); if (holder_ == nullptr || !(holder_->place() == place) || holder_->size() < size + offset_) { if (platform::is_cpu_place(place)) { @@ -131,7 +131,7 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { PADDLE_ENFORCE_LT(begin_idx, end_idx, "Begin index must be less than end index."); PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1."); - size_t base = numel_ / dims_[0]; + size_t base = numel() / dims_[0]; Tensor dst; dst.holder_ = holder_; DDim dst_dims = dims_; -- GitLab