/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include "common/enforce.h" namespace paddle_mobile { namespace framework { // Statically sized, statically indexed dimension template struct Dim { static constexpr int dimensions = i; template Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) { static_assert(sizeof...(_tail) == i - 1, "Dim initialized with the wrong number of parameters"); } Dim(int64_t _head, const Dim &_tail) : head(_head), tail(_tail) {} Dim() : head(0), tail() {} /** Construct a Dim from a linear index and size. Uses Fortran * order * indexing. */ Dim(int64_t idx, const Dim &size) : head(idx % size.head), tail(idx / size.head, size.tail) {} /** Construct a Dim with each dimension set to the given index */ explicit Dim(int64_t idx) : head(idx), tail(idx) {} bool operator==(const Dim &o) const { return (head == o.head) && (tail == o.tail); } bool operator!=(const Dim &o) const { return !(*this == o); } int64_t &operator[](int idx); int64_t operator[](int idx) const; std::string to_string() const; int64_t head; Dim tail; }; // Base case specialization template <> struct Dim<0> { static constexpr int dimensions = 0; explicit Dim(int64_t _head) {} Dim() {} Dim(int idx, const Dim<0> &size) { if (idx > 0) { PADDLE_MOBILE_THROW_EXCEPTION("Index out of range.") } } bool operator==(const Dim<0> &o) const { return true; } bool operator!=(const Dim<0> &o) const { return false; } int64_t &operator[](int idx); int64_t operator[](int idx) const; }; namespace { // Helper for accessing Dim classes template struct DimGetter { // Return a copy if Dim is const template static int64_t impl(const D &d) { return DimGetter::impl(d.tail); } // Return a reference if Dim is mutable template static int64_t &impl(D &d) { return DimGetter::impl(d.tail); } }; // Eureka! We found the element! template <> struct DimGetter<0> { // Return a copy if Dim is const template static int64_t impl(const D &d) { return d.head; } // Return a reference if Dim is mutable template static int64_t &impl(D &d) { return d.head; } }; template int64_t &indexer(Dim &dim, int idx) { if (idx < 0) { PADDLE_MOBILE_THROW_EXCEPTION("Tried to access a negative dimension") } if (idx == 0) { return dim.head; } return indexer(dim.tail, idx - 1); } template <> int64_t &indexer<0>(Dim<0> &dim, int idx) { PADDLE_MOBILE_THROW_EXCEPTION("Invalid index") } template int64_t indexer(const Dim &dim, int idx) { if (idx < 0) { PADDLE_MOBILE_THROW_EXCEPTION("Tried to access a negative dimension") } if (idx == 0) { return dim.head; } return indexer(dim.tail, idx - 1); } template <> int64_t indexer<0>(const Dim<0> &dim, int idx) { PADDLE_MOBILE_THROW_EXCEPTION("Invalid index") } } // namespace // Static access to constant Dim template int64_t get(const Dim &d) { return DimGetter::impl(d); } // Static access to mutable Dim template int64_t &get(Dim &d) { return DimGetter::impl(d); } // Dynamic access to constant Dim template int64_t Dim::operator[](int i) const { // std::cout << "l: " << l << std::endl; return indexer(*this, i); } // Dynamic access to mutable Dim template int64_t &Dim::operator[](int i) { return indexer(*this, i); } // Dynamic access to constant Dim inline int64_t Dim<0>::operator[](int i) const { return indexer(*this, i); } // Dynamic access to mutable Dim inline int64_t &Dim<0>::operator[](int i) { return indexer(*this, i); } // Dynamic access to constant Dim // without std::enable_if will try to instantiate this on get<0>(d) template typename std::enable_if<(l > 0), int64_t>::type get(const Dim &d, int i) { return d[i]; } // Dynamic access to mutable Dim template typename std::enable_if<(l > 0), int64_t &>::type get(Dim &d, int i) { return d[i]; } // Dot product of two dims template int64_t linearize(const Dim &a, const Dim &b) { return a.head * b.head + linearize(a.tail, b.tail); } // Base case dot product of two Dims // Notice it is inline because it is no longer a template template <> inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) { return 0; } // Product of a Dim template int64_t product(const Dim &a, int prod = 1) { return prod * a.head * product(a.tail); } // Base case product of a Dim // Notice it is inline because it is no longer a template template <> inline int64_t product(const Dim<0> &a, int prod) { return prod; } // Is 0 <= idx_i < size_i for all i? template bool contained(const Dim &idx, const Dim &size) { return ((0 <= idx.head) && (idx.head < size.head) && contained(idx.tail, size.tail)); } // Base case of is 0 <= idx_i < size_i ? // Notice it is inline because it is no longer a template template <> inline bool contained(const Dim<0> &idx, const Dim<0> &size) { return true; } /** * \brief Compute exclusive prefix-multiply of a Dim. */ template Dim ex_prefix_mul(const Dim &src, int mul = 1) { return Dim(mul, ex_prefix_mul(src.tail, mul * src.head)); } ///\cond HIDDEN // Base case of ex_prefix_mul // Notice it is inline because it is no longer a template template <> inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) { return Dim<0>(); } ///\endcond /** * Add two dimensions together */ template Dim dim_plus(const Dim &a, const Dim &b) { return Dim(a.head + b.head, dim_plus(a.tail, b.tail)); } // Base case template <> inline Dim<0> dim_plus(const Dim<0> &a, const Dim<0> &b) { return Dim<0>(); } template Dim operator+(const Dim &lhs, const Dim &rhs) { return dim_plus(lhs, rhs); } /** * Multiply two dimensions together */ template Dim dim_mult(const Dim &a, const Dim &b) { return Dim(a.head * b.head, dim_mult(a.tail, b.tail)); } // Base case template <> inline Dim<0> dim_mult(const Dim<0> &a, const Dim<0> &b) { return Dim<0>(); } template Dim operator*(const Dim &lhs, const Dim &rhs) { return dim_mult(lhs, rhs); } /** * \brief Normalize strides to ensure any dimension with extent 1 * has stride 0. * * \param size Dim object containing the size of an array * \param stride Dim object containing stride of an array * \return Dim object the same size as \p size with normalized strides * */ template Dim normalize_strides(const Dim &size, const Dim &stride) { int norm_stride = size.head == 1 ? 0 : stride.head; return Dim(norm_stride, normalize_strides(size.tail, stride.tail)); } ///\cond HIDDEN template <> inline Dim<0> normalize_strides(const Dim<0> &size, const Dim<0> &stride) { return Dim<0>(); } ///\endcond /** * Helper function to create a Dim * * \param idxes The type of Dim constructed depends on the number of * params * */ template Dim make_dim(Args... idxes) { return Dim(idxes...); } } // namespace framework } // namespace paddle_mobile