dim.h 9.4 KB
Newer Older
W
wangliu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

朔-望's avatar
朔-望 已提交
15 16 17 18 19 20 21
#pragma once

#include <sstream>
#include <stdexcept>
#include <type_traits>

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
22
namespace framework {
朔-望's avatar
朔-望 已提交
23

朔-望's avatar
朔-望 已提交
24
// Statically sized, statically indexed dimension
朔-望's avatar
朔-望 已提交
25 26
template <int i>
struct Dim {
27
  static constexpr int dimensions = i;
朔-望's avatar
朔-望 已提交
28

29
  template <typename... Args>
L
liuruilong 已提交
30
  Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
31 32 33
    static_assert(sizeof...(_tail) == i - 1,
                  "Dim initialized with the wrong number of parameters");
  }
朔-望's avatar
朔-望 已提交
34

35
  Dim(int64_t _head, const Dim<i - 1> &_tail) : head(_head), tail(_tail) {}
朔-望's avatar
朔-望 已提交
36

37
  Dim() : head(0), tail() {}
朔-望's avatar
朔-望 已提交
38

39 40 41 42 43
  /** Construct a Dim from a linear index and size.  Uses Fortran
   * order
   * indexing. */
  Dim(int64_t idx, const Dim<i> &size)
      : head(idx % size.head), tail(idx / size.head, size.tail) {}
朔-望's avatar
朔-望 已提交
44

45 46
  /** Construct a Dim with each dimension set to the given index */
  Dim(int64_t idx) : head(idx), tail(idx) {}
朔-望's avatar
朔-望 已提交
47

48 49 50
  bool operator==(const Dim<i> &o) const {
    return (head == o.head) && (tail == o.tail);
  }
朔-望's avatar
朔-望 已提交
51

52
  bool operator!=(const Dim<i> &o) const { return !(*this == o); }
朔-望's avatar
朔-望 已提交
53

54
  int64_t &operator[](int idx);
L
liuruilong 已提交
55

56
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
57

L
liuruilong 已提交
58
  std::string to_string() const;
朔-望's avatar
朔-望 已提交
59

60 61
  int64_t head;
  Dim<i - 1> tail;
朔-望's avatar
朔-望 已提交
62
};
朔-望's avatar
朔-望 已提交
63

朔-望's avatar
朔-望 已提交
64
// Base case specialization
朔-望's avatar
朔-望 已提交
65 66
template <>
struct Dim<0> {
67
  static constexpr int dimensions = 0;
朔-望's avatar
朔-望 已提交
68

69
  Dim(int64_t _head) {}
朔-望's avatar
朔-望 已提交
70

71
  Dim() {}
朔-望's avatar
朔-望 已提交
72

73
  Dim(int idx, const Dim<0> &size) {
朔-望's avatar
朔-望 已提交
74
#ifndef __CUDA_ARCH__
75 76 77
    if (idx > 0) {
      throw std::invalid_argument("Index out of range.");
    }
朔-望's avatar
朔-望 已提交
78
#else
79
    PADDLE_ASSERT(idx == 0);
朔-望's avatar
朔-望 已提交
80
#endif
81
  }
朔-望's avatar
朔-望 已提交
82

83
  bool operator==(const Dim<0> &o) const { return true; }
朔-望's avatar
朔-望 已提交
84

85
  bool operator!=(const Dim<0> &o) const { return false; }
朔-望's avatar
朔-望 已提交
86

87
  int64_t &operator[](int idx);
L
liuruilong 已提交
88

89
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
90 91 92 93 94
};

