ddim.h 2.3 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7
#pragma once

#include <boost/variant.hpp>
#include <initializer_list>
#include <stdexcept>
#include <vector>

8
#include "paddle/framework/dim.h"
F
fengjiayi 已提交
9

10 11
namespace paddle {
namespace framework {
F
fengjiayi 已提交
12 13

namespace {
14 15
typedef boost::variant<Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>,
                       Dim<8>, Dim<9>>
F
fengjiayi 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29
    DDimVar;
}

/**
 * \brief A dynamically sized dimension.
 *
 * The number of dimensions must be between [1, 9].
 */
struct DDim {
  DDimVar var;

  DDim() : var(Dim<1>()) {}

  template <int D>
L
liaogang 已提交
30
  explicit DDim(const Dim<D>& in) : var(in) {}
F
fengjiayi 已提交
31

32 33
  /*implicit*/ DDim(std::initializer_list<int> init_list);

F
fengjiayi 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  template <int D>
  DDim& operator=(const Dim<D>& in) {
    var = in;
    return *this;
  }

  int& operator[](int idx);
  int operator[](int idx) const;

  template <typename Visitor>
  typename Visitor::result_type apply_visitor(Visitor& visitor) {
    return var.apply_visitor(visitor);
  }

  template <typename Visitor>
  typename Visitor::result_type apply_visitor(Visitor& visitor) const {
    return var.apply_visitor(visitor);
  }

  DDimVar getVar() { return var; }

F
fengjiayi 已提交
55 56
  ssize_t size() const;

F
fengjiayi 已提交
57 58 59 60 61 62 63
  bool operator==(DDim d) const;

  bool operator!=(DDim d) const;

  DDim operator+(DDim d) const;

  DDim operator*(DDim d) const;
64 65

  ssize_t size() const;
F
fengjiayi 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
};

/**
 * \brief Make a DDim from std::vector<int>
 *
 * \param dims An vector of ints. Must be sized between [1, 9]
 */
DDim make_ddim(const std::vector<int>& dims);

/**
 * \brief Make a DDim from an initializer list
 *
 * \param dims An initializer list of ints. Must be sized between [1, 9]
 *
 */
DDim make_ddim(std::initializer_list<int> dims);

int get(const DDim& dim, int idx);
void set(DDim& dim, int idx, int val);

std::vector<int> vectorize(const DDim& ddim);

ssize_t product(const DDim& ddim);

F
fengjiayi 已提交
90 91 92 93 94 95 96
/**
 * \brief Slice a ddim
 *
 * Slice dim with [begin, end).
 * e.g.  DDim d = make_ddim({1,2,3,4,5});
 *       slice_ddim(d, 1, 3); ====> {2,3}
 */
97 98
DDim slice_ddim(const DDim& dim, int begin, int end);

F
fengjiayi 已提交
99 100 101 102 103 104 105 106
/**
 * \brief What is the length of this dimension?
 *
 * \param Dynamic dimension to inspect
 */

int arity(const DDim& ddim);

107
std::ostream& operator<<(std::ostream&, const DDim&);
F
fengjiayi 已提交
108

109 110
}  // namespace framework
}  // namespace paddle
F
fengjiayi 已提交
111 112 113 114

namespace boost {

template <typename T>
115
T get(const paddle::framework::DDim& in) {
F
fengjiayi 已提交
116 117 118 119
  return boost::get<T>(in.var);
}

}  // namespace boost