dim.h 7.7 KB
Newer Older
W
wangliu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

朔-望's avatar
朔-望 已提交
15 16
#pragma once

W
wangliu 已提交
17
#include "common/enforce.h"
朔-望's avatar
朔-望 已提交
18
namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
19
namespace framework {
朔-望's avatar
朔-望 已提交
20

朔-望's avatar
朔-望 已提交
21
// Statically sized, statically indexed dimension
朔-望's avatar
朔-望 已提交
22 23
template <int i>
struct Dim {
24
  static constexpr int dimensions = i;
朔-望's avatar
朔-望 已提交
25

26
  template <typename... Args>
L
liuruilong 已提交
27
  Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
28 29 30
    static_assert(sizeof...(_tail) == i - 1,
                  "Dim initialized with the wrong number of parameters");
  }
朔-望's avatar
朔-望 已提交
31

32
  Dim(int64_t _head, const Dim<i - 1> &_tail) : head(_head), tail(_tail) {}
朔-望's avatar
朔-望 已提交
33

34
  Dim() : head(0), tail() {}
朔-望's avatar
朔-望 已提交
35

36 37 38 39 40
  /** Construct a Dim from a linear index and size.  Uses Fortran
   * order
   * indexing. */
  Dim(int64_t idx, const Dim<i> &size)
      : head(idx % size.head), tail(idx / size.head, size.tail) {}
朔-望's avatar
朔-望 已提交
41

42 43
  /** Construct a Dim with each dimension set to the given index */
  Dim(int64_t idx) : head(idx), tail(idx) {}
朔-望's avatar
朔-望 已提交
44

45 46 47
  bool operator==(const Dim<i> &o) const {
    return (head == o.head) && (tail == o.tail);
  }
朔-望's avatar
朔-望 已提交
48

49
  bool operator!=(const Dim<i> &o) const { return !(*this == o); }
朔-望's avatar
朔-望 已提交
50

51
  int64_t &operator[](int idx);
L
liuruilong 已提交
52

53
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
54

L
liuruilong 已提交
55
  std::string to_string() const;
朔-望's avatar
朔-望 已提交
56

57 58
  int64_t head;
  Dim<i - 1> tail;
朔-望's avatar
朔-望 已提交
59
};
朔-望's avatar
朔-望 已提交
60

朔-望's avatar
朔-望 已提交
61
// Base case specialization
朔-望's avatar
朔-望 已提交
62 63
template <>
struct Dim<0> {
64
  static constexpr int dimensions = 0;
朔-望's avatar
朔-望 已提交
65

66
  Dim(int64_t _head) {}
朔-望's avatar
朔-望 已提交
67

68
  Dim() {}
朔-望's avatar
朔-望 已提交
69

70 71
  Dim(int idx, const Dim<0> &size) {
    if (idx > 0) {
W
wangliu 已提交
72
      PADDLE_MOBILE_THROW_EXCEPTION("Index out of range.")
73 74
    }
  }
朔-望's avatar
朔-望 已提交
75

76
  bool operator==(const Dim<0> &o) const { return true; }
朔-望's avatar
朔-望 已提交
77

78
  bool operator!=(const Dim<0> &o) const { return false; }
朔-望's avatar
朔-望 已提交
79

80
  int64_t &operator[](int idx);
L
liuruilong 已提交
81

82
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
83 84 85 86 87
};

