From a152315be70e311682b5786946bf746896843044 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 19 Nov 2019 11:09:56 +0800 Subject: [PATCH] refine Tensor method, test=develop (#21031) --- paddle/fluid/framework/selected_rows.h | 2 +- paddle/fluid/framework/tensor.cc | 7 ++++--- paddle/fluid/framework/tensor.h | 11 ++++++----- paddle/fluid/framework/tensor_impl.h | 5 +++-- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index e1bdba9b46a..f8a40a5d99a 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -58,7 +58,7 @@ class SelectedRows { rwlock_.reset(new RWLock); } - platform::Place place() const { return value_->place(); } + const platform::Place& place() const { return value_->place(); } const Tensor& value() const { return *value_; } diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index 7b39c5359e8..956049ca6a7 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -34,8 +34,8 @@ size_t Tensor::memory_size() const { return holder_ == nullptr ? 0UL : holder_->size() - offset_; } -void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type, - size_t requested_size) { +void* Tensor::mutable_data(const platform::Place& place, + proto::VarType::Type type, size_t requested_size) { type_ = type; PADDLE_ENFORCE_GE(numel(), 0, "When calling this method, the Tensor's numel must be " @@ -60,7 +60,8 @@ void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type, offset_); } -void* Tensor::mutable_data(platform::Place place, size_t requested_size) { +void* Tensor::mutable_data(const platform::Place& place, + size_t requested_size) { PADDLE_ENFORCE_NOT_NULL( this->holder_, "Cannot invoke mutable data if current hold nothing."); return mutable_data(place, type_, requested_size); diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 8fffecfa0e1..ea28302b35a 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -87,12 +87,12 @@ class Tensor { * @note If not exist, then allocation. */ template - T* mutable_data(platform::Place place, size_t requested_size = 0); + T* mutable_data(const platform::Place& place, size_t requested_size = 0); - void* mutable_data(platform::Place place, proto::VarType::Type type, + void* mutable_data(const platform::Place& place, proto::VarType::Type type, size_t requested_size = 0); - void* mutable_data(platform::Place place, size_t requested_size = 0); + void* mutable_data(const platform::Place& place, size_t requested_size = 0); /** * @brief Return a pointer to mutable memory block. @@ -104,7 +104,8 @@ class Tensor { * @note If not exist, then allocation. */ template - T* mutable_data(DDim dims, platform::Place place, size_t requested_size = 0); + T* mutable_data(const DDim& dims, const platform::Place& place, + size_t requested_size = 0); /*! Return the dimensions of the memory block. */ const DDim& dims() const; @@ -128,7 +129,7 @@ class Tensor { */ Tensor Slice(int64_t begin_idx, int64_t end_idx) const; - platform::Place place() const { + const platform::Place& place() const { PADDLE_ENFORCE_NOT_NULL( holder_, "Tensor not initialized yet when Tensor::place() is called."); return holder_->place(); diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index 34dd3c59fff..f5171b0a8d1 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -51,7 +51,7 @@ inline T* Tensor::data() { } template -inline T* Tensor::mutable_data(DDim dims, platform::Place place, +inline T* Tensor::mutable_data(const DDim& dims, const platform::Place& place, size_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); Resize(dims); @@ -59,7 +59,8 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place, } template -inline T* Tensor::mutable_data(platform::Place place, size_t requested_size) { +inline T* Tensor::mutable_data(const platform::Place& place, + size_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); return reinterpret_cast( mutable_data(place, DataTypeTrait::DataType(), requested_size)); -- GitLab