From 55d301722fac0454e7769e4b16d77aa9ab907042 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 19 Jul 2017 16:41:11 +0800 Subject: [PATCH] Simplify Tensor implimentation ATTENTION: some interfaces changed: 1. void Tensor::set_dims(const DDim& dims) ==> void Tensor::Resize(const DDim& dims). 2. void Tensor::ShareDataFrom(const Tensor& src) ==> void Tensor::ShareDataWith(const Tensor& src) 3. DDim Tensor::dims() const ==> const DDim& Tensor::dims() const --- paddle/framework/tensor.h | 65 +++++++++++------------------- paddle/framework/tensor_test.cc | 10 ++--- paddle/memory/memory.h | 10 +++++ paddle/operators/add_op.cc | 2 +- paddle/operators/mul_op.cc | 2 +- paddle/operators/rowwise_add_op.cc | 2 +- paddle/operators/sigmoid_op.cc | 2 +- paddle/operators/softmax_op.cc | 2 +- paddle/pybind/pybind.cc | 2 +- paddle/pybind/tensor_bind.h | 2 +- 10 files changed, 45 insertions(+), 54 deletions(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 1dd421cdb..a0f0bb1ff 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -40,21 +40,21 @@ class Tensor { template const T* data() const { - CheckDims(); + EnforceSufficientMemory(); return reinterpret_cast( reinterpret_cast(holder_->ptr()) + offset_); } template T* raw_data() const { - CheckDims(); + EnforceSufficientMemory(); return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } template T* mutable_data(DDim dims, platform::Place place) { - set_dims(dims); + Resize(dims); return mutable_data(place); } @@ -147,11 +147,9 @@ class Tensor { } template - void ShareDataFrom(const Tensor& src) { - src.CheckDims(); - holder_ = src.holder_; - set_dims(src.dims()); - offset_ = src.offset_; + void ShareDataWith(const Tensor& src) { + src.EnforceSufficientMemory(); + *this = src; } template @@ -159,9 +157,9 @@ class Tensor { PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) && platform::is_cpu_place(dst_place), "Tensor::CopyFrom only support CPU now."); - src.CheckDims(); + src.EnforceSufficientMemory(); size_t size = product(src.dims_) * sizeof(T); - set_dims(src.dims()); + Resize(src.dims()); const void* src_ptr = static_cast(src.data()); void* dst_ptr = static_cast(mutable_data(dst_place)); memcpy(dst_ptr, src_ptr, size); @@ -169,34 +167,25 @@ class Tensor { template Tensor Slice(const int& begin_idx, const int& end_idx) const { - CheckDims(); - PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0], - "Slice index is less than zero or out of bound."); + EnforceSufficientMemory(); + PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero."); + PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound."); PADDLE_ENFORCE(begin_idx < end_idx, "Begin index must be less than end index."); PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1."); - std::vector d = vectorize(dims_); - int base = 1; - for (size_t i = 1; i < d.size(); ++i) { - base *= d[i]; - } + int base = product(dims_) / dims_[0]; Tensor dst; dst.holder_ = holder_; DDim dst_dims = dims_; dst_dims[0] = end_idx - begin_idx; - dst.set_dims(dst_dims); + dst.Resize(dst_dims); dst.offset_ = offset_ + begin_idx * base * sizeof(T); return dst; } - void set_dims(const DDim& dims) { - if (dims == dims_) { - return; - } - dims_ = dims; - } + void Resize(const DDim& dims) { dims_ = dims; } - DDim dims() const { return dims_; } + const DDim& dims() const { return dims_; } private: // Placeholder hides type T, so it doesn't appear as a template @@ -211,21 +200,9 @@ class Tensor { template struct PlaceholderImpl : public Placeholder { - private: - template - class Deleter { - public: - Deleter(PType place) : place_(place) {} - void operator()(T* ptr) { memory::Free(place_, static_cast(ptr)); } - - private: - PType place_; - }; - - public: PlaceholderImpl(PlaceType place, size_t size) : ptr_(static_cast(memory::Alloc(place, size)), - Deleter(place)), + memory::PodDeleter(place)), place_(place), size_(size) {} @@ -234,13 +211,13 @@ class Tensor { virtual paddle::platform::Place place() const { return place_; } virtual std::type_index type() const { return std::type_index(typeid(T)); } - std::unique_ptr> ptr_; + std::unique_ptr> ptr_; platform::Place place_; // record the place of ptr_. size_t size_; // size of the memory block. }; template - inline void CheckDims() const { + inline void EnforceSufficientMemory() const { PADDLE_ENFORCE(holder_ != nullptr, "Tenosr holds no memory. Call Tensor::mutable_data first."); PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_, @@ -250,7 +227,11 @@ class Tensor { std::shared_ptr holder_; // holds the memory block if allocated. DDim dims_; - size_t offset_; // marks the begin of tensor data area. + // A PlaceHolder may be shared by more than one tensor. Some of them may be + // slices of the others. So the offset_ is introduced here to indicate the + // byte offset between PlaceHolder::ptr_ and where tensor's data really + // begins. + size_t offset_; template friend struct paddle::pybind::details::CastToPyBufferImpl; }; diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 84c6f0cf6..a78bdd41b 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -19,7 +19,7 @@ TEST(Tensor, Dims) { using namespace paddle::framework; using namespace paddle::platform; Tensor tt; - tt.set_dims(make_ddim({2, 3, 4})); + tt.Resize(make_ddim({2, 3, 4})); DDim dims = tt.dims(); ASSERT_EQ(arity(dims), 3); for (int i = 0; i < 3; ++i) { @@ -97,7 +97,7 @@ TEST(Tensor, MutableData) { #endif } -TEST(Tensor, ShareDataFrom) { +TEST(Tensor, ShareDataWith) { using namespace paddle::framework; using namespace paddle::platform; { @@ -106,7 +106,7 @@ TEST(Tensor, ShareDataFrom) { // Try to share data form uninitialized tensor bool caught = false; try { - dst_tensor.ShareDataFrom(src_tensor); + dst_tensor.ShareDataWith(src_tensor); } catch (EnforceNotMet err) { caught = true; std::string msg = @@ -119,7 +119,7 @@ TEST(Tensor, ShareDataFrom) { ASSERT_TRUE(caught); src_tensor.mutable_data(make_ddim({2, 3, 4}), CPUPlace()); - dst_tensor.ShareDataFrom(src_tensor); + dst_tensor.ShareDataWith(src_tensor); ASSERT_EQ(src_tensor.data(), dst_tensor.data()); } @@ -128,7 +128,7 @@ TEST(Tensor, ShareDataFrom) { Tensor src_tensor; Tensor dst_tensor; src_tensor.mutable_data(make_ddim({2, 3, 4}), GPUPlace()); - dst_tensor.ShareDataFrom(src_tensor); + dst_tensor.ShareDataWith(src_tensor); ASSERT_EQ(src_tensor.data(), dst_tensor.data()); } #endif diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 2d6f4fd2a..f5890fb84 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -28,5 +28,15 @@ void Free(Place, void*); template size_t Used(Place); +template +class PodDeleter { + public: + PodDeleter(PlaceType place) : place_(place) {} + void operator()(T* ptr) { Free(place_, static_cast(ptr)); } + + private: + PlaceType place_; +}; + } // namespace memory } // namespace paddle diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 41d044cdb..858a65089 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -31,7 +31,7 @@ protected: "Inputs/Outputs of AddOp must all be set"); PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), "Two input of Add Op's dimension must be same."); - outputs[0]->set_dims(inputs[0]->dims()); + outputs[0]->Resize(inputs[0]->dims()); } }; diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 713b2a5dc..e7bda6a7d 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -33,7 +33,7 @@ protected: dim0[1] == dim1[0], "First matrix's width must be equal with second matrix's height."); PADDLE_ENFORCE(outputs.size() == 1, "The mul op must take one output"); - outputs[0]->set_dims({dim0[0], dim1[1]}); + outputs[0]->Resize({dim0[0], dim1[1]}); } }; diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 414bafd04..97d42c193 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -30,7 +30,7 @@ protected: PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); PADDLE_ENFORCE(outputs.size() == 1, "The output size must be 1"); - outputs[0]->set_dims(inputs[0]->dims()); + outputs[0]->Resize(inputs[0]->dims()); } }; diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 45ae277c5..e87fb78d3 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -24,7 +24,7 @@ protected: const std::vector &outputs) const override { PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input"); PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output"); - outputs[0]->set_dims(inputs[0]->dims()); + outputs[0]->Resize(inputs[0]->dims()); } }; diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 4ca7be359..07302d67d 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -25,7 +25,7 @@ protected: PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax"); PADDLE_ENFORCE(outputs.size() == 1, "Only one output is need for softmax"); - outputs[0]->set_dims(inputs[0]->dims()); + outputs[0]->Resize(inputs[0]->dims()); } }; diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fc9c6544c..56d6fe4dd 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -42,7 +42,7 @@ PYBIND11_PLUGIN(core) { [](const pd::Tensor& self) { return pd::vectorize(self.dims()); }) .def("set_dims", [](pd::Tensor& self, const std::vector& dim) { - self.set_dims(pd::make_ddim(dim)); + self.Resize(pd::make_ddim(dim)); }) .def("alloc_float", [](pd::Tensor& self) { diff --git a/paddle/pybind/tensor_bind.h b/paddle/pybind/tensor_bind.h index b96516643..995e102bf 100644 --- a/paddle/pybind/tensor_bind.h +++ b/paddle/pybind/tensor_bind.h @@ -86,7 +86,7 @@ void PyTensorSetFromArray( dims.push_back((int)array.shape()[i]); } - self.set_dims(framework::make_ddim(dims)); + self.Resize(framework::make_ddim(dims)); auto *dst = self.mutable_data(paddle::platform::CPUPlace()); std::memcpy(dst, array.data(), sizeof(T) * array.size()); } -- GitLab