namespace {

// Helper for accessing Dim classes
朔-望's avatar
朔-望 已提交
95 96
template <int i>
struct DimGetter {
97
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
98
  template <typename D>
L
liuruilong 已提交
99
  static int64_t impl(const D &d) {
100 101 102
    return DimGetter<i - 1>::impl(d.tail);
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
103
  template <typename D>
L
liuruilong 已提交
104
  static int64_t &impl(D &d) {
105 106
    return DimGetter<i - 1>::impl(d.tail);
  }
朔-望's avatar
朔-望 已提交
107 108 109
};

// Eureka! We found the element!
朔-望's avatar
朔-望 已提交
110 111
template <>
struct DimGetter<0> {
112
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
113
  template <typename D>
L
liuruilong 已提交
114
  static int64_t impl(const D &d) {
115 116 117
    return d.head;
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
118
  template <typename D>
L
liuruilong 已提交
119
  static int64_t &impl(D &d) {
朔-望's avatar
朔-望 已提交
120 121
    return d.head;
  }
朔-望's avatar
朔-望 已提交
122 123
};

朔-望's avatar
朔-望 已提交
124
template <int D>
L
liuruilong 已提交
125
int64_t &indexer(Dim<D> &dim, int idx) {
朔-望's avatar
朔-望 已提交
126
#ifndef __CUDA_ARCH__
127 128 129
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
朔-望's avatar
朔-望 已提交
130
#else
131
  PADDLE_ASSERT(idx >= 0);
朔-望's avatar
朔-望 已提交
132
#endif
133 134 135 136
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
137
}
朔-望's avatar
朔-望 已提交
138

朔-望's avatar
朔-望 已提交
139
template <>
L
liuruilong 已提交
140
int64_t &indexer<0>(Dim<0> &dim, int idx) {
朔-望's avatar
朔-望 已提交
141
#ifndef __CUDA_ARCH__
142
  throw std::invalid_argument("Invalid index");
朔-望's avatar
朔-望 已提交
143
#else
144
  PADDLE_ASSERT(false);
朔-望's avatar
朔-望 已提交
145
#if CUDA_VERSION < 8000
146 147 148
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
朔-望's avatar
朔-望 已提交
149
#else
150
  static int64_t head = 0;
朔-望's avatar
朔-望 已提交
151
#endif
152
  return head;
朔-望's avatar
朔-望 已提交
153
#endif
朔-望's avatar
朔-望 已提交
154
}
朔-望's avatar
朔-望 已提交
155

朔-望's avatar
朔-望 已提交
156
template <int D>
L
liuruilong 已提交
157
int64_t indexer(const Dim<D> &dim, int idx) {
朔-望's avatar
朔-望 已提交
158
#ifndef __CUDA_ARCH__
159 160 161
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
朔-望's avatar
朔-望 已提交
162
#else
163
  PADDLE_ASSERT(idx >= 0);
朔-望's avatar
朔-望 已提交
164
#endif
165 166 167 168
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
169 170
}

朔-望's avatar
朔-望 已提交
171
template <>
L
liuruilong 已提交
172
int64_t indexer<0>(const Dim<0> &dim, int idx) {
朔-望's avatar
朔-望 已提交
173
#ifndef __CUDA_ARCH__
174
  throw std::invalid_argument("Invalid index");
朔-望's avatar
朔-望 已提交
175
#else
176
  PADDLE_ASSERT(false);
朔-望's avatar
朔-望 已提交
177
#if CUDA_VERSION < 8000
178 179 180
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
朔-望's avatar
朔-望 已提交
181
#else
182
  static int64_t head = 0;
朔-望's avatar
朔-望 已提交
183
#endif
184
  return head;
朔-望's avatar
朔-望 已提交
185
#endif
朔-望's avatar
朔-望 已提交
186 187
}

朔-望's avatar
朔-望 已提交
188
}  // namespace
朔-望's avatar
朔-望 已提交
189
// Static access to constant Dim
朔-望's avatar
朔-望 已提交
190
template <int i, int l>
L
liuruilong 已提交
191
int64_t get(const Dim<l> &d) {
192
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
193 194 195
}

// Static access to mutable Dim
朔-望's avatar
朔-望 已提交
196
template <int i, int l>
L
liuruilong 已提交
197
int64_t &get(Dim<l> &d) {
198
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
199 200 201
}

// Dynamic access to constant Dim
朔-望's avatar
朔-望 已提交
202
template <int l>
L
liuruilong 已提交
203
int64_t Dim<l>::operator[](int i) const {
204 205
  //  std::cout << "l: " << l << std::endl;
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
206 207 208
}

// Dynamic access to mutable Dim
朔-望's avatar
朔-望 已提交
209
template <int l>
L
liuruilong 已提交
210
int64_t &Dim<l>::operator[](int i) {
211
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
212 213 214
}

// Dynamic access to constant Dim
L
liuruilong 已提交
215
inline int64_t Dim<0>::operator[](int i) const { return indexer(*this, i); }
朔-望's avatar
朔-望 已提交
216 217

// Dynamic access to mutable Dim
L
liuruilong 已提交
218
inline int64_t &Dim<0>::operator[](int i) { return indexer(*this, i); }
朔-望's avatar
朔-望 已提交
219 220 221 222

// Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d)
template <int l>
L
liuruilong 已提交
223
typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l> &d, int i) {
224
  return d[i];
朔-望's avatar
朔-望 已提交
225 226 227 228
}

// Dynamic access to mutable Dim
template <int l>
L
liuruilong 已提交
229
typename std::enable_if<(l > 0), int64_t &>::type get(Dim<l> &d, int i) {
230
  return d[i];
朔-望's avatar
朔-望 已提交
231 232 233 234
}

// Dot product of two dims
template <int i>
L
liuruilong 已提交
235
int64_t linearize(const Dim<i> &a, const Dim<i> &b) {
236
  return a.head * b.head + linearize(a.tail, b.tail);
朔-望's avatar
朔-望 已提交
237 238 239 240 241
}

// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
242
inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
243
  return 0;
朔-望's avatar
朔-望 已提交
244 245 246
}

