From 53f6c6991aa749305bc585d067fa761579fcf995 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 19 Dec 2018 13:32:36 +0000 Subject: [PATCH] polish code test=develop --- paddle/fluid/framework/ddim.cc | 50 +++++++--------------------------- paddle/fluid/framework/ddim.h | 46 ++++++++++++++++++++++--------- paddle/fluid/framework/dim.h | 9 +++--- 3 files changed, 48 insertions(+), 57 deletions(-) diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index 3640138e18..f7fee04c1e 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -18,32 +18,10 @@ limitations under the License. */ namespace paddle { namespace framework { -template -struct DDimAssignFunctor { - static_assert(std::is_integral::value, "T must be integral type"); - using result_type = void; - explicit DDimAssignFunctor(const T* in) : in_(in) {} - - template - inline void operator()(Dim& dim) { // NOLINT - UnrollAssign::Run(in_, dim.data()); - } - - const T* in_; -}; - -DDim::DDim(const int* d, int n) : rank_(n) { - this->apply_visitor(DDimAssignFunctor(d)); -} - -DDim::DDim(const int64_t* d, int n) : rank_(n) { - this->apply_visitor(DDimAssignFunctor(d)); -} - template Dim make_dim(const int64_t* d) { Dim ret; - for (int i = 0; i < N; ++i) ret[i] = d[i]; + fix_dim_assign(d, ret.GetMutable()); return ret; } @@ -64,14 +42,14 @@ struct DDimEqualityVisitor { template inline bool operator()(const Dim& self) const { - return UnrollCompare::Run(self.data(), d_); + return UnrollCompare::Run(self.Get(), d_); } const int64_t* d_; }; bool DDim::operator==(const DDim& d) const { - return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.data())); + return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.Get())); } bool DDim::operator!=(const DDim& d) const { return !(*this == d); } @@ -82,7 +60,7 @@ struct DDimPlusVisitor { template inline void operator()(Dim& self) const { - UnrollAdd::Run(d1_, d2_, self.data()); + UnrollAdd::Run(d1_, d2_, self.GetMutable()); } const int64_t* d1_; @@ -93,7 +71,7 @@ DDim DDim::operator+(const DDim& d) const { PADDLE_ENFORCE(rank_ == d.rank_); DDim ret; ret.rank_ = rank_; - ret.apply_visitor(DDimPlusVisitor(data(), d.data())); + ret.apply_visitor(DDimPlusVisitor(Get(), d.Get())); return ret; } @@ -103,7 +81,7 @@ struct DDimMulVisitor { template inline void operator()(Dim& self) const { - UnrollMul::Run(d1_, d2_, self.data()); + UnrollMul::Run(d1_, d2_, self.GetMutable()); } const int64_t* d1_; @@ -114,7 +92,7 @@ DDim DDim::operator*(const DDim& d) const { PADDLE_ENFORCE(rank_ == d.rank_); DDim ret; ret.rank_ = rank_; - ret.apply_visitor(DDimMulVisitor(data(), d.data())); + ret.apply_visitor(DDimMulVisitor(Get(), d.Get())); return ret; } @@ -124,9 +102,7 @@ void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } // NOLINT std::vector vectorize(const DDim& ddim) { std::vector result(DDim::kMaxRank); - for (int i = 0; i < ddim.size(); ++i) { - result[i] = ddim[i]; - } + dynamic_dim_assign(ddim.Get(), result.data(), ddim.size()); result.resize(ddim.size()); return result; } @@ -135,9 +111,7 @@ std::vector vectorize(const DDim& ddim) { // which does not fit cudnn inputs. std::vector vectorize2int(const DDim& ddim) { std::vector result(DDim::kMaxRank); - for (int i = 0; i < ddim.size(); ++i) { - result[i] = ddim[i]; - } + dynamic_dim_assign(ddim.Get(), result.data(), ddim.size()); result.resize(ddim.size()); return result; } @@ -154,15 +128,11 @@ int64_t product(const DDim& ddim) { } DDim slice_ddim(const DDim& dim, int begin, int end) { - PADDLE_ENFORCE(begin < end, - "Begin index must be less than end index in ddim slice."); PADDLE_ENFORCE(begin >= 0, "Begin index can't be less than zero in ddim slice."); DDim ret; ret.rank_ = end - begin; - for (int i = 0; i < ret.rank_; ++i) { - ret[i] = dim[i + begin]; - } + dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_); return ret; } diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index bff710040e..e65d451cde 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -22,6 +22,29 @@ limitations under the License. */ namespace paddle { namespace framework { +template +inline void dynamic_dim_assign(const T1* in, T2* out, int n) { +#define STATIC_DIM_ASSIGN_CASE(rank) \ + case rank: \ + static_dim_assign(in, out); \ + return + switch (n) { + STATIC_DIM_ASSIGN_CASE(0); + STATIC_DIM_ASSIGN_CASE(1); + STATIC_DIM_ASSIGN_CASE(2); + STATIC_DIM_ASSIGN_CASE(3); + STATIC_DIM_ASSIGN_CASE(4); + STATIC_DIM_ASSIGN_CASE(5); + STATIC_DIM_ASSIGN_CASE(6); + STATIC_DIM_ASSIGN_CASE(7); + STATIC_DIM_ASSIGN_CASE(8); + STATIC_DIM_ASSIGN_CASE(9); + default: + PADDLE_THROW("Invalid rank %d", n); + } +#undef STATIC_DIM_ASSIGN_CASE +} + /** * \brief A dynamically sized dimension. * @@ -33,8 +56,13 @@ class DDim { DDim() : rank_(1) { dim_[0] = 0; } - DDim(const int* d, int n); - DDim(const int64_t* d, int n); + DDim(const int* d, int n) : rank_(n) { + dynamic_dim_assign(d, dim_.GetMutable(), n); + } + + DDim(const int64_t* d, int n) : rank_(n) { + dynamic_dim_assign(d, dim_.GetMutable(), n); + } template /*implicit*/ DDim(const Dim& in) : rank_(D) { // NOLINT @@ -81,19 +109,11 @@ class DDim { DDim operator*(const DDim& d) const; - // Make DDim act like std::vector - using iterator = int64_t*; - using const_iterator = const int64_t*; - - int64_t* data() { return dim_.data(); } - const int64_t* data() const { return dim_.data(); } + inline const int64_t* Get() const { return dim_.Get(); } - iterator begin() { return data(); } - const_iterator begin() const { return data(); } - iterator end() { return data() + rank_; } - const_iterator end() const { return data() + rank_; } + inline int64_t* GetMutable() { return dim_.GetMutable(); } - int size() const { return rank_; } + inline int size() const { return rank_; } private: template diff --git a/paddle/fluid/framework/dim.h b/paddle/fluid/framework/dim.h index 3ae60a3119..21d91167a4 100644 --- a/paddle/fluid/framework/dim.h +++ b/paddle/fluid/framework/dim.h @@ -54,10 +54,6 @@ class Dim : public Array { HOSTDEVICE Dim() = default; - HOSTDEVICE int64_t* data() { return this->GetMutable(); } - - HOSTDEVICE const int64_t* data() const { return this->Get(); } - HOST std::string to_string() const; }; @@ -283,5 +279,10 @@ HOSTDEVICE Dim linear_to_dimension(int linear_index, const Dim& extents) { return result; } +template +inline void static_dim_assign(const T1* in, T2* out) { + UnrollAssign::Run(in, out); +} + } // namespace framework } // namespace paddle -- GitLab