// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include "paddle/fluid/framework/array.h" #include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/hostdevice.h" namespace paddle { namespace framework { // Statically sized, statically indexed dimension template class Dim : public Array { public: static_assert(N >= 0, "N must be not less than 0"); static constexpr int kRank = N; using BaseClass = Array; inline Dim(int64_t head, const Dim& tail) { (*this)[0] = head; new (this->GetMutable() + 1) Dim(tail); } template HOSTDEVICE explicit Dim(int64_t head, Args... args) : BaseClass(head, args...) {} /** Construct a Dim from a linear index and size. Uses Fortran order * indexing. */ HOSTDEVICE Dim(int64_t idx, const Dim& size); /** Construct a Dim with each dimension set to the given index */ HOSTDEVICE explicit Dim(int64_t idx) { this->Fill(idx); } HOSTDEVICE Dim() = default; HOSTDEVICE int64_t* data() { return this->GetMutable(); } HOSTDEVICE const int64_t* data() const { return this->Get(); } HOST std::string to_string() const; }; namespace detail { template struct FortranOrderIndexingConstructorFunctor { HOSTDEVICE inline static void Run(const int64_t* in, int64_t* idx, int64_t* out) { out[kStart] = (*idx) % in[kStart]; (*idx) /= in[kStart]; FortranOrderIndexingConstructorFunctor::Run(in, idx, out); } }; template struct FortranOrderIndexingConstructorFunctor { HOSTDEVICE inline static void Run(const int64_t* in, int64_t* idx, int64_t* out) {} }; } // namespace detail template HOSTDEVICE Dim::Dim(int64_t idx, const Dim& size) { detail::FortranOrderIndexingConstructorFunctor<0, N, N == 0>::Run( size.Get(), &idx, this->GetMutable()); } template HOSTDEVICE inline int64_t get(const Dim& dim) { return dim[idx]; } template HOSTDEVICE inline int64_t& get(Dim& dim) { // NOLINT return dim[idx]; } template HOSTDEVICE inline int64_t get(const Dim& dim, int idx) { return dim[idx]; } template HOSTDEVICE inline int64_t& get(Dim& dim, int idx) { // NOLINT return dim[idx]; } // Dot product of two dims template HOSTDEVICE inline int64_t linearize(const Dim& a, const Dim& b) { return UnrollProduct::Run(a.Get(), b.Get()); } // Product of a Dim template HOSTDEVICE inline int64_t product(const Dim& a) { return UnrollProduct::Run(a.Get()); } // Is 0 <= idx_i < size_i for all i? namespace detail { template struct ContainedFunctor { HOSTDEVICE static inline bool Run(const int64_t* idx, const int64_t* size) { return (idx[kStart] >= 0 && idx[kStart] < size[kStart]) && ContainedFunctor::Run(idx, size); } }; template struct ContainedFunctor { HOSTDEVICE static constexpr inline bool Run(const int64_t* idx, const int64_t* size) { return true; } }; } // namespace detail template HOSTDEVICE inline bool contained(const Dim& idx, const Dim& size) { return detail::ContainedFunctor<0, N, N == 0>::Run(idx.Get(), size.Get()); } /** * \brief Compute exclusive prefix-multiply of a Dim. */ namespace detail { template struct ExPrefixMulFunctor { HOSTDEVICE static inline void Run(const int64_t* in, int64_t* out) { kStart == 0 ? out[kStart] = 1 : out[kStart] = out[kStart - 1] * in[kStart - 1]; detail::ExPrefixMulFunctor::Run(in, out); } }; template struct ExPrefixMulFunctor { HOSTDEVICE static inline void Run(const int64_t* in, int64_t* out) {} }; } // namespace detail template HOSTDEVICE inline Dim ex_prefix_mul(const Dim& src) { Dim ret; detail::ExPrefixMulFunctor<0, N, N == 0>::Run(src.Get(), ret.GetMutable()); return ret; } /** * Add two dimensions together */ template HOSTDEVICE inline Dim dim_plus(const Dim& a, const Dim& b) { Dim ret; UnrollAdd::Run(a.Get(), b.Get(), ret.GetMutable()); return ret; } template HOSTDEVICE inline Dim operator+(const Dim& lhs, const Dim& rhs) { return dim_plus(lhs, rhs); } /** * Multiply two dimensions together */ template HOSTDEVICE inline Dim dim_mult(const Dim& a, const Dim& b) { Dim ret; UnrollMul::Run(a.Get(), b.Get(), ret.GetMutable()); return ret; } template HOSTDEVICE Dim operator*(const Dim& lhs, const Dim& rhs) { return dim_mult(lhs, rhs); } /** * \brief Normalize strides to ensure any dimension with extent 1 * has stride 0. * * \param size Dim object containing the size of an array * \param stride Dim object containing stride of an array * \return Dim object the same size as \p size with normalized strides * */ namespace detail { template struct NormalizeStridesFunctor { HOSTDEVICE static void Run(const int64_t* size, const int64_t* stride, int64_t* ret) { ret[kStart] = (size[kStart] == 1 ? 0 : stride[kStart]); NormalizeStridesFunctor::Run( size, stride, ret); } }; template struct NormalizeStridesFunctor { HOSTDEVICE static void Run(const int64_t* size, const int64_t* stride, int64_t* ret) {} }; } // namespace detail template HOSTDEVICE Dim normalize_strides(const Dim& size, const Dim& stride) { Dim ret; detail::NormalizeStridesFunctor<0, N, N == 0>::Run(size.Get(), stride.Get(), ret.GetMutable()); return ret; } /** * Helper function to create a Dim * * \param idxes The type of Dim constructed depends on the number of params * */ template HOSTDEVICE inline Dim make_dim(Args... idxes) { return Dim(idxes...); } // Allows us to output a Dim template inline std::ostream& operator<<(std::ostream& os, const Dim& d) { os << d[0]; for (int i = 1; i < N; ++i) { os << ", " << d[i]; } return os; } inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) { return os; } template HOST std::string Dim::to_string() const { std::stringstream stream; stream << *this; return stream.str(); } template HOSTDEVICE Dim linear_to_dimension(int linear_index, const Dim& extents) { Dim result; for (int i = 0; i < N - 1; ++i) { result[i] = linear_index % extents[i]; linear_index /= extents[i]; } result[N - 1] = linear_index; return result; } } // namespace framework } // namespace paddle