#pragma once #include #include #include #include #include "paddle/platform/assert.h" #include "paddle/platform/hostdevice.h" namespace paddle { namespace framework { // Statically sized, statically indexed dimension template struct Dim { static constexpr int dimensions = i; template HOSTDEVICE Dim(int _head, Args... _tail) : head(_head), tail(_tail...) { static_assert(sizeof...(_tail) == i - 1, "Dim initialized with the wrong number of parameters"); } HOSTDEVICE Dim(int _head, const Dim& _tail) : head(_head), tail(_tail) {} HOSTDEVICE Dim() : head(0), tail() {} /** Construct a Dim from a linear index and size. Uses Fortran order * indexing. */ HOSTDEVICE Dim(int idx, const Dim& size) : head(idx % size.head), tail(idx / size.head, size.tail) {} /** Construct a Dim with each dimension set to the given index */ HOSTDEVICE Dim(int idx) : head(idx), tail(idx) {} HOSTDEVICE bool operator==(const Dim& o) const { return (head == o.head) && (tail == o.tail); } HOSTDEVICE bool operator!=(const Dim& o) const { return !(*this == o); } HOSTDEVICE int& operator[](int idx); HOSTDEVICE int operator[](int idx) const; HOST std::string to_string() const; int head; Dim tail; }; // Base case specialization template <> struct Dim<1> { static constexpr int dimensions = 1; HOSTDEVICE Dim(int _head) : head(_head) {} HOSTDEVICE Dim() : head(0) {} HOSTDEVICE Dim(int idx, const Dim<1>& size) : head(idx) { #ifndef __CUDA_ARCH__ if (idx >= size.head) { throw std::invalid_argument("Index out of range."); } #else PADDLE_ASSERT(idx < size.head); #endif } HOSTDEVICE bool operator==(const Dim<1>& o) const { return (head == o.head); } HOSTDEVICE bool operator!=(const Dim<1>& o) const { return !(*this == o); } HOSTDEVICE int& operator[](int idx); HOSTDEVICE int operator[](int idx) const; int head; }; namespace { // Helper for accessing Dim classes template struct DimGetter { // Return a copy if Dim is const template HOSTDEVICE static int impl(const D& d) { return DimGetter::impl(d.tail); } // Return a reference if Dim is mutable template HOSTDEVICE static int& impl(D& d) { return DimGetter::impl(d.tail); } }; // Eureka! We found the element! template <> struct DimGetter<0> { // Return a copy if Dim is const template HOSTDEVICE static int impl(const D& d) { return d.head; } // Return a reference if Dim is mutable template HOSTDEVICE static int& impl(D& d) { return d.head; } }; template HOSTDEVICE int& indexer(Dim& dim, int idx) { #ifndef __CUDA_ARCH__ if (idx < 0) { throw std::invalid_argument("Tried to access a negative dimension"); } #else PADDLE_ASSERT(idx >= 0); #endif if (idx == 0) { return dim.head; } return indexer(dim.tail, idx - 1); } template <> HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) { #ifndef __CUDA_ARCH__ if (idx != 0) { throw std::invalid_argument("Invalid index"); } #else PADDLE_ASSERT(idx == 0); #endif return dim.head; } template HOSTDEVICE int indexer(const Dim& dim, int idx) { #ifndef __CUDA_ARCH__ if (idx < 0) { throw std::invalid_argument("Tried to access a negative dimension"); } #else PADDLE_ASSERT(idx >= 0); #endif if (idx == 0) { return dim.head; } return indexer(dim.tail, idx - 1); } template <> HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) { #ifndef __CUDA_ARCH__ if (idx != 0) { throw std::invalid_argument("Invalid index"); } #else PADDLE_ASSERT(idx == 0); #endif return dim.head; } } // namespace // Static access to constant Dim template HOSTDEVICE int get(const Dim& d) { return DimGetter::impl(d); } // Static access to mutable Dim template HOSTDEVICE int& get(Dim& d) { return DimGetter::impl(d); } // Dynamic access to constant Dim template HOSTDEVICE int Dim::operator[](int i) const { return indexer(*this, i); } // Dynamic access to mutable Dim template HOSTDEVICE int& Dim::operator[](int i) { return indexer(*this, i); } // Dynamic access to constant Dim inline HOSTDEVICE int Dim<1>::operator[](int i) const { return indexer(*this, i); } // Dynamic access to mutable Dim inline HOSTDEVICE int& Dim<1>::operator[](int i) { return indexer(*this, i); } // Dynamic access to constant Dim // without std::enable_if will try to instantiate this on get<0>(d) template HOSTDEVICE typename std::enable_if<(l > 0), int>::type get(const Dim& d, int i) { return d[i]; } // Dynamic access to mutable Dim template HOSTDEVICE typename std::enable_if<(l > 0), int&>::type get(Dim& d, int i) { return d[i]; } // Dot product of two dims template HOSTDEVICE int linearize(const Dim& a, const Dim& b) { return a.head * b.head + linearize(a.tail, b.tail); } // Base case dot product of two Dims // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline int linearize(const Dim<1>& a, const Dim<1>& b) { return a.head * b.head; } // Product of a Dim template HOSTDEVICE int product(const Dim& a, int prod = 1) { return prod * a.head * product(a.tail); } // Base case product of a Dim // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline int product(const Dim<1>& a, int prod) { return prod * a.head; } // Is 0 <= idx_i < size_i for all i? template HOSTDEVICE bool contained(const Dim& idx, const Dim& size) { return ((0 <= idx.head) && (idx.head < size.head) && contained(idx.tail, size.tail)); } // Base case of is 0 <= idx_i < size_i ? // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) { return ((0 <= idx.head) && (idx.head < size.head)); } /** * \brief Check if a size and a stride create a Fortran order contiguous * block of memory. */ template HOST bool contiguous(const Dim& size, const Dim& stride, int mul = 1) { if (product(size) == 0) return true; int contiguous_stride = get<0>(size) == 1 ? 0 : mul; return (get<0>(stride) == contiguous_stride && contiguous(size.tail, stride.tail, mul * get<0>(size))); } ///\cond HIDDEN // Base case of contiguous, check the nth stride is the size of // the prefix multiply of n-1 dims. template <> inline bool contiguous(const Dim<1>& size, const Dim<1>& stride, int mul) { if (get<0>(size) == 0) return true; int contiguous_stride = get<0>(size) == 1 ? 0 : mul; return get<0>(stride) == contiguous_stride; } ///\endcond /** * \brief Compute exclusive prefix-multiply of a Dim. */ template HOSTDEVICE Dim ex_prefix_mul(const Dim& src, int mul = 1) { return Dim(mul, ex_prefix_mul(src.tail, mul * src.head)); } ///\cond HIDDEN // Base case of ex_prefix_mul // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) { return Dim<1>(mul); } ///\endcond /** * \brief Calculate strides of a contiguous array of the given size * * Sets the stride for any dimension with an extent of 1 to 0. * \param size Dim object containing the size of the array. * \param base The base stride to use. * \return Dim object the same size as \p size with the strides. */ template HOSTDEVICE Dim contiguous_strides(const Dim& size, int base = 1) { int stride = size.head == 1 ? 0 : base; return Dim(stride, contiguous_strides(size.tail, base * size.head)); } ///\cond HIDDEN // Base case of contiguous_strides template <> HOSTDEVICE inline Dim<1> contiguous_strides(const Dim<1>& size, int base) { int stride = size.head == 1 ? 0 : base; return Dim<1>(stride); } ///\endcond /** * Add two dimensions together */ template HOSTDEVICE Dim dim_plus(const Dim& a, const Dim& b) { return Dim(a.head + b.head, dim_plus(a.tail, b.tail)); } // Base case template <> HOSTDEVICE inline Dim<1> dim_plus(const Dim<1>& a, const Dim<1>& b) { return Dim<1>(a.head + b.head); } template HOSTDEVICE Dim operator+(const Dim& lhs, const Dim& rhs) { return dim_plus(lhs, rhs); } /** * Multiply two dimensions together */ template HOSTDEVICE Dim dim_mult(const Dim& a, const Dim& b) { return Dim(a.head * b.head, dim_mult(a.tail, b.tail)); } // Base case template <> HOSTDEVICE inline Dim<1> dim_mult(const Dim<1>& a, const Dim<1>& b) { return Dim<1>(a.head * b.head); } template HOSTDEVICE Dim operator*(const Dim& lhs, const Dim& rhs) { return dim_mult(lhs, rhs); } /** * \brief Normalize strides to ensure any dimension with extent 1 * has stride 0. * * \param size Dim object containing the size of an array * \param stride Dim object containing stride of an array * \return Dim object the same size as \p size with normalized strides * */ template HOSTDEVICE Dim normalize_strides(const Dim& size, const Dim& stride) { int norm_stride = size.head == 1 ? 0 : stride.head; return Dim(norm_stride, normalize_strides(size.tail, stride.tail)); } ///\cond HIDDEN template <> HOSTDEVICE inline Dim<1> normalize_strides(const Dim<1>& size, const Dim<1>& stride) { int norm_stride = size.head == 1 ? 0 : stride.head; return Dim<1>(norm_stride); } ///\endcond /** * Helper function to create a Dim * * \param idxes The type of Dim constructed depends on the number of params * */ template HOSTDEVICE Dim make_dim(Args... idxes) { return Dim(idxes...); } // Allows us to output a Dim // XXX For some reason, overloading fails to resolve this correctly template typename std::enable_if<(i > 1), std::ostream&>::type operator<<( std::ostream& os, const Dim& d) { os << d.head << ", " << d.tail; return os; } // Base case that allows us to output a Dim // XXX I wish this could be an overload instead of a template template typename std::enable_if<(i == 1), std::ostream&>::type operator<<( std::ostream& os, const Dim& d) { os << d.head; return os; } template HOST std::string Dim::to_string() const { std::stringstream stream; stream << *this; return stream.str(); } template HOSTDEVICE Dim linear_to_dimension(int linear_index, Dim extents) { Dim result; for (int i = 0; i < D - 1; ++i) { result[i] = linear_index % extents[i]; linear_index /= extents[i]; } result[D - 1] = linear_index; return result; } } // namespace framework } // namespace paddle