diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index 95078093e5905a47cd363815b62a1ed0db089bf2..37544e97eb655287860f8b113e0e1f1ed298474b 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -42,7 +42,8 @@ struct DDimEqualityVisitor { }; bool DDim::operator==(const DDim& d) const { - return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.Get())); + return size() == d.size() && + this->apply_visitor(DDimEqualityVisitor(d.Get())); } bool DDim::operator!=(const DDim& d) const { return !(*this == d); } @@ -61,7 +62,7 @@ struct DDimPlusVisitor { }; DDim DDim::operator+(const DDim& d) const { - PADDLE_ENFORCE(rank_ == d.rank_); + PADDLE_ENFORCE(size() == d.size()); DDim ret; ret.rank_ = rank_; ret.apply_visitor(DDimPlusVisitor(Get(), d.Get())); @@ -82,7 +83,7 @@ struct DDimMulVisitor { }; DDim DDim::operator*(const DDim& d) const { - PADDLE_ENFORCE(rank_ == d.rank_); + PADDLE_ENFORCE(size() == d.size()); DDim ret; ret.rank_ = rank_; ret.apply_visitor(DDimMulVisitor(Get(), d.Get())); @@ -121,13 +122,11 @@ int64_t product(const DDim& ddim) { } DDim slice_ddim(const DDim& dim, int begin, int end) { - PADDLE_ENFORCE(begin >= 0, - "Begin index can't be less than zero in ddim slice."); - int len = end - begin; - DDim ret; - ret.rank_ = len; - dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_); - return ret; + PADDLE_ENFORCE(begin >= 0 && end <= dim.size(), + "[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.", + begin, end, dim.size()); + // Constructor of DDim would check whether end - begin is valid + return DDim(dim.Get() + begin, end - begin); } int arity(const DDim& d) { return d.size(); } @@ -138,8 +137,8 @@ struct DDimPrinter { std::ostream& os; explicit DDimPrinter(std::ostream& os_) : os(os_) {} - template - void operator()(const T& t) { + template + void operator()(const Dim& t) { os << t; } }; @@ -152,12 +151,11 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { } DDim flatten_to_2d(const DDim& src, int num_col_dims) { - int rank = src.size(); - return make_ddim({product(slice_ddim(src, 0, num_col_dims)), - product(slice_ddim(src, num_col_dims, rank))}); + return DDim({product(slice_ddim(src, 0, num_col_dims)), + product(slice_ddim(src, num_col_dims, src.size()))}); } -DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } +DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); } DDim stride(const DDim& ddim) { DDim strides; diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index 0d7b12120525b538ac714e8f89156f8d79f8c41d..452072a58762d9a5668e351e98eb6e7c8e38f4b3 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -124,16 +124,16 @@ class DDim { inline int size() const { return rank_; } private: - template - inline Dim& UnsafeCast() { - return const_cast&>(const_cast(this)->UnsafeCast()); + template + inline Dim& UnsafeCast() { + return const_cast&>(const_cast(this)->UnsafeCast()); } - template - inline const Dim& UnsafeCast() const { - static_assert(M >= 0 && M <= kMaxRank, "Invalid rank"); + template + inline const Dim& UnsafeCast() const { + static_assert(D >= 0 && D <= kMaxRank, "Invalid rank"); auto* p = static_cast(&dim_); - return *reinterpret_cast*>(p); + return *reinterpret_cast*>(p); } friend DDim slice_ddim(const DDim& dim, int begin, int end); diff --git a/paddle/fluid/framework/dim.h b/paddle/fluid/framework/dim.h index 21d91167a4475466ec77fce7f47edb5800e74dca..88aee8379d835ce88b6b348aca99eb4a35bbeb5c 100644 --- a/paddle/fluid/framework/dim.h +++ b/paddle/fluid/framework/dim.h @@ -28,17 +28,17 @@ namespace paddle { namespace framework { // Statically sized, statically indexed dimension -template -class Dim : public Array { +template +class Dim : public Array { public: - static_assert(N >= 0, "N must be not less than 0"); + static_assert(D >= 0, "D must be not less than 0"); - static constexpr int kRank = N; - using BaseClass = Array; + static constexpr int kRank = D; + using BaseClass = Array; - inline Dim(int64_t head, const Dim& tail) { + inline Dim(int64_t head, const Dim& tail) { (*this)[0] = head; - new (this->GetMutable() + 1) Dim(tail); + new (this->GetMutable() + 1) Dim(tail); } template @@ -47,7 +47,7 @@ class Dim : public Array { /** Construct a Dim from a linear index and size. Uses Fortran order * indexing. */ - HOSTDEVICE Dim(int64_t idx, const Dim& size); + HOSTDEVICE Dim(int64_t idx, const Dim& size); /** Construct a Dim with each dimension set to the given index */ HOSTDEVICE explicit Dim(int64_t idx) { this->Fill(idx); } @@ -77,42 +77,42 @@ struct FortranOrderIndexingConstructorFunctor { }; } // namespace detail -template -HOSTDEVICE Dim::Dim(int64_t idx, const Dim& size) { - detail::FortranOrderIndexingConstructorFunctor<0, N, N == 0>::Run( +template +HOSTDEVICE Dim::Dim(int64_t idx, const Dim& size) { + detail::FortranOrderIndexingConstructorFunctor<0, D, D == 0>::Run( size.Get(), &idx, this->GetMutable()); } -template -HOSTDEVICE inline int64_t get(const Dim& dim) { +template +HOSTDEVICE inline int64_t get(const Dim& dim) { return dim[idx]; } -template -HOSTDEVICE inline int64_t& get(Dim& dim) { // NOLINT +template +HOSTDEVICE inline int64_t& get(Dim& dim) { // NOLINT return dim[idx]; } -template -HOSTDEVICE inline int64_t get(const Dim& dim, int idx) { +template +HOSTDEVICE inline int64_t get(const Dim& dim, int idx) { return dim[idx]; } -template -HOSTDEVICE inline int64_t& get(Dim& dim, int idx) { // NOLINT +template +HOSTDEVICE inline int64_t& get(Dim& dim, int idx) { // NOLINT return dim[idx]; } // Dot product of two dims -template -HOSTDEVICE inline int64_t linearize(const Dim& a, const Dim& b) { - return UnrollProduct::Run(a.Get(), b.Get()); +template +HOSTDEVICE inline int64_t linearize(const Dim& a, const Dim& b) { + return UnrollProduct::Run(a.Get(), b.Get()); } // Product of a Dim -template -HOSTDEVICE inline int64_t product(const Dim& a) { - return UnrollProduct::Run(a.Get()); +template +HOSTDEVICE inline int64_t product(const Dim& a) { + return UnrollProduct::Run(a.Get()); } // Is 0 <= idx_i < size_i for all i? @@ -135,9 +135,9 @@ struct ContainedFunctor { }; } // namespace detail -template -HOSTDEVICE inline bool contained(const Dim& idx, const Dim& size) { - return detail::ContainedFunctor<0, N, N == 0>::Run(idx.Get(), size.Get()); +template +HOSTDEVICE inline bool contained(const Dim& idx, const Dim& size) { + return detail::ContainedFunctor<0, D, D == 0>::Run(idx.Get(), size.Get()); } /** @@ -160,40 +160,40 @@ struct ExPrefixMulFunctor { }; } // namespace detail -template -HOSTDEVICE inline Dim ex_prefix_mul(const Dim& src) { - Dim ret; - detail::ExPrefixMulFunctor<0, N, N == 0>::Run(src.Get(), ret.GetMutable()); +template +HOSTDEVICE inline Dim ex_prefix_mul(const Dim& src) { + Dim ret; + detail::ExPrefixMulFunctor<0, D, D == 0>::Run(src.Get(), ret.GetMutable()); return ret; } /** * Add two dimensions together */ -template -HOSTDEVICE inline Dim dim_plus(const Dim& a, const Dim& b) { - Dim ret; - UnrollAdd::Run(a.Get(), b.Get(), ret.GetMutable()); +template +HOSTDEVICE inline Dim dim_plus(const Dim& a, const Dim& b) { + Dim ret; + UnrollAdd::Run(a.Get(), b.Get(), ret.GetMutable()); return ret; } -template -HOSTDEVICE inline Dim operator+(const Dim& lhs, const Dim& rhs) { +template +HOSTDEVICE inline Dim operator+(const Dim& lhs, const Dim& rhs) { return dim_plus(lhs, rhs); } /** * Multiply two dimensions together */ -template -HOSTDEVICE inline Dim dim_mult(const Dim& a, const Dim& b) { - Dim ret; - UnrollMul::Run(a.Get(), b.Get(), ret.GetMutable()); +template +HOSTDEVICE inline Dim dim_mult(const Dim& a, const Dim& b) { + Dim ret; + UnrollMul::Run(a.Get(), b.Get(), ret.GetMutable()); return ret; } -template -HOSTDEVICE Dim operator*(const Dim& lhs, const Dim& rhs) { +template +HOSTDEVICE Dim operator*(const Dim& lhs, const Dim& rhs) { return dim_mult(lhs, rhs); } @@ -224,10 +224,10 @@ struct NormalizeStridesFunctor { }; } // namespace detail -template -HOSTDEVICE Dim normalize_strides(const Dim& size, const Dim& stride) { - Dim ret; - detail::NormalizeStridesFunctor<0, N, N == 0>::Run(size.Get(), stride.Get(), +template +HOSTDEVICE Dim normalize_strides(const Dim& size, const Dim& stride) { + Dim ret; + detail::NormalizeStridesFunctor<0, D, D == 0>::Run(size.Get(), stride.Get(), ret.GetMutable()); return ret; } @@ -245,10 +245,10 @@ HOSTDEVICE inline Dim make_dim(Args... idxes) { } // Allows us to output a Dim -template -inline std::ostream& operator<<(std::ostream& os, const Dim& d) { +template +inline std::ostream& operator<<(std::ostream& os, const Dim& d) { os << d[0]; - for (int i = 1; i < N; ++i) { + for (int i = 1; i < D; ++i) { os << ", " << d[i]; } return os; @@ -258,23 +258,23 @@ inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) { return os; } -template -HOST std::string Dim::to_string() const { +template +HOST std::string Dim::to_string() const { std::stringstream stream; stream << *this; return stream.str(); } -template -HOSTDEVICE Dim linear_to_dimension(int linear_index, const Dim& extents) { - Dim result; +template +HOSTDEVICE Dim linear_to_dimension(int linear_index, const Dim& extents) { + Dim result; - for (int i = 0; i < N - 1; ++i) { + for (int i = 0; i < D - 1; ++i) { result[i] = linear_index % extents[i]; linear_index /= extents[i]; } - result[N - 1] = linear_index; + result[D - 1] = linear_index; return result; }