// Product of a Dim
朔-望's avatar
朔-望 已提交
247
template <int i>
L
liuruilong 已提交
248
int64_t product(const Dim<i> &a, int prod = 1) {
249
  return prod * a.head * product(a.tail);
朔-望's avatar
朔-望 已提交
250 251 252 253
}

// Base case product of a Dim
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
254
template <>
L
liuruilong 已提交
255
inline int64_t product(const Dim<0> &a, int prod) {
256
  return prod;
朔-望's avatar
朔-望 已提交
257 258 259 260
}

// Is 0 <= idx_i < size_i for all i?
template <int i>
L
liuruilong 已提交
261
bool contained(const Dim<i> &idx, const Dim<i> &size) {
262 263
  return ((0 <= idx.head) && (idx.head < size.head) &&
          contained(idx.tail, size.tail));
朔-望's avatar
朔-望 已提交
264 265 266 267 268
}

// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
269
inline bool contained(const Dim<0> &idx, const Dim<0> &size) {
270
  return true;
朔-望's avatar
朔-望 已提交
271 272 273 274 275 276
}

/**
 * \brief Compute exclusive prefix-multiply of a Dim.
 */
template <int i>
L
liuruilong 已提交
277
Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
278
  return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head));
朔-望's avatar
朔-望 已提交
279 280 281 282 283
}

///\cond HIDDEN
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
284
template <>
L
liuruilong 已提交
285
inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
286
  return Dim<0>();
朔-望's avatar
朔-望 已提交
287 288 289 290 291 292
}
///\endcond

/**
 * Add two dimensions together
 */
朔-望's avatar
朔-望 已提交
293
template <int i>
L
liuruilong 已提交
294
Dim<i> dim_plus(const Dim<i> &a, const Dim<i> &b) {
295
  return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
296 297 298 299
}

// Base case
template <>
L
liuruilong 已提交
300
inline Dim<0> dim_plus(const Dim<0> &a, const Dim<0> &b) {
301
  return Dim<0>();
朔-望's avatar
朔-望 已提交
302 303 304
}

template <int i>
L
liuruilong 已提交
305
Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
306
  return dim_plus(lhs, rhs);
朔-望's avatar
朔-望 已提交
307 308 309 310 311
}

/**
 * Multiply two dimensions together
 */
