dim.h 7.7 KB
Newer Older
W
wangliu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

朔-望's avatar
朔-望 已提交
15 16
#pragma once

W
wangliu 已提交
17
#include "common/enforce.h"
朔-望's avatar
朔-望 已提交
18
namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
19
namespace framework {
朔-望's avatar
朔-望 已提交
20

朔-望's avatar
朔-望 已提交
21
// Statically sized, statically indexed dimension
朔-望's avatar
朔-望 已提交
22 23
template <int i>
struct Dim {
24
  static constexpr int dimensions = i;
朔-望's avatar
朔-望 已提交
25

26
  template <typename... Args>
L
liuruilong 已提交
27
  Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
28 29 30
    static_assert(sizeof...(_tail) == i - 1,
                  "Dim initialized with the wrong number of parameters");
  }
朔-望's avatar
朔-望 已提交
31

32
  Dim(int64_t _head, const Dim<i - 1> &_tail) : head(_head), tail(_tail) {}
朔-望's avatar
朔-望 已提交
33

34
  Dim() : head(0), tail() {}
朔-望's avatar
朔-望 已提交
35

36 37 38 39 40
  /** Construct a Dim from a linear index and size.  Uses Fortran
   * order
   * indexing. */
  Dim(int64_t idx, const Dim<i> &size)
      : head(idx % size.head), tail(idx / size.head, size.tail) {}
朔-望's avatar
朔-望 已提交
41

42 43
  /** Construct a Dim with each dimension set to the given index */
  Dim(int64_t idx) : head(idx), tail(idx) {}
朔-望's avatar
朔-望 已提交
44

45 46 47
  bool operator==(const Dim<i> &o) const {
    return (head == o.head) && (tail == o.tail);
  }
朔-望's avatar
朔-望 已提交
48

49
  bool operator!=(const Dim<i> &o) const { return !(*this == o); }
朔-望's avatar
朔-望 已提交
50

51
  int64_t &operator[](int idx);
L
liuruilong 已提交
52

53
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
54

L
liuruilong 已提交
55
  std::string to_string() const;
朔-望's avatar
朔-望 已提交
56

57 58
  int64_t head;
  Dim<i - 1> tail;
朔-望's avatar
朔-望 已提交
59
};
朔-望's avatar
朔-望 已提交
60

朔-望's avatar
朔-望 已提交
61
// Base case specialization
朔-望's avatar
朔-望 已提交
62 63
template <>
struct Dim<0> {
64
  static constexpr int dimensions = 0;
朔-望's avatar
朔-望 已提交
65

66
  Dim(int64_t _head) {}
朔-望's avatar
朔-望 已提交
67

68
  Dim() {}
朔-望's avatar
朔-望 已提交
69

70 71
  Dim(int idx, const Dim<0> &size) {
    if (idx > 0) {
W
wangliu 已提交
72
      PADDLE_MOBILE_THROW_EXCEPTION("Index out of range.")
73 74
    }
  }
朔-望's avatar
朔-望 已提交
75

76
  bool operator==(const Dim<0> &o) const { return true; }
朔-望's avatar
朔-望 已提交
77

78
  bool operator!=(const Dim<0> &o) const { return false; }
朔-望's avatar
朔-望 已提交
79

80
  int64_t &operator[](int idx);
L
liuruilong 已提交
81

82
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
83 84 85 86 87
};

