diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 73f5499ad15752237a73ca27e0cd0fe2c5e86b4e..b6ad8b60aaf7b5657e1f18defe3c8d2c2ebbc060 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -117,6 +117,8 @@ int DDim::operator[](int idx) const { return boost::apply_visitor(DynamicConstIndexer(idx), var); } +ssize_t DDim::size() const { return arity(*this); } + bool DDim::operator==(DDim d) const { if (var.which() != d.getVar().which()) { return false; diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index a0c2a8a74afdefd4a504ec6fa730238e077efbb5..7bc21a1e3455bcf8889c18084e8c96e7a3c76a13 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -50,6 +50,8 @@ struct DDim { DDimVar getVar() { return var; } + ssize_t size() const; + bool operator==(DDim d) const; bool operator!=(DDim d) const; diff --git a/paddle/framework/ddim_test.cc b/paddle/framework/ddim_test.cc index 6a099f2aeb4aa117bca8695aa326fbd1272a43d6..9d18a2972ce62139430b240b4599854b14290a32 100644 --- a/paddle/framework/ddim_test.cc +++ b/paddle/framework/ddim_test.cc @@ -49,6 +49,7 @@ TEST(DDim, Equality) { // arity of a DDim EXPECT_EQ(paddle::framework::arity(ddim), 3); + EXPECT_EQ(ddim.size(), 3); // product of a DDim EXPECT_EQ(paddle::framework::product(vddim), 45); diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 29bad7a00a439f393bcc22cbbf716f98681971bf..b405e3877c9f15712dd5bbf08de1e4f7148e0fd9 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -27,7 +27,7 @@ namespace framework { class Tensor { public: - Tensor() : numel_(0), offset_(0) {} + Tensor() : offset_(0) {} template const T* data() const { @@ -44,30 +44,26 @@ class Tensor { template T* mutable_data(platform::Place place) { - PADDLE_ENFORCE(numel_ > 0, - "Tensor::numel_ must be larger than zero to call " + PADDLE_ENFORCE(product(dims_) > 0, + "Tensor's numel must be larger than zero to call " "Tensor::mutable_data. Call Tensor::set_dim first."); if (holder_ == nullptr || !(holder_->place() == place) /* some versions of boost::variant don't have operator!= */ - || holder_->size() < numel_ * sizeof(T) + offset_) { + || holder_->size() < product(dims_) * sizeof(T) + offset_) { + if (platform::is_cpu_place(place)) { + holder_.reset(new PlaceholderImpl( + boost::get(place), product(dims_) * sizeof(T))); + } else if (platform::is_gpu_place(place)) { #ifdef __CUDACC__ - switch (place.which()) { - case 0: - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); - break; - - case 1: - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); - break; - } + holder_.reset(new PlaceholderImpl( + boost::get(place), product(dims_) * sizeof(T))); #else - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); + PADDLE_ENFORCE(true, "'GPUPlace' is not supported in CPU only device."); #endif - + } else { + PADDLE_ENFORCE(true, "Unknown 'place'."); + } offset_ = 0; } return reinterpret_cast(reinterpret_cast(holder_->ptr()) + @@ -88,7 +84,7 @@ class Tensor { platform::is_cpu_place(dst_place), "Tensor::CopyFrom only support CPU now."); src.CheckDims(); - size_t size = src.numel_ * sizeof(T); + size_t size = product(src.dims_) * sizeof(T); set_dims(src.dims()); const void* src_ptr = static_cast(src.data()); void* dst_ptr = static_cast(mutable_data(dst_place)); @@ -122,7 +118,6 @@ class Tensor { return; } dims_ = dims; - numel_ = product(dims_); } DDim dims() const { return dims_; } @@ -170,16 +165,15 @@ class Tensor { inline void CheckDims() const { PADDLE_ENFORCE(holder_ != nullptr, "Tenosr holds no memory. Call Tensor::mutable_data first."); - PADDLE_ENFORCE(holder_->size() >= numel_ * sizeof(T) + offset_, + PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_, "Tensor's dims_ is out of bound. Call Tensor::mutable_data " "first to re-allocate memory."); } std::shared_ptr holder_; // holds the memory block if allocated. DDim dims_; - size_t numel_; // cache of `product(dims_)` size_t offset_; // marks the begin of tensor data area. -}; // namespace framework +}; } // namespace framework } // namespace paddle