// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include "paddle/platform/assert.h" #include "paddle/platform/hostdevice.h" namespace paddle { namespace framework { // Statically sized, statically indexed dimension template struct Dim { static constexpr int dimensions = i; template HOSTDEVICE Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) { static_assert(sizeof...(_tail) == i - 1, "Dim initialized with the wrong number of parameters"); } HOSTDEVICE Dim(int64_t _head, const Dim& _tail) : head(_head), tail(_tail) {} HOSTDEVICE Dim() : head(0), tail() {} /** Construct a Dim from a linear index and size. Uses Fortran order * indexing. */ HOSTDEVICE Dim(int64_t idx, const Dim& size) : head(idx % size.head), tail(idx / size.head, size.tail) {} /** Construct a Dim with each dimension set to the given index */ HOSTDEVICE Dim(int64_t idx) : head(idx), tail(idx) {} HOSTDEVICE bool operator==(const Dim& o) const { return (head == o.head) && (tail == o.tail); } HOSTDEVICE bool operator!=(const Dim& o) const { return !(*this == o); } HOSTDEVICE int64_t& operator[](int idx); HOSTDEVICE int64_t operator[](int idx) const; HOST std::string to_string() const; int64_t head; Dim tail; }; // Base case specialization template <> struct Dim<1> { static constexpr int dimensions = 1; HOSTDEVICE Dim(int64_t _head) : head(_head) {} HOSTDEVICE Dim() : head(0) {} HOSTDEVICE Dim(int idx, const Dim<1>& size) : head(idx) { #ifndef __CUDA_ARCH__ if (idx >= size.head) { throw std::invalid_argument("Index out of range."); } #else PADDLE_ASSERT(idx < size.head); #endif } HOSTDEVICE bool operator==(const Dim<1>& o) const { return (head == o.head); } HOSTDEVICE bool operator!=(const Dim<1>& o) const { return !(*this == o); } HOSTDEVICE int64_t& operator[](int idx); HOSTDEVICE int64_t operator[](int idx) const; int64_t head; }; namespace { // Helper for accessing Dim classes template struct DimGetter { // Return a copy if Dim is const template HOSTDEVICE static int64_t impl(const D& d) { return DimGetter::impl(d.tail); } // Return a reference if Dim is mutable template HOSTDEVICE static int64_t& impl(D& d) { return DimGetter::impl(d.tail); } }; // Eureka! We found the element! template <> struct DimGetter<0> { // Return a copy if Dim is const template HOSTDEVICE static int64_t impl(const D& d) { return d.head; } // Return a reference if Dim is mutable template HOSTDEVICE static int64_t& impl(D& d) { return d.head; } }; template HOSTDEVICE int64_t& indexer(Dim& dim, int idx) { #ifndef __CUDA_ARCH__ if (idx < 0) { throw std::invalid_argument("Tried to access a negative dimension"); } #else PADDLE_ASSERT(idx >= 0); #endif if (idx == 0) { return dim.head; } return indexer(dim.tail, idx - 1); } template <> HOSTDEVICE int64_t& indexer<1>(Dim<1>& dim, int idx) { #ifndef __CUDA_ARCH__ if (idx != 0) { throw std::invalid_argument("Invalid index"); } #else PADDLE_ASSERT(idx == 0); #endif return dim.head; } template HOSTDEVICE int64_t indexer(const Dim& dim, int idx) { #ifndef __CUDA_ARCH__ if (idx < 0) { throw std::invalid_argument("Tried to access a negative dimension"); } #else PADDLE_ASSERT(idx >= 0); #endif if (idx == 0) { return dim.head; } return indexer(dim.tail, idx - 1); } template <> HOSTDEVICE int64_t indexer<1>(const Dim<1>& dim, int idx) { #ifndef __CUDA_ARCH__ if (idx != 0) { throw std::invalid_argument("Invalid index"); } #else PADDLE_ASSERT(idx == 0); #endif return dim.head; } } // namespace // Static access to constant Dim template HOSTDEVICE int64_t get(const Dim& d) { return DimGetter::impl(d); } // Static access to mutable Dim template HOSTDEVICE int64_t& get(Dim& d) { return DimGetter::impl(d); } // Dynamic access to constant Dim template HOSTDEVICE int64_t Dim::operator[](int i) const { return indexer(*this, i); } // Dynamic access to mutable Dim template HOSTDEVICE int64_t& Dim::operator[](int i) { return indexer(*this, i); } // Dynamic access to constant Dim inline HOSTDEVICE int64_t Dim<1>::operator[](int i) const { return indexer(*this, i); } // Dynamic access to mutable Dim inline HOSTDEVICE int64_t& Dim<1>::operator[](int i) { return indexer(*this, i); } // Dynamic access to constant Dim // without std::enable_if will try to instantiate this on get<0>(d) template HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim& d, int i) { return d[i]; } // Dynamic access to mutable Dim template HOSTDEVICE typename std::enable_if<(l > 0), int64_t&>::type get(Dim& d, int i) { return d[i]; } // Dot product of two dims template HOSTDEVICE int64_t linearize(const Dim& a, const Dim& b) { return a.head * b.head + linearize(a.tail, b.tail); } // Base case dot product of two Dims // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline int64_t linearize(const Dim<1>& a, const Dim<1>& b) { return a.head * b.head; } // Product of a Dim template HOSTDEVICE int64_t product(const Dim& a, int prod = 1) { return prod * a.head * product(a.tail); } // Base case product of a Dim // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline int64_t product(const Dim<1>& a, int prod) { return prod * a.head; } // Is 0 <= idx_i < size_i for all i? template HOSTDEVICE bool contained(const Dim& idx, const Dim& size) { return ((0 <= idx.head) && (idx.head < size.head) && contained(idx.tail, size.tail)); } // Base case of is 0 <= idx_i < size_i ? // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) { return ((0 <= idx.head) && (idx.head < size.head)); } /** * \brief Compute exclusive prefix-multiply of a Dim. */ template HOSTDEVICE Dim ex_prefix_mul(const Dim& src, int mul = 1) { return Dim(mul, ex_prefix_mul(src.tail, mul * src.head)); } ///\cond HIDDEN // Base case of ex_prefix_mul // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) { return Dim<1>(mul); } ///\endcond /** * Add two dimensions together */ template HOSTDEVICE Dim dim_plus(const Dim& a, const Dim& b) { return Dim(a.head + b.head, dim_plus(a.tail, b.tail)); } // Base case template <> HOSTDEVICE inline Dim<1> dim_plus(const Dim<1>& a, const Dim<1>& b) { return Dim<1>(a.head + b.head); } template HOSTDEVICE Dim operator+(const Dim& lhs, const Dim& rhs) { return dim_plus(lhs, rhs); } /** * Multiply two dimensions together */ template HOSTDEVICE Dim dim_mult(const Dim& a, const Dim& b) { return Dim(a.head * b.head, dim_mult(a.tail, b.tail)); } // Base case template <> HOSTDEVICE inline Dim<1> dim_mult(const Dim<1>& a, const Dim<1>& b) { return Dim<1>(a.head * b.head); } template HOSTDEVICE Dim operator*(const Dim& lhs, const Dim& rhs) { return dim_mult(lhs, rhs); } /** * \brief Normalize strides to ensure any dimension with extent 1 * has stride 0. * * \param size Dim object containing the size of an array * \param stride Dim object containing stride of an array * \return Dim object the same size as \p size with normalized strides * */ template HOSTDEVICE Dim normalize_strides(const Dim& size, const Dim& stride) { int norm_stride = size.head == 1 ? 0 : stride.head; return Dim(norm_stride, normalize_strides(size.tail, stride.tail)); } ///\cond HIDDEN template <> HOSTDEVICE inline Dim<1> normalize_strides(const Dim<1>& size, const Dim<1>& stride) { int norm_stride = size.head == 1 ? 0 : stride.head; return Dim<1>(norm_stride); } ///\endcond /** * Helper function to create a Dim * * \param idxes The type of Dim constructed depends on the number of params * */ template HOSTDEVICE Dim make_dim(Args... idxes) { return Dim(idxes...); } // Allows us to output a Dim // XXX For some reason, overloading fails to resolve this correctly template typename std::enable_if<(i > 1), std::ostream&>::type operator<<( std::ostream& os, const Dim& d) { os << d.head << ", " << d.tail; return os; } // Base case that allows us to output a Dim // XXX I wish this could be an overload instead of a template template typename std::enable_if<(i == 1), std::ostream&>::type operator<<( std::ostream& os, const Dim& d) { os << d.head; return os; } template HOST std::string Dim::to_string() const { std::stringstream stream; stream << *this; return stream.str(); } template HOSTDEVICE Dim linear_to_dimension(int linear_index, Dim extents) { Dim result; for (int i = 0; i < D - 1; ++i) { result[i] = linear_index % extents[i]; linear_index /= extents[i]; } result[D - 1] = linear_index; return result; } } // namespace framework } // namespace paddle