朔-望's avatar
朔-望 已提交
312
template <int i>
L
liuruilong 已提交
313
Dim<i> dim_mult(const Dim<i> &a, const Dim<i> &b) {
314
  return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
315 316 317 318
}

// Base case
template <>
L
liuruilong 已提交
319
inline Dim<0> dim_mult(const Dim<0> &a, const Dim<0> &b) {
320
  return Dim<0>();
朔-望's avatar
朔-望 已提交
321 322 323
}

template <int i>
L
liuruilong 已提交
324
Dim<i> operator*(const Dim<i> &lhs, const Dim<i> &rhs) {
325
  return dim_mult(lhs, rhs);
朔-望's avatar
朔-望 已提交
326 327 328 329 330 331 332 333 334 335 336 337 338
}

/**
 * \brief Normalize strides to ensure any dimension with extent 1
 * has stride 0.
 *
 * \param size Dim object containing the size of an array
 * \param stride Dim object containing stride of an array
 * \return Dim object the same size as \p size with normalized strides
 *
 */

template <int i>
L
liuruilong 已提交
339
Dim<i> normalize_strides(const Dim<i> &size, const Dim<i> &stride) {
340 341
  int norm_stride = size.head == 1 ? 0 : stride.head;
  return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail));
朔-望's avatar
朔-望 已提交
342 343 344 345 346
}

///\cond HIDDEN

template <>
L
liuruilong 已提交
347
inline Dim<0> normalize_strides(const Dim<0> &size, const Dim<0> &stride) {
348
  return Dim<0>();
朔-望's avatar
朔-望 已提交
349 350 351 352 353 354 355 356 357 358 359 360 361
}

///\endcond

/**
 * Helper function to create a Dim
 *
 * \param idxes The type of Dim constructed depends on the number of
 * params
 *
 */

template <typename... Args>
L
liuruilong 已提交
362
Dim<sizeof...(Args)> make_dim(Args... idxes) {
363
  return Dim<sizeof...(Args)>(idxes...);
朔-望's avatar
朔-望 已提交
364 365 366 367 368
}

// Allows us to output a Dim
// XXX For some reason, overloading fails to resolve this correctly
template <int i>
朔-望's avatar
朔-望 已提交
369 370
typename std::enable_if<(i > 1), std::ostream &>::type operator<<(
    std::ostream &os, const Dim<i> &d) {
371 372
  os << d.head << ", " << d.tail;
  return os;
朔-望's avatar
朔-望 已提交
373 374 375 376 377
}

// Base case that allows us to output a Dim
// XXX I wish this could be an overload instead of a template
template <int i>
朔-望's avatar
朔-望 已提交
378 379
typename std::enable_if<(i == 1), std::ostream &>::type operator<<(
    std::ostream &os, const Dim<i> &d) {
380 381
  os << d.head;
  return os;
朔-望's avatar
朔-望 已提交
382 383 384
}

inline std::ostream &operator<<(std::ostream &os, const Dim<0> &d) {
385
  return os;
朔-望's avatar
朔-望 已提交
386 387
}

朔-望's avatar
朔-望 已提交
388
template <int i>
L
liuruilong 已提交
389
std::string Dim<i>::to_string() const {
390
  std::stringstream stream;
朔-望's avatar
朔-望 已提交
391

392
  stream << *this;
朔-望's avatar
朔-望 已提交
393

394
  return stream.str();
朔-望's avatar
朔-望 已提交
395 396 397
}

template <int D>
L
liuruilong 已提交
398
Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
399
  Dim<D> result;
朔-望's avatar
朔-望 已提交
400

401 402 403 404
  for (int i = 0; i < D - 1; ++i) {
    result[i] = linear_index % extents[i];
    linear_index /= extents[i];
  }
朔-望's avatar
朔-望 已提交
405

406
  result[D - 1] = linear_index;
朔-望's avatar
朔-望 已提交
407

408
  return result;
朔-望's avatar
朔-望 已提交
409 410
}

朔-望's avatar
朔-望 已提交
411 412
}  // namespace framework
}  // namespace paddle_mobile