diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index a658537430e27a82f43d3e8ff5d3a1abc6c48369..8962b76a12cdb1b6396eca704ca6336eca412b49 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -20,23 +20,19 @@ class Tensor { using paddle::platform::get_place; public: - explicit Tensor(DDim dims) : dims_(dims), place_(get_place()) {} - explicit Tensor(DDim dims, Place place) : dims_(dims), place_(place) {} - template const T* data() const { - PADDLE_ASSERT(holder_ != nullptr); - PADDLE_ASSERT(holder_->Place() == place_); - PADDLE_ASSERT(holder_->Size() >= dims_.product() * sizeof(T)); + PADDLE_ASSERT(holder_ != nullptr, + "Tensor::data must be called after Tensor::mutable_data"); return static_cast(holder->Ptr()); } template ::value>::type> - T* mutable_data() { - if (holder_ == nullptr || holder_->Place() != place_ || - holder_->Size() < dims_.product() * sizeof(T)) { - holder_.reset(new PlaceholderImpl(place_, dims.product() * sizeof(T))); + T* mutable_data(DDim dims, Place place) { + if (holder_ == nullptr || holder_->Place() != place || + holder_->Size() < dims.product() * sizeof(T)) { + holder_.reset(new PlaceholderImpl(place, dims.product() * sizeof(T))); } return static_cast(holder_->Ptr()); } @@ -44,16 +40,7 @@ class Tensor { template ::value>::type> T* mutable_data(DDim dims) { - dims_ = dims; - return mutable_data(); - } - - template ::value>::type> - T* mutable_data(DDim dims, Place place) { - dims_ = dims; - place_ = place; - return mutable_data(); + return mutable_data(dims, paddle::platform::get_place()); } private: @@ -69,7 +56,7 @@ class Tensor { template struct PlaceholderImpl : public Placeholder { PlaceholderImpl(Place pl, size_t size) - : ptr_(memory::Alloc(pl, size), paddle::memory::Deleter(pl)), + : ptr_(paddle::memory::Alloc(pl, size), paddle::memory::Deleter(pl)), place_(pl), size_(size) {} @@ -83,8 +70,6 @@ class Tensor { }; std::unique_ptr holder_; // holds the memory block if allocated. - DDim dims_; // could be smallers than the holder_->Size(). - paddle::platform::Place place_; }; } // namespace framework