namespace {

// Helper for accessing Dim classes
朔-望's avatar
朔-望 已提交
88 89
template <int i>
struct DimGetter {
90
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
91
  template <typename D>
L
liuruilong 已提交
92
  static int64_t impl(const D &d) {
93 94 95
    return DimGetter<i - 1>::impl(d.tail);
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
96
  template <typename D>
L
liuruilong 已提交
97
  static int64_t &impl(D &d) {
98 99
    return DimGetter<i - 1>::impl(d.tail);
  }
朔-望's avatar
朔-望 已提交
100 101 102
};

// Eureka! We found the element!
朔-望's avatar
朔-望 已提交
103 104
template <>
struct DimGetter<0> {
105
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
106
  template <typename D>
L
liuruilong 已提交
107
  static int64_t impl(const D &d) {
108 109 110
    return d.head;
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
111
  template <typename D>
L
liuruilong 已提交
112
  static int64_t &impl(D &d) {
朔-望's avatar
朔-望 已提交
113 114
    return d.head;
  }
朔-望's avatar
朔-望 已提交
115 116
};

朔-望's avatar
朔-望 已提交
117
template <int D>
L
liuruilong 已提交
118
int64_t &indexer(Dim<D> &dim, int idx) {
119
  if (idx < 0) {
W
wangliu 已提交
120
    PADDLE_MOBILE_THROW_EXCEPTION("Tried to access a negative dimension")
121
  }
W
wangliu 已提交
122

123 124 125 126
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
127
}
朔-望's avatar
朔-望 已提交
128

朔-望's avatar
朔-望 已提交
129
template <>
L
liuruilong 已提交
130
int64_t &indexer<0>(Dim<0> &dim, int idx) {
W
wangliu 已提交
131
  PADDLE_MOBILE_THROW_EXCEPTION("Invalid index")
朔-望's avatar
朔-望 已提交
132
}
朔-望's avatar
朔-望 已提交
133

朔-望's avatar
朔-望 已提交
134
template <int D>
L
liuruilong 已提交
135
int64_t indexer(const Dim<D> &dim, int idx) {
136
  if (idx < 0) {
W
wangliu 已提交
137
    PADDLE_MOBILE_THROW_EXCEPTION("Tried to access a negative dimension")
138 139 140 141 142
  }
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
143 144
}

朔-望's avatar
朔-望 已提交
145
template <>
L
liuruilong 已提交
146
int64_t indexer<0>(const Dim<0> &dim, int idx) {
W
wangliu 已提交
147
  PADDLE_MOBILE_THROW_EXCEPTION("Invalid index")
朔-望's avatar
朔-望 已提交
148 149
}

朔-望's avatar
朔-望 已提交
150
}  // namespace
朔-望's avatar
朔-望 已提交
151
// Static access to constant Dim
朔-望's avatar
朔-望 已提交
152
template <int i, int l>
L
liuruilong 已提交
153
int64_t get(const Dim<l> &d) {
154
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
155 156 157
}

// Static access to mutable Dim
朔-望's avatar
朔-望 已提交
158
template <int i, int l>
L
liuruilong 已提交
159
int64_t &get(Dim<l> &d) {
160
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
161 162 163
}

// Dynamic access to constant Dim
朔-望's avatar
朔-望 已提交
164
template <int l>
L
liuruilong 已提交
165
int64_t Dim<l>::operator[](int i) const {
166 167
  //  std::cout << "l: " << l << std::endl;
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
168 169 170
}

// Dynamic access to mutable Dim
朔-望's avatar
朔-望 已提交
171
template <int l>
L
liuruilong 已提交
172
int64_t &Dim<l>::operator[](int i) {
173
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
174 175 176
}

// Dynamic access to constant Dim
L
liuruilong 已提交
177
inline int64_t Dim<0>::operator[](int i) const { return indexer(*this, i); }
朔-望's avatar
朔-望 已提交
178 179

// Dynamic access to mutable Dim
L
liuruilong 已提交
180
inline int64_t &Dim<0>::operator[](int i) { return indexer(*this, i); }
朔-望's avatar
朔-望 已提交
181 182 183 184

// Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d)
template <int l>
L
liuruilong 已提交
185
typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l> &d, int i) {
186
  return d[i];
朔-望's avatar
朔-望 已提交
187 188 189 190
}

// Dynamic access to mutable Dim
template <int l>
L
liuruilong 已提交
191
typename std::enable_if<(l > 0), int64_t &>::type get(Dim<l> &d, int i) {
192
  return d[i];
朔-望's avatar
朔-望 已提交
193 194 195 196
}

// Dot product of two dims
template <int i>
L
liuruilong 已提交
197
int64_t linearize(const Dim<i> &a, const Dim<i> &b) {
198
  return a.head * b.head + linearize(a.tail, b.tail);
朔-望's avatar
朔-望 已提交
199 200 201 202 203
}

// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
204
inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
205
  return 0;
朔-望's avatar
朔-望 已提交
206 207 208
}

// Product of a Dim
朔-望's avatar
朔-望 已提交
209
template <int i>
L
liuruilong 已提交
210
int64_t product(const Dim<i> &a, int prod = 1) {
211
  return prod * a.head * product(a.tail);
朔-望's avatar
朔-望 已提交
212 213 214 215
}

