ddim.h 2.1 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7
#pragma once

#include <boost/variant.hpp>
#include <initializer_list>
#include <stdexcept>
#include <vector>

8
#include "paddle/majel/dim.h"
F
fengjiayi 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

namespace majel {

namespace {
typedef boost::variant<Dim<1>,
                       Dim<2>,
                       Dim<3>,
                       Dim<4>,
                       Dim<5>,
                       Dim<6>,
                       Dim<7>,
                       Dim<8>,
                       Dim<9>>
    DDimVar;
}

/**
 * \brief A dynamically sized dimension.
 *
 * The number of dimensions must be between [1, 9].
 */
struct DDim {
  DDimVar var;

  DDim() : var(Dim<1>()) {}

  template <int D>
  DDim(const Dim<D>& in) : var(in) {}

  template <int D>
  DDim& operator=(const Dim<D>& in) {
    var = in;
    return *this;
  }

  int& operator[](int idx);
  int operator[](int idx) const;

  template <typename Visitor>
  typename Visitor::result_type apply_visitor(Visitor& visitor) {
    return var.apply_visitor(visitor);
  }

  template <typename Visitor>
  typename Visitor::result_type apply_visitor(Visitor& visitor) const {
    return var.apply_visitor(visitor);
  }

  DDimVar getVar() { return var; }

  bool operator==(DDim d) const;

  bool operator!=(DDim d) const;

  DDim operator+(DDim d) const;

  DDim operator*(DDim d) const;
};

/**
 * \brief Make a DDim from std::vector<int>
 *
 * \param dims An vector of ints. Must be sized between [1, 9]
 */
DDim make_ddim(const std::vector<int>& dims);

/**
 * \brief Make a DDim from an initializer list
 *
 * \param dims An initializer list of ints. Must be sized between [1, 9]
 *
 */
DDim make_ddim(std::initializer_list<int> dims);

int get(const DDim& dim, int idx);
void set(DDim& dim, int idx, int val);

std::vector<int> vectorize(const DDim& ddim);

ssize_t product(const DDim& ddim);

/**
 * \brief What is the length of this dimension?
 *
 * \param Dynamic dimension to inspect
 */

int arity(const DDim& ddim);

std::ostream& operator<<(std::ostream&, const majel::DDim&);

}  // namespace majel

namespace boost {

template <typename T>
T get(const majel::DDim& in) {
  return boost::get<T>(in.var);
}

}  // namespace boost