From ca39515e24be50931000be632134dce2e4a23d3f Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 11 Jul 2017 18:52:09 +0800 Subject: [PATCH] Add several interfaces for Tensor class 1. Add member variable 'DDim dims_' and a getter function 'dims()'. 'dims' is supposed to hold tensor's shape during Op::InferShape. 2. Remove 'mutable_data' which use default Place. User must specify a explicit Place when call 'mutable_data'. 3. A PlaceHolder may be shared by more than one tensor, and some of them may be the others' slices. So we add a new member variable 'offset_' for Tensor, which is used to show the byte offset between PlaceHolder::ptr_ and where tensor's data really begins. 4. Add functions 'ShareDataFrom' and 'Slice' for Tensor. TODO: Tensor needs a 'CopyFrom' function. --- paddle/framework/tensor.h | 57 ++++++++++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index ce5d98b04e..d40edb190c 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include "paddle/framework/ddim.h" @@ -26,31 +27,65 @@ namespace framework { class Tensor { public: + Tensor() : offset_(0) {} + + Tensor(const DDim& dims) : dims_(dims), offset_(0) {} + template const T* data() const { - PADDLE_ENFORCE(holder_ != nullptr, - "Tensor::data must be called after Tensor::mutable_data."); - return static_cast(holder_->Ptr()); + PADDLE_ENFORCE( + holder_ != nullptr, + "Tenosr has not been initialized. Call Tensor::mutable_data first."); + return reinterpret_cast( + reinterpret_cast(holder_->Ptr()) + offset_); } template ::value>::type* = nullptr> T* mutable_data(DDim dims, paddle::platform::Place place) { + dims_ = dims; if (holder_ == nullptr || !(holder_->Place() == place) /* some versions of boost::variant don't have operator!= */ - || holder_->Size() < product(dims) * sizeof(T)) { + || holder_->Size() < product(dims) * sizeof(T) + offset_) { holder_.reset(new PlaceholderImpl(place, product(dims) * sizeof(T))); + offset_ = 0; } - return static_cast(holder_->Ptr()); + return reinterpret_cast(reinterpret_cast(holder_->Ptr()) + + offset_); } - template ::value>::type* = nullptr> - T* mutable_data(DDim dims) { - return mutable_data(dims, paddle::platform::get_place()); + void ShareDataFrom(const Tensor& src) { + PADDLE_ENFORCE(src.holder_ != nullptr, + "Tenosr 'src' has not been initialized."); + holder_ = src.holder_; + dims_ = src.dims_; + offset_ = src.offset_; } + Tensor Slice(const int& begin_idx, const int& end_idx) { + PADDLE_ENFORCE(holder_ != nullptr, + "The sliced tenosr has not been initialized."); + PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0], + "Slice index is less than zero or out of bound."); + PADDLE_ENFORCE(begin_idx < end_idx, + "Begin index must be less than end index."); + PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1."); + std::vector d = vectorize(dims_); + int base = 1; + for (size_t i = 1; i < d.size(); ++i) { + base *= d[i]; + } + Tensor dst; + dst.holder_ = holder_; + dst.dims_ = dims_; + dst.dims_[0] = end_idx - begin_idx; + dst.offset_ = offset_ + begin_idx * base * holder_->TypeSize(); + return dst; + } + + DDim dims() const { return dims_; } + private: // Placeholder hides type T, so it doesn't appear as a template // parameter of Variable. @@ -59,6 +94,7 @@ class Tensor { virtual void* Ptr() const = 0; virtual paddle::platform::Place Place() const = 0; virtual size_t Size() const = 0; + virtual size_t TypeSize() const = 0; }; template @@ -85,6 +121,7 @@ class Tensor { virtual void* Ptr() const { return static_cast(ptr_.get()); } virtual size_t Size() const { return size_; } virtual paddle::platform::Place Place() const { return place_; } + virtual size_t TypeSize() const { return sizeof(T); } std::unique_ptr ptr_; paddle::platform::Place place_; // record the place of ptr_. @@ -92,6 +129,8 @@ class Tensor { }; std::shared_ptr holder_; // holds the memory block if allocated. + DDim dims_; + size_t offset_; // marks the begin of tensor data area. }; } // namespace framework -- GitLab