// Base case product of a Dim
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
216
template <>
L
liuruilong 已提交
217
inline int64_t product(const Dim<0> &a, int prod) {
218
  return prod;
朔-望's avatar
朔-望 已提交
219 220 221 222
}

// Is 0 <= idx_i < size_i for all i?
template <int i>
L
liuruilong 已提交
223
bool contained(const Dim<i> &idx, const Dim<i> &size) {
224 225
  return ((0 <= idx.head) && (idx.head < size.head) &&
          contained(idx.tail, size.tail));
朔-望's avatar
朔-望 已提交
226 227 228 229 230
}

// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
231
inline bool contained(const Dim<0> &idx, const Dim<0> &size) {
232
  return true;
朔-望's avatar
朔-望 已提交
233 234 235 236 237 238
}

/**
 * \brief Compute exclusive prefix-multiply of a Dim.
 */
template <int i>
L
liuruilong 已提交
239
Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
240
  return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head));
朔-望's avatar
朔-望 已提交
241 242 243 244 245
}

///\cond HIDDEN
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
246
template <>
L
liuruilong 已提交
247
inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
248
  return Dim<0>();
朔-望's avatar
朔-望 已提交
249 250 251 252 253 254
}
///\endcond

/**
 * Add two dimensions together
 */
朔-望's avatar
朔-望 已提交
255
template <int i>
L
liuruilong 已提交
256
Dim<i> dim_plus(const Dim<i> &a, const Dim<i> &b) {
257
  return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
258 259 260 261
}

// Base case
template <>
L
liuruilong 已提交
262
inline Dim<0> dim_plus(const Dim<0> &a, const Dim<0> &b) {
263
  return Dim<0>();
朔-望's avatar
朔-望 已提交
264 265 266
}

template <int i>
L
liuruilong 已提交
267
Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
268
  return dim_plus(lhs, rhs);
朔-望's avatar
朔-望 已提交
269 270 271 272 273
}

/**
 * Multiply two dimensions together
 */
朔-望's avatar
朔-望 已提交
274
template <int i>
L
liuruilong 已提交
275
Dim<i> dim_mult(const Dim<i> &a, const Dim<i> &b) {
276
  return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
277 278 279 280
}

// Base case
template <>
L
liuruilong 已提交
281
inline Dim<0> dim_mult(const Dim<0> &a, const Dim<0> &b) {
282
  return Dim<0>();
朔-望's avatar
朔-望 已提交
283 284 285
}

template <int i>
L
liuruilong 已提交
286
Dim<i> operator*(const Dim<i> &lhs, const Dim<i> &rhs) {
287
  return dim_mult(lhs, rhs);
朔-望's avatar
朔-望 已提交
288 289 290 291 292 293 294 295 296 297 298 299 300
}

/**
 * \brief Normalize strides to ensure any dimension with extent 1
 * has stride 0.
 *
 * \param size Dim object containing the size of an array
 * \param stride Dim object containing stride of an array
 * \return Dim object the same size as \p size with normalized strides
 *
 */

template <int i>
L
liuruilong 已提交
301
Dim<i> normalize_strides(const Dim<i> &size, const Dim<i> &stride) {
302 303
  int norm_stride = size.head == 1 ? 0 : stride.head;
  return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail));
朔-望's avatar
朔-望 已提交
304 305 306 307 308
}

///\cond HIDDEN

template <>
L
liuruilong 已提交
309
inline Dim<0> normalize_strides(const Dim<0> &size, const Dim<0> &stride) {
310
  return Dim<0>();
朔-望's avatar
朔-望 已提交
311 312 313 314 315 316 317 318 319 320 321 322 323
}

///\endcond

/**
 * Helper function to create a Dim
 *
 * \param idxes The type of Dim constructed depends on the number of
 * params
 *
 */

template <typename... Args>
L
liuruilong 已提交
324
Dim<sizeof...(Args)> make_dim(Args... idxes) {
325
  return Dim<sizeof...(Args)>(idxes...);
朔-望's avatar
朔-望 已提交
326 327
}

朔-望's avatar
朔-望 已提交
328 329
}  // namespace framework
}  // namespace paddle_mobile