#include "paddle/framework/ddim.h" #include "paddle/framework/enforce.h" namespace paddle { namespace framework { ///@cond HIDDEN template Dim make_dim(const int* d) { return Dim(*d, make_dim(d + 1)); } template <> Dim<1> make_dim<1>(const int* d) { return Dim<1>(*d); } void make_ddim(DDim& ddim, const int* dims, int n) { switch (n) { case 1: ddim = make_dim<1>(dims); break; case 2: ddim = make_dim<2>(dims); break; case 3: ddim = make_dim<3>(dims); break; case 4: ddim = make_dim<4>(dims); break; case 5: ddim = make_dim<5>(dims); break; case 6: ddim = make_dim<6>(dims); break; case 7: ddim = make_dim<7>(dims); break; case 8: ddim = make_dim<8>(dims); break; case 9: ddim = make_dim<9>(dims); break; default: throw std::invalid_argument( "Dynamic dimensions must have between [1, 9] dimensions."); } } ///@endcond DDim make_ddim(std::initializer_list dims) { DDim result(make_dim(0)); make_ddim(result, dims.begin(), dims.size()); return result; } DDim make_ddim(const std::vector& dims) { DDim result(make_dim(0)); make_ddim(result, &dims[0], dims.size()); return result; } ///@cond HIDDEN // XXX For some reason, putting this in an anonymous namespace causes errors class DynamicMutableIndexer : public boost::static_visitor { public: DynamicMutableIndexer(int idx) : idx_(idx) {} template int& operator()(Dim& dim) const { return dim[idx_]; } private: int idx_; }; class DynamicConstIndexer : public boost::static_visitor { public: DynamicConstIndexer(int idx) : idx_(idx) {} template int operator()(const Dim& dim) const { return dim[idx_]; } private: int idx_; }; ///@endcond int& DDim::operator[](int idx) { return boost::apply_visitor(DynamicMutableIndexer(idx), var); } int DDim::operator[](int idx) const { return boost::apply_visitor(DynamicConstIndexer(idx), var); } bool DDim::operator==(DDim d) const { if (var.which() != d.getVar().which()) { return false; } else { std::vector v1 = vectorize(*this); std::vector v2 = vectorize(d); for (unsigned int i = 0; i < v1.size(); i++) { if (v1[i] != v2[i]) { return false; } } return true; } } bool DDim::operator!=(DDim d) const { return !(*this == d); } DDim DDim::operator+(DDim d) const { std::vector v1 = vectorize(*this); std::vector v2 = vectorize(d); std::vector v3; assert(v1.size() == v2.size()); for (unsigned int i = 0; i < v1.size(); i++) { v3.push_back(v1[i] + v2[i]); } return make_ddim(v3); } DDim DDim::operator*(DDim d) const { std::vector v1 = vectorize(*this); std::vector v2 = vectorize(d); std::vector v3; assert(v1.size() == v2.size()); for (unsigned int i = 0; i < v1.size(); i++) { v3.push_back(v1[i] * v2[i]); } return make_ddim(v3); } int get(const DDim& ddim, int idx) { return ddim[idx]; } void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } ///@cond HIDDEN struct VectorizeVisitor : public boost::static_visitor<> { std::vector& vector; VectorizeVisitor(std::vector& v) : vector(v) {} template void operator()(const T& t) { vector.push_back(t.head); this->operator()(t.tail); } void operator()(const Dim<1>& t) { vector.push_back(t.head); } }; ///@endcond std::vector vectorize(const DDim& ddim) { std::vector result; VectorizeVisitor visitor(result); boost::apply_visitor(visitor, ddim); return result; } ssize_t product(const DDim& ddim) { ssize_t result = 1; std::vector v = vectorize(ddim); for (auto i : v) { result *= i; } return result; } ///\cond HIDDEN struct ArityVisitor : boost::static_visitor { template int operator()(Dim) const { return D; } }; ///\endcond int arity(const DDim& d) { return boost::apply_visitor(ArityVisitor(), d); } ///\cond HIDDEN struct DDimPrinter : boost::static_visitor { std::ostream& os; DDimPrinter(std::ostream& os_) : os(os_) {} template void operator()(const T& t) { os << t; } }; ///\endcond std::ostream& operator<<(std::ostream& os, const DDim& ddim) { DDimPrinter printer(os); boost::apply_visitor(printer, ddim); return os; } template Eigen::DSizes ToEigenDSizes(DDim dims) const { int rank = paddle::framework::arity(dims); PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same") Eigen::DSizes dsizes; for (int d = 0; d < rank; d++) { dsizes[d] = dims[d]; } return dsizes; } } // namespace framework } // namespace paddle