namespace {

// Helper for accessing Dim classes
朔-望's avatar
朔-望 已提交
88 89
template <int i>
struct DimGetter {
90
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
91
  template <typename D>
L
liuruilong 已提交
92
  static int64_t impl(const D &d) {
93 94 95
    return DimGetter<i - 1>::impl(d.tail);
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
96
  template <typename D>
L
liuruilong 已提交
97
  static int64_t &impl(D &d) {
98 99
    return DimGetter<i - 1>::impl(d.tail);
  }
朔-望's avatar
朔-望 已提交
100 101 102
};

// Eureka! We found the element!
朔-望's avatar
朔-望 已提交
103 104
template <>
struct DimGetter<0> {
105
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
106
  template <typename D>
L
liuruilong 已提交
107
  static int64_t impl(const D &d) {
108 109 110
    return d.head;
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
111
  template <typename D>
L
liuruilong 已提交
112
  static int64_t &impl(D &d) {
朔-望's avatar
朔-望 已提交
113 114
    return d.head;
  }
朔-望's avatar
朔-望 已提交
115 116
};

朔-望's avatar
朔-望 已提交
117
template <int D>
L
liuruilong 已提交
118
int64_t &indexer(Dim<D> &dim, int idx) {
119
  if (idx < 0) {
W
wangliu 已提交
120
    PADDLE_MOBILE_THROW_EXCEPTION("Tried to access a negative dimension")
121
  }
W
wangliu 已提交
122

123 124 125 126
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
127
}
朔-望's avatar
朔-望 已提交
128

朔-望's avatar
朔-望 已提交
129
template <>
L
liuruilong 已提交
130
int64_t &indexer<0>(Dim<0> &dim, int idx) {
W
wangliu 已提交
131
  PADDLE_MOBILE_THROW_EXCEPTION("Invalid index")
132
  exit(0);
朔-望's avatar
朔-望 已提交
133
}
朔-望's avatar
朔-望 已提交
134

朔-望's avatar
朔-望 已提交
135
template <int D>
L
liuruilong 已提交
136
int64_t indexer(const Dim<D> &dim, int idx) {
137
  if (idx < 0) {
W
wangliu 已提交
138
    PADDLE_MOBILE_THROW_EXCEPTION("Tried to access a negative dimension")
139 140 141 142 143
  }
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
144 145
}

朔-望's avatar
朔-望 已提交
146
template <>
L
liuruilong 已提交
147
int64_t indexer<0>(const Dim<0> &dim, int idx) {
W
wangliu 已提交
148
  PADDLE_MOBILE_THROW_EXCEPTION("Invalid index")
149
  exit(0);
朔-望's avatar
朔-望 已提交
150 151
}

朔-望's avatar
朔-望 已提交
152
}  // namespace
朔-望's avatar
朔-望 已提交
153
// Static access to constant Dim
朔-望's avatar
朔-望 已提交
154
template <int i, int l>
L
liuruilong 已提交
155
int64_t get(const Dim<l> &d) {
156
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
157 158 159
}

// Static access to mutable Dim
朔-望's avatar
朔-望 已提交
160
template <int i, int l>
L
liuruilong 已提交
161
int64_t &get(Dim<l> &d) {
162
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
163 164 165
}

// Dynamic access to constant Dim
朔-望's avatar
朔-望 已提交
166
template <int l>
L
liuruilong 已提交
167
int64_t Dim<l>::operator[](int i) const {
168 169
  //  std::cout << "l: " << l << std::endl;
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
170 171 172
}

// Dynamic access to mutable Dim
朔-望's avatar
朔-望 已提交
173
template <int l>
L
liuruilong 已提交
174
int64_t &Dim<l>::operator[](int i) {
175
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
176 177 178
}

// Dynamic access to constant Dim
L
liuruilong 已提交
179
inline int64_t Dim<0>::operator[](int i) const { return indexer(*this, i); }
朔-望's avatar
朔-望 已提交
180 181

// Dynamic access to mutable Dim
L
liuruilong 已提交
182
inline int64_t &Dim<0>::operator[](int i) { return indexer(*this, i); }
朔-望's avatar
朔-望 已提交
183 184 185 186

// Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d)
template <int l>
L
liuruilong 已提交
187
typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l> &d, int i) {
188
  return d[i];
朔-望's avatar
朔-望 已提交
189 190 191 192
}

// Dynamic access to mutable Dim
template <int l>
L
liuruilong 已提交
193
typename std::enable_if<(l > 0), int64_t &>::type get(Dim<l> &d, int i) {
194
  return d[i];
朔-望's avatar
朔-望 已提交
195 196 197 198
}

// Dot product of two dims
template <int i>
L
liuruilong 已提交
199
int64_t linearize(const Dim<i> &a, const Dim<i> &b) {
200
  return a.head * b.head + linearize(a.tail, b.tail);
朔-望's avatar
朔-望 已提交
201 202 203 204 205
}

// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
206
inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
207
  return 0;
朔-望's avatar
朔-望 已提交
208 209 210
}

