diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index e164f57abc5367d7ae82f26b52bb9546e70bda44..8cb4d1793c1a24da3fd55c36810e75630711efce 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -17,7 +17,6 @@ limitations under the License. */ #include #include #include -#include #include "paddle/framework/ddim.h" #include "paddle/framework/enforce.h" #include "paddle/memory/memory.h" @@ -28,15 +27,15 @@ namespace framework { class Tensor { public: - Tensor() : offset_(0) { numel_ = product(dims_); } + Tensor() : numel_(0), offset_(0) {} Tensor& operator=(const Tensor& src) = delete; template const T* data() const { - CheckDimsValidity(); + CheckDimsValidity(); return reinterpret_cast( - reinterpret_cast(holder_->Ptr()) + offset_); + reinterpret_cast(holder_->ptr()) + offset_); } template @@ -51,35 +50,40 @@ class Tensor { "Tensor::numel_ must be larger than zero to call " "Tensor::mutable_data."); if (holder_ == nullptr || - !(holder_->Place() == + !(holder_->place() == place) /* some versions of boost::variant don't have operator!= */ - || holder_->Size() < numel_ * sizeof(T) + offset_) { + || holder_->size() < numel_ * sizeof(T) + offset_) { holder_.reset(new PlaceholderImpl(place, numel_ * sizeof(T))); offset_ = 0; } - return reinterpret_cast(reinterpret_cast(holder_->Ptr()) + + return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } + template void ShareDataFrom(const Tensor& src) { - src.CheckDimsValidity(); + src.CheckDimsValidity(); holder_ = src.holder_; - dims_ = src.dims(); - numel_ = src.numel_; + set_dims(src.dims()); offset_ = src.offset_; } + template void CopyFrom(const Tensor& src, paddle::platform::Place dst_place) { - src.CheckDimsValidity(); - size_t size = src.numel_ * src.holder_->TypeSize(); - holder_.reset(src.holder_->Clone(src.offset_, size, dst_place)); - dims_ = src.dims(); - numel_ = src.numel_; - offset_ = 0; + PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) && + platform::is_cpu_place(dst_place), + "Tensor::CopyFrom only support CPU now."); + src.CheckDimsValidity(); + size_t size = src.numel_ * sizeof(T); + set_dims(src.dims()); + void* src_ptr = static_cast(src.data()); + void* dst_ptr = static_cast(mutable_data(dst_place)); + memcpy(dst_ptr, src_ptr, size); } + template Tensor Slice(const int& begin_idx, const int& end_idx) const { - CheckDimsValidity(); + CheckDimsValidity(); PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0], "Slice index is less than zero or out of bound."); PADDLE_ENFORCE(begin_idx < end_idx, @@ -95,7 +99,7 @@ class Tensor { DDim dst_dims = dims_; dst_dims[0] = end_idx - begin_idx; dst.set_dims(dst_dims); - dst.offset_ = offset_ + begin_idx * base * holder_->TypeSize(); + dst.offset_ = offset_ + begin_idx * base * sizeof(T); return dst; } @@ -115,12 +119,9 @@ class Tensor { // parameter of Variable. struct Placeholder { virtual ~Placeholder() {} - virtual void* Ptr() const = 0; - virtual paddle::platform::Place Place() const = 0; - virtual size_t Size() const = 0; - virtual size_t TypeSize() const = 0; - virtual Placeholder* Clone(size_t begin, size_t size, - paddle::platform::Place place) const = 0; + virtual void* ptr() const = 0; + virtual paddle::platform::Place place() const = 0; + virtual size_t size() const = 0; }; template @@ -144,32 +145,20 @@ class Tensor { place_(place), size_(size) {} - virtual void* Ptr() const { return static_cast(ptr_.get()); } - virtual size_t Size() const { return size_; } - virtual paddle::platform::Place Place() const { return place_; } - virtual size_t TypeSize() const { return sizeof(T); } - // TODO: Clone only support CPU now. GPU support is needed. - virtual Placeholder* Clone(size_t begin, size_t size, - paddle::platform::Place place) const { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(place_) && - paddle::platform::is_cpu_place(place), - "PlaceholderImpl::Clone only support CPU now."); - PlaceholderImpl* dst = new PlaceholderImpl(place, size); - void* begin_ptr = - reinterpret_cast(reinterpret_cast(Ptr()) + begin); - memcpy(dst->Ptr(), begin_ptr, size); - return dst; - } + virtual void* ptr() const { return static_cast(ptr_.get()); } + virtual size_t size() const { return size_; } + virtual paddle::platform::Place place() const { return place_; } std::unique_ptr ptr_; paddle::platform::Place place_; // record the place of ptr_. size_t size_; // size of the memory block. }; - inline void CheckDimsValidity() { + template + inline void CheckDimsValidity() const { PADDLE_ENFORCE(holder_ != nullptr, "Tenosr holds no memory. Call Tensor::mutable_data first."); - PADDLE_ENFORCE(holder_->Size() > numel_ * sizeof(T) + offset_, + PADDLE_ENFORCE(holder_->size() > numel_ * sizeof(T) + offset_, "Tensor's dims_ is out of bound. Call Tensor::mutable_data " "first to re-allocate memory."); } diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 6db0ba8c798ee6ea503e6305213822408807b1e8..eef9cfcd9e7e9fd16af75b073056386f1d7611d4 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -18,7 +18,8 @@ TEST(Tensor, Dims) { using namespace paddle::framework; using namespace paddle::platform; - Tensor tt(make_ddim({2, 3, 4})); + Tensor tt; + tt.set_dims(make_ddim({2, 3, 4})); DDim dims = tt.dims(); ASSERT_EQ(arity(dims), 3); for (int i = 0; i < 3; ++i) { @@ -35,7 +36,7 @@ TEST(Tensor, DataAssert) { } catch (paddle::framework::EnforceNotMet err) { caught = true; std::string msg = - "Tenosr has not been initialized. Call Tensor::mutable_data first."; + "Tenosr holds no memory. Call Tensor::mutable_data first."; const char* what = err.what(); for (size_t i = 0; i < msg.length(); ++i) { ASSERT_EQ(what[i], msg[i]);