diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 29bad7a00a439f393bcc22cbbf716f98681971bf..b405e3877c9f15712dd5bbf08de1e4f7148e0fd9 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -27,7 +27,7 @@ namespace framework { class Tensor { public: - Tensor() : numel_(0), offset_(0) {} + Tensor() : offset_(0) {} template const T* data() const { @@ -44,30 +44,26 @@ class Tensor { template T* mutable_data(platform::Place place) { - PADDLE_ENFORCE(numel_ > 0, - "Tensor::numel_ must be larger than zero to call " + PADDLE_ENFORCE(product(dims_) > 0, + "Tensor's numel must be larger than zero to call " "Tensor::mutable_data. Call Tensor::set_dim first."); if (holder_ == nullptr || !(holder_->place() == place) /* some versions of boost::variant don't have operator!= */ - || holder_->size() < numel_ * sizeof(T) + offset_) { + || holder_->size() < product(dims_) * sizeof(T) + offset_) { + if (platform::is_cpu_place(place)) { + holder_.reset(new PlaceholderImpl( + boost::get(place), product(dims_) * sizeof(T))); + } else if (platform::is_gpu_place(place)) { #ifdef __CUDACC__ - switch (place.which()) { - case 0: - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); - break; - - case 1: - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); - break; - } + holder_.reset(new PlaceholderImpl( + boost::get(place), product(dims_) * sizeof(T))); #else - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); + PADDLE_ENFORCE(true, "'GPUPlace' is not supported in CPU only device."); #endif - + } else { + PADDLE_ENFORCE(true, "Unknown 'place'."); + } offset_ = 0; } return reinterpret_cast(reinterpret_cast(holder_->ptr()) + @@ -88,7 +84,7 @@ class Tensor { platform::is_cpu_place(dst_place), "Tensor::CopyFrom only support CPU now."); src.CheckDims(); - size_t size = src.numel_ * sizeof(T); + size_t size = product(src.dims_) * sizeof(T); set_dims(src.dims()); const void* src_ptr = static_cast(src.data()); void* dst_ptr = static_cast(mutable_data(dst_place)); @@ -122,7 +118,6 @@ class Tensor { return; } dims_ = dims; - numel_ = product(dims_); } DDim dims() const { return dims_; } @@ -170,16 +165,15 @@ class Tensor { inline void CheckDims() const { PADDLE_ENFORCE(holder_ != nullptr, "Tenosr holds no memory. Call Tensor::mutable_data first."); - PADDLE_ENFORCE(holder_->size() >= numel_ * sizeof(T) + offset_, + PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_, "Tensor's dims_ is out of bound. Call Tensor::mutable_data " "first to re-allocate memory."); } std::shared_ptr holder_; // holds the memory block if allocated. DDim dims_; - size_t numel_; // cache of `product(dims_)` size_t offset_; // marks the begin of tensor data area. -}; // namespace framework +}; } // namespace framework } // namespace paddle