/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/ddim.h" #include "paddle/platform/enforce.h" namespace paddle { namespace framework { /// @cond HIDDEN template Dim make_dim(const int64_t* d) { return Dim(*d, make_dim(d + 1)); } template <> Dim<1> make_dim<1>(const int64_t* d) { return Dim<1>(*d); } void make_ddim(DDim& ddim, const int64_t* dims, int n) { switch (n) { case 1: ddim = make_dim<1>(dims); break; case 2: ddim = make_dim<2>(dims); break; case 3: ddim = make_dim<3>(dims); break; case 4: ddim = make_dim<4>(dims); break; case 5: ddim = make_dim<5>(dims); break; case 6: ddim = make_dim<6>(dims); break; case 7: ddim = make_dim<7>(dims); break; case 8: ddim = make_dim<8>(dims); break; case 9: ddim = make_dim<9>(dims); break; default: throw std::invalid_argument( "Dynamic dimensions must have between [1, 9] dimensions."); } } /// @endcond DDim make_ddim(std::initializer_list dims) { DDim result(make_dim(0)); make_ddim(result, dims.begin(), dims.size()); return result; } DDim make_ddim(const std::vector& dims) { DDim result(make_dim(0)); make_ddim(result, &dims[0], dims.size()); return result; } /// @cond HIDDEN // XXX For some reason, putting this in an anonymous namespace causes errors class DynamicMutableIndexer : public boost::static_visitor { public: explicit DynamicMutableIndexer(int idx) : idx_(idx) {} template int64_t& operator()(Dim& dim) const { return dim[idx_]; } private: int idx_; }; class DynamicConstIndexer : public boost::static_visitor { public: explicit DynamicConstIndexer(int idx) : idx_(idx) {} template int64_t operator()(const Dim& dim) const { return dim[idx_]; } private: int idx_; }; /// @endcond int64_t& DDim::operator[](int idx) { return boost::apply_visitor(DynamicMutableIndexer(idx), var); } int64_t DDim::operator[](int idx) const { return boost::apply_visitor(DynamicConstIndexer(idx), var); } int64_t DDim::size() const { return arity(*this); } bool DDim::operator==(DDim d) const { if (var.which() != d.getVar().which()) { return false; } else { std::vector v1 = vectorize(*this); std::vector v2 = vectorize(d); for (unsigned int i = 0; i < v1.size(); i++) { if (v1[i] != v2[i]) { return false; } } return true; } } bool DDim::operator!=(DDim d) const { return !(*this == d); } DDim DDim::operator+(DDim d) const { std::vector v1 = vectorize(*this); std::vector v2 = vectorize(d); std::vector v3; assert(v1.size() == v2.size()); for (unsigned int i = 0; i < v1.size(); i++) { v3.push_back(v1[i] + v2[i]); } return make_ddim(v3); } DDim DDim::operator*(DDim d) const { std::vector v1 = vectorize(*this); std::vector v2 = vectorize(d); std::vector v3; assert(v1.size() == v2.size()); for (unsigned int i = 0; i < v1.size(); i++) { v3.push_back(v1[i] * v2[i]); } return make_ddim(v3); } int64_t get(const DDim& ddim, int idx) { return ddim[idx]; } void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } /// @cond HIDDEN struct VectorizeVisitor : public boost::static_visitor<> { std::vector& vector; explicit VectorizeVisitor(std::vector& v) : vector(v) {} template void operator()(const T& t) { vector.push_back(t.head); this->operator()(t.tail); } void operator()(const Dim<1>& t) { vector.push_back(t.head); } }; /// @endcond std::vector vectorize(const DDim& ddim) { std::vector result; VectorizeVisitor visitor(result); boost::apply_visitor(visitor, ddim); return result; } struct ProductVisitor : public boost::static_visitor { template int64_t operator()(const Dim& dim) { return product(dim); } }; int64_t product(const DDim& ddim) { ProductVisitor visitor; return boost::apply_visitor(visitor, ddim); } struct SliceVectorizeVisitor : public boost::static_visitor<> { std::vector& vector; int begin; int end; SliceVectorizeVisitor(std::vector& v, int b, int e) : vector(v), begin(b), end(e) { PADDLE_ENFORCE(begin < end, "Begin index must be less than end index in ddim slice."); PADDLE_ENFORCE(begin >= 0, "Begin index can't be less than zero in ddim slice."); } template void operator()(const Dim& dim) { if (begin == 0) { vector.push_back(dim.head); } else { --begin; } --end; if (end > 0) { this->operator()(dim.tail); } } void operator()(const Dim<1>& dim) { PADDLE_ENFORCE(end == 1, "End index in ddim slice is out of bound."); vector.push_back(dim.head); } }; DDim slice_ddim(const DDim& dim, int begin, int end) { std::vector vec; vec.reserve(end - begin); SliceVectorizeVisitor visitor(vec, begin, end); boost::apply_visitor(visitor, dim); return make_ddim(vec); } /// \cond HIDDEN struct ArityVisitor : boost::static_visitor { template int operator()(Dim) const { return D; } }; /// \endcond int arity(const DDim& d) { return boost::apply_visitor(ArityVisitor(), d); } /// \cond HIDDEN struct DDimPrinter : boost::static_visitor { std::ostream& os; explicit DDimPrinter(std::ostream& os_) : os(os_) {} template void operator()(const T& t) { os << t; } }; /// \endcond std::ostream& operator<<(std::ostream& os, const DDim& ddim) { DDimPrinter printer(os); boost::apply_visitor(printer, ddim); return os; } DDim::DDim(std::initializer_list init_list) { *this = make_ddim(init_list); } DDim flatten_to_2d(const DDim& src, int num_col_dims) { int rank = src.size(); return make_ddim({product(slice_ddim(src, 0, num_col_dims)), product(slice_ddim(src, num_col_dims, rank))}); } DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } DDim stride(const DDim& ddim) { std::vector strides(ddim.size()); strides[ddim.size() - 1] = 1; for (int i = ddim.size() - 2; i >= 0; --i) { strides[i] = strides[i + 1] * ddim[i + 1]; } return framework::make_ddim(strides); } } // namespace framework } // namespace paddle