// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/utils/dim.h" namespace phi { #define PADDLE_VISIT_DDIM_BASE(rank, callback) \ case (rank): { \ constexpr auto kRank = (rank); \ return (callback); \ } #define PADDLE_VISIT_DDIM(rank, callback) \ switch (rank) { \ PADDLE_VISIT_DDIM_BASE(0, callback); \ PADDLE_VISIT_DDIM_BASE(1, callback); \ PADDLE_VISIT_DDIM_BASE(2, callback); \ PADDLE_VISIT_DDIM_BASE(3, callback); \ PADDLE_VISIT_DDIM_BASE(4, callback); \ PADDLE_VISIT_DDIM_BASE(5, callback); \ PADDLE_VISIT_DDIM_BASE(6, callback); \ PADDLE_VISIT_DDIM_BASE(7, callback); \ PADDLE_VISIT_DDIM_BASE(8, callback); \ PADDLE_VISIT_DDIM_BASE(9, callback); \ default: \ PADDLE_THROW(phi::errors::Unimplemented( \ "Invalid dimension to be accessed. Now only supports access to " \ "dimension 0 to 9, but received dimension is %d.", \ rank)); \ } template inline void dynamic_dim_assign(const T1* in, T2* out, int n) { PADDLE_VISIT_DDIM(n, (static_dim_assign(in, out))); } /** * \brief A dynamically sized dimension. * * The number of dimensions must be between [1, 9]. */ class DDim { public: constexpr static int kMaxRank = 9; DDim() : rank_(1) { dim_[0] = 0; } DDim(const DDim& ddim) : dim_() { CopyFrom(ddim); } DDim(const int* d, int n) : rank_(n) { dynamic_dim_assign(d, dim_.GetMutable(), n); } DDim(const int64_t* d, int n) : rank_(n) { dynamic_dim_assign(d, dim_.GetMutable(), n); } template /*implicit*/ DDim(const Dim& in) : rank_(D) { // NOLINT UnsafeCast() = in; } /*implicit*/ DDim(std::initializer_list init_list) : DDim(init_list.begin(), init_list.size()) {} inline DDim& operator=(const DDim& ddim) { return CopyFrom(ddim); } template inline DDim& operator=(const Dim& dim) { rank_ = D; UnsafeCast() = dim; return *this; } inline int64_t& operator[](int idx) { return dim_[idx]; } inline int64_t operator[](int idx) const { return dim_[idx]; } int64_t& at(int idx) { PADDLE_ENFORCE_GE(idx, 0, phi::errors::InvalidArgument( "Invalid DDim index to be accessed. The valid index " "is between 0 and %d, but received index is %d.", rank_, idx)); PADDLE_ENFORCE_LT(idx, rank_, phi::errors::InvalidArgument( "Invalid DDim index to be accessed. The valid index " "is between 0 and %d, but received index is %d.", rank_, idx)); return dim_[idx]; } int64_t at(int idx) const { PADDLE_ENFORCE_GE(idx, 0, phi::errors::InvalidArgument( "Invalid DDim index to be accessed. The valid index " "is between 0 and %d, but received index is %d.", rank_, idx)); PADDLE_ENFORCE_LT(idx, rank_, phi::errors::InvalidArgument( "Invalid DDim index to be accessed. The valid index " "is between 0 and %d, but received index is %d.", rank_, idx)); return dim_[idx]; } template typename std::result_of&)>::type apply_visitor( Visitor&& visitor) { PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast())); } template typename std::result_of&)>::type apply_visitor( Visitor&& visitor) const { PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast())); } bool operator==(const DDim& d) const; bool operator!=(const DDim& d) const; inline const int64_t* Get() const { return dim_.Get(); } inline int64_t* GetMutable() { return dim_.GetMutable(); } inline int size() const { return rank_; } std::string to_str() const; DDim reshape(const std::vector& shape) const; DDim transpose(const std::vector& axis) const; private: template inline Dim& UnsafeCast() { static_assert(D >= 0 && D <= kMaxRank, "Invalid rank"); auto* p = static_cast(&dim_); return *reinterpret_cast*>(p); } template inline const Dim& UnsafeCast() const { static_assert(D >= 0 && D <= kMaxRank, "Invalid rank"); auto* p = static_cast(&dim_); return *reinterpret_cast*>(p); } inline DDim& CopyFrom(const DDim& ddim) { PADDLE_VISIT_DDIM(ddim.rank_, (*this = ddim.UnsafeCast())); } friend DDim stride(const DDim& ddim); friend DDim stride_numel(const DDim& ddim); private: Dim dim_; int rank_; }; #undef PADDLE_VISIT_DDIM_BASE #undef PADDLE_VISIT_DDIM /** * \brief Make a DDim from std::vector * * \param dims An vector of ints. Must be sized between [1, 9] */ DDim make_ddim(const std::vector& dims); DDim make_ddim(const std::vector& dims); /** * \brief Make a DDim from an initializer list * * \param dims An initializer list of ints. Must be sized between [1, 9] * */ DDim make_ddim(std::initializer_list dims); template std::vector vectorize(const DDim& ddim) { std::vector result(DDim::kMaxRank); dynamic_dim_assign(ddim.Get(), result.data(), ddim.size()); result.resize(ddim.size()); return result; } int64_t product(const DDim& ddim); bool contain_unknown_dim(const DDim& ddim); /** * \brief Slice a ddim * * Slice dim with [begin, end). * e.g. DDim d = make_ddim({1,2,3,4,5}); * slice_ddim(d, 1, 3); ====> {2,3} */ DDim slice_ddim(const DDim& dim, int begin, int end); /** * \brief What is the length of this dimension? * * \param Dynamic dimension to inspect */ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); /** * \brief Flatten dim to 3d * e.g., DDim d = mak_ddim({1, 2, 3, 4, 5, 6}) * flatten_to_3d(d, 2, 4); ===> {1*2, 3*4, 5*6} ===> {2, 12, 30} */ DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims); // Reshape a tensor to a matrix. The matrix's first dimension(column length) // will be the product of tensor's first `num_col_dims` dimensions. DDim flatten_to_2d(const DDim& src, int num_col_dims); DDim flatten_to_1d(const DDim& src); DDim stride(const DDim& ddim); DDim stride_numel(const DDim& ddim); } // namespace phi namespace paddle { namespace framework { using DDim = phi::DDim; } // namespace framework } // namespace paddle