未验证 提交 a152315b 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine Tensor method, test=develop (#21031)

上级 b5d8ba83
...@@ -58,7 +58,7 @@ class SelectedRows { ...@@ -58,7 +58,7 @@ class SelectedRows {
rwlock_.reset(new RWLock); rwlock_.reset(new RWLock);
} }
platform::Place place() const { return value_->place(); } const platform::Place& place() const { return value_->place(); }
const Tensor& value() const { return *value_; } const Tensor& value() const { return *value_; }
......
...@@ -34,8 +34,8 @@ size_t Tensor::memory_size() const { ...@@ -34,8 +34,8 @@ size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_; return holder_ == nullptr ? 0UL : holder_->size() - offset_;
} }
void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type, void* Tensor::mutable_data(const platform::Place& place,
size_t requested_size) { proto::VarType::Type type, size_t requested_size) {
type_ = type; type_ = type;
PADDLE_ENFORCE_GE(numel(), 0, PADDLE_ENFORCE_GE(numel(), 0,
"When calling this method, the Tensor's numel must be " "When calling this method, the Tensor's numel must be "
...@@ -60,7 +60,8 @@ void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type, ...@@ -60,7 +60,8 @@ void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type,
offset_); offset_);
} }
void* Tensor::mutable_data(platform::Place place, size_t requested_size) { void* Tensor::mutable_data(const platform::Place& place,
size_t requested_size) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
this->holder_, "Cannot invoke mutable data if current hold nothing."); this->holder_, "Cannot invoke mutable data if current hold nothing.");
return mutable_data(place, type_, requested_size); return mutable_data(place, type_, requested_size);
......
...@@ -87,12 +87,12 @@ class Tensor { ...@@ -87,12 +87,12 @@ class Tensor {
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T>
T* mutable_data(platform::Place place, size_t requested_size = 0); T* mutable_data(const platform::Place& place, size_t requested_size = 0);
void* mutable_data(platform::Place place, proto::VarType::Type type, void* mutable_data(const platform::Place& place, proto::VarType::Type type,
size_t requested_size = 0); size_t requested_size = 0);
void* mutable_data(platform::Place place, size_t requested_size = 0); void* mutable_data(const platform::Place& place, size_t requested_size = 0);
/** /**
* @brief Return a pointer to mutable memory block. * @brief Return a pointer to mutable memory block.
...@@ -104,7 +104,8 @@ class Tensor { ...@@ -104,7 +104,8 @@ class Tensor {
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T>
T* mutable_data(DDim dims, platform::Place place, size_t requested_size = 0); T* mutable_data(const DDim& dims, const platform::Place& place,
size_t requested_size = 0);
/*! Return the dimensions of the memory block. */ /*! Return the dimensions of the memory block. */
const DDim& dims() const; const DDim& dims() const;
...@@ -128,7 +129,7 @@ class Tensor { ...@@ -128,7 +129,7 @@ class Tensor {
*/ */
Tensor Slice(int64_t begin_idx, int64_t end_idx) const; Tensor Slice(int64_t begin_idx, int64_t end_idx) const;
platform::Place place() const { const platform::Place& place() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::place() is called."); holder_, "Tensor not initialized yet when Tensor::place() is called.");
return holder_->place(); return holder_->place();
......
...@@ -51,7 +51,7 @@ inline T* Tensor::data() { ...@@ -51,7 +51,7 @@ inline T* Tensor::data() {
} }
template <typename T> template <typename T>
inline T* Tensor::mutable_data(DDim dims, platform::Place place, inline T* Tensor::mutable_data(const DDim& dims, const platform::Place& place,
size_t requested_size) { size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims); Resize(dims);
...@@ -59,7 +59,8 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place, ...@@ -59,7 +59,8 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place,
} }
template <typename T> template <typename T>
inline T* Tensor::mutable_data(platform::Place place, size_t requested_size) { inline T* Tensor::mutable_data(const platform::Place& place,
size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>( return reinterpret_cast<T*>(
mutable_data(place, DataTypeTrait<T>::DataType(), requested_size)); mutable_data(place, DataTypeTrait<T>::DataType(), requested_size));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册