/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include "platform/hostdevice.h" namespace paddle_mobile { namespace framework { // Statically sized, statically indexed dimension template struct Dim { static constexpr int dimensions = i; template HOSTDEVICE Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) { static_assert(sizeof...(_tail) == i - 1, "Dim initialized with the wrong number of parameters"); } HOSTDEVICE Dim(int64_t _head, const Dim &_tail) : head(_head), tail(_tail) {} HOSTDEVICE Dim() : head(0), tail() {} /** Construct a Dim from a linear index and size. Uses Fortran * order * indexing. */ HOSTDEVICE Dim(int64_t idx, const Dim &size) : head(idx % size.head), tail(idx / size.head, size.tail) {} /** Construct a Dim with each dimension set to the given index */ HOSTDEVICE Dim(int64_t idx) : head(idx), tail(idx) {} HOSTDEVICE bool operator==(const Dim &o) const { return (head == o.head) && (tail == o.tail); } HOSTDEVICE bool operator!=(const Dim &o) const { return !(*this == o); } HOSTDEVICE int64_t &operator[](int idx); HOSTDEVICE int64_t operator[](int idx) const; HOST std::string to_string() const; int64_t head; Dim tail; }; // Base case specialization template <> struct Dim<0> { static constexpr int dimensions = 0; HOSTDEVICE Dim(int64_t _head) {} HOSTDEVICE Dim() {} HOSTDEVICE Dim(int idx, const Dim<0> &size) { #ifndef __CUDA_ARCH__ if (idx > 0) { throw std::invalid_argument("Index out of range."); } #else PADDLE_ASSERT(idx == 0); #endif } HOSTDEVICE bool operator==(const Dim<0> &o) const { return true; } HOSTDEVICE bool operator!=(const Dim<0> &o) const { return false; } HOSTDEVICE int64_t &operator[](int idx); HOSTDEVICE int64_t operator[](int idx) const; }; namespace { // Helper for accessing Dim classes template struct DimGetter { // Return a copy if Dim is const template HOSTDEVICE static int64_t impl(const D &d) { return DimGetter::impl(d.tail); } // Return a reference if Dim is mutable template HOSTDEVICE static int64_t &impl(D &d) { return DimGetter::impl(d.tail); } }; // Eureka! We found the element! template <> struct DimGetter<0> { // Return a copy if Dim is const template HOSTDEVICE static int64_t impl(const D &d) { return d.head; } // Return a reference if Dim is mutable template HOSTDEVICE static int64_t &impl(D &d) { return d.head; } }; template HOSTDEVICE int64_t &indexer(Dim &dim, int idx) { #ifndef __CUDA_ARCH__ if (idx < 0) { throw std::invalid_argument("Tried to access a negative dimension"); } #else PADDLE_ASSERT(idx >= 0); #endif if (idx == 0) { return dim.head; } return indexer(dim.tail, idx - 1); } template <> HOSTDEVICE int64_t &indexer<0>(Dim<0> &dim, int idx) { #ifndef __CUDA_ARCH__ throw std::invalid_argument("Invalid index"); #else PADDLE_ASSERT(false); #if CUDA_VERSION < 8000 // On CUDA versions previous to 8.0, only __shared__ variables // could be declared as static in the device code. int64_t head = 0; #else static int64_t head = 0; #endif return head; #endif } template HOSTDEVICE int64_t indexer(const Dim &dim, int idx) { #ifndef __CUDA_ARCH__ if (idx < 0) { throw std::invalid_argument("Tried to access a negative dimension"); } #else PADDLE_ASSERT(idx >= 0); #endif if (idx == 0) { return dim.head; } return indexer(dim.tail, idx - 1); } template <> HOSTDEVICE int64_t indexer<0>(const Dim<0> &dim, int idx) { #ifndef __CUDA_ARCH__ throw std::invalid_argument("Invalid index"); #else PADDLE_ASSERT(false); #if CUDA_VERSION < 8000 // On CUDA versions previous to 8.0, only __shared__ variables // could be declared as static in the device code. int64_t head = 0; #else static int64_t head = 0; #endif return head; #endif } } // namespace // Static access to constant Dim template HOSTDEVICE int64_t get(const Dim &d) { return DimGetter::impl(d); } // Static access to mutable Dim template HOSTDEVICE int64_t &get(Dim &d) { return DimGetter::impl(d); } // Dynamic access to constant Dim template HOSTDEVICE int64_t Dim::operator[](int i) const { // std::cout << "l: " << l << std::endl; return indexer(*this, i); } // Dynamic access to mutable Dim template HOSTDEVICE int64_t &Dim::operator[](int i) { return indexer(*this, i); } // Dynamic access to constant Dim inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const { return indexer(*this, i); } // Dynamic access to mutable Dim inline HOSTDEVICE int64_t &Dim<0>::operator[](int i) { return indexer(*this, i); } // Dynamic access to constant Dim // without std::enable_if will try to instantiate this on get<0>(d) template HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim &d, int i) { return d[i]; } // Dynamic access to mutable Dim template HOSTDEVICE typename std::enable_if<(l > 0), int64_t &>::type get(Dim &d, int i) { return d[i]; } // Dot product of two dims template HOSTDEVICE int64_t linearize(const Dim &a, const Dim &b) { return a.head * b.head + linearize(a.tail, b.tail); } // Base case dot product of two Dims // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) { return 0; } // Product of a Dim template HOSTDEVICE int64_t product(const Dim &a, int prod = 1) { return prod * a.head * product(a.tail); } // Base case product of a Dim // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline int64_t product(const Dim<0> &a, int prod) { return prod; } // Is 0 <= idx_i < size_i for all i? template HOSTDEVICE bool contained(const Dim &idx, const Dim &size) { return ((0 <= idx.head) && (idx.head < size.head) && contained(idx.tail, size.tail)); } // Base case of is 0 <= idx_i < size_i ? // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline bool contained(const Dim<0> &idx, const Dim<0> &size) { return true; } /** * \brief Compute exclusive prefix-multiply of a Dim. */ template HOSTDEVICE Dim ex_prefix_mul(const Dim &src, int mul = 1) { return Dim(mul, ex_prefix_mul(src.tail, mul * src.head)); } ///\cond HIDDEN // Base case of ex_prefix_mul // Notice it is inline because it is no longer a template template <> HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) { return Dim<0>(); } ///\endcond /** * Add two dimensions together */ template HOSTDEVICE Dim dim_plus(const Dim &a, const Dim &b) { return Dim(a.head + b.head, dim_plus(a.tail, b.tail)); } // Base case template <> HOSTDEVICE inline Dim<0> dim_plus(const Dim<0> &a, const Dim<0> &b) { return Dim<0>(); } template HOSTDEVICE Dim operator+(const Dim &lhs, const Dim &rhs) { return dim_plus(lhs, rhs); } /** * Multiply two dimensions together */ template HOSTDEVICE Dim dim_mult(const Dim &a, const Dim &b) { return Dim(a.head * b.head, dim_mult(a.tail, b.tail)); } // Base case template <> HOSTDEVICE inline Dim<0> dim_mult(const Dim<0> &a, const Dim<0> &b) { return Dim<0>(); } template HOSTDEVICE Dim operator*(const Dim &lhs, const Dim &rhs) { return dim_mult(lhs, rhs); } /** * \brief Normalize strides to ensure any dimension with extent 1 * has stride 0. * * \param size Dim object containing the size of an array * \param stride Dim object containing stride of an array * \return Dim object the same size as \p size with normalized strides * */ template HOSTDEVICE Dim normalize_strides(const Dim &size, const Dim &stride) { int norm_stride = size.head == 1 ? 0 : stride.head; return Dim(norm_stride, normalize_strides(size.tail, stride.tail)); } ///\cond HIDDEN template <> HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0> &size, const Dim<0> &stride) { return Dim<0>(); } ///\endcond /** * Helper function to create a Dim * * \param idxes The type of Dim constructed depends on the number of * params * */ template HOSTDEVICE Dim make_dim(Args... idxes) { return Dim(idxes...); } // Allows us to output a Dim // XXX For some reason, overloading fails to resolve this correctly template typename std::enable_if<(i > 1), std::ostream &>::type operator<<( std::ostream &os, const Dim &d) { os << d.head << ", " << d.tail; return os; } // Base case that allows us to output a Dim // XXX I wish this could be an overload instead of a template template typename std::enable_if<(i == 1), std::ostream &>::type operator<<( std::ostream &os, const Dim &d) { os << d.head; return os; } inline std::ostream &operator<<(std::ostream &os, const Dim<0> &d) { return os; } template HOST std::string Dim::to_string() const { std::stringstream stream; stream << *this; return stream.str(); } template HOSTDEVICE Dim linear_to_dimension(int linear_index, Dim extents) { Dim result; for (int i = 0; i < D - 1; ++i) { result[i] = linear_index % extents[i]; linear_index /= extents[i]; } result[D - 1] = linear_index; return result; } } // namespace framework } // namespace paddle_mobile