// Product of a Dim
朔-望's avatar
朔-望 已提交
211
template <int i>
L
liuruilong 已提交
212
int64_t product(const Dim<i> &a, int prod = 1) {
213
  return prod * a.head * product(a.tail);
朔-望's avatar
朔-望 已提交
214 215 216 217
}

// Base case product of a Dim
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
218
template <>
L
liuruilong 已提交
219
inline int64_t product(const Dim<0> &a, int prod) {
220
  return prod;
朔-望's avatar
朔-望 已提交
221 222 223 224
}

// Is 0 <= idx_i < size_i for all i?
template <int i>
L
liuruilong 已提交
225
bool contained(const Dim<i> &idx, const Dim<i> &size) {
226 227
  return ((0 <= idx.head) && (idx.head < size.head) &&
          contained(idx.tail, size.tail));
朔-望's avatar
朔-望 已提交
228 229 230 231 232
}

// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
233
inline bool contained(const Dim<0> &idx, const Dim<0> &size) {
234
  return true;
朔-望's avatar
朔-望 已提交
235 236 237 238 239 240
}

/**
 * \brief Compute exclusive prefix-multiply of a Dim.
 */
template <int i>
L
liuruilong 已提交
241
Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
242
  return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head));
朔-望's avatar
朔-望 已提交
243 244 245 246 247
}

///\cond HIDDEN
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
248
template <>
L
liuruilong 已提交
249
inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
250
  return Dim<0>();
朔-望's avatar
朔-望 已提交
251 252 253 254 255 256
}
///\endcond

/**
 * Add two dimensions together
 */
朔-望's avatar
朔-望 已提交
257
template <int i>
L
liuruilong 已提交
258
Dim<i> dim_plus(const Dim<i> &a, const Dim<i> &b) {
259
  return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
260 261 262 263
}

// Base case
template <>
L
liuruilong 已提交
264
inline Dim<0> dim_plus(const Dim<0> &a, const Dim<0> &b) {
265
  return Dim<0>();
朔-望's avatar
朔-望 已提交
266 267 268
}

template <int i>
L
liuruilong 已提交
269
Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
270
  return dim_plus(lhs, rhs);
朔-望's avatar
朔-望 已提交
271 272 273 274 275
}

/**
 * Multiply two dimensions together
 */
朔-望's avatar
朔-望 已提交
276
template <int i>
L
liuruilong 已提交
277
Dim<i> dim_mult(const Dim<i> &a, const Dim<i> &b) {
278
  return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
279 280 281 282
}

// Base case
template <>
L
liuruilong 已提交
283
inline Dim<0> dim_mult(const Dim<0> &a, const Dim<0> &b) {
284
  return Dim<0>();
朔-望's avatar
朔-望 已提交
285 286 287
}

template <int i>
L
liuruilong 已提交
288
Dim<i> operator*(const Dim<i> &lhs, const Dim<i> &rhs) {
289
  return dim_mult(lhs, rhs);
朔-望's avatar
朔-望 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302
}

/**
 * \brief Normalize strides to ensure any dimension with extent 1
 * has stride 0.
 *
 * \param size Dim object containing the size of an array
 * \param stride Dim object containing stride of an array
 * \return Dim object the same size as \p size with normalized strides
 *
 */

template <int i>
L
liuruilong 已提交
303
Dim<i> normalize_strides(const Dim<i> &size, const Dim<i> &stride) {
304 305
  int norm_stride = size.head == 1 ? 0 : stride.head;
  return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail));
朔-望's avatar
朔-望 已提交
306 307 308 309 310
}

///\cond HIDDEN

template <>
L
liuruilong 已提交
311
inline Dim<0> normalize_strides(const Dim<0> &size, const Dim<0> &stride) {
312
  return Dim<0>();
朔-望's avatar
朔-望 已提交
313 314 315 316 317 318 319 320 321 322 323 324 325
}

///\endcond

/**
 * Helper function to create a Dim
 *
 * \param idxes The type of Dim constructed depends on the number of
 * params
 *
 */

template <typename... Args>
L
liuruilong 已提交
326
Dim<sizeof...(Args)> make_dim(Args... idxes) {
327
  return Dim<sizeof...(Args)>(idxes...);
朔-望's avatar
朔-望 已提交
328 329
}

朔-望's avatar
朔-望 已提交
330 331
}  // namespace framework
}  // namespace paddle_mobile