dim.h 9.5 KB
Newer Older
W
wangliu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

朔-望's avatar
朔-望 已提交
15 16 17 18 19 20 21 22
#pragma once

#include <iostream>
#include <sstream>
#include <stdexcept>
#include <type_traits>

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
23
namespace framework {
朔-望's avatar
朔-望 已提交
24

朔-望's avatar
朔-望 已提交
25
// Statically sized, statically indexed dimension
朔-望's avatar
朔-望 已提交
26 27
template <int i>
struct Dim {
28
  static constexpr int dimensions = i;
朔-望's avatar
朔-望 已提交
29

30
  template <typename... Args>
L
liuruilong 已提交
31
  Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
32 33 34
    static_assert(sizeof...(_tail) == i - 1,
                  "Dim initialized with the wrong number of parameters");
  }
朔-望's avatar
朔-望 已提交
35

36
  Dim(int64_t _head, const Dim<i - 1> &_tail) : head(_head), tail(_tail) {}
朔-望's avatar
朔-望 已提交
37

38
  Dim() : head(0), tail() {}
朔-望's avatar
朔-望 已提交
39

40 41 42 43 44
  /** Construct a Dim from a linear index and size.  Uses Fortran
   * order
   * indexing. */
  Dim(int64_t idx, const Dim<i> &size)
      : head(idx % size.head), tail(idx / size.head, size.tail) {}
朔-望's avatar
朔-望 已提交
45

46 47
  /** Construct a Dim with each dimension set to the given index */
  Dim(int64_t idx) : head(idx), tail(idx) {}
朔-望's avatar
朔-望 已提交
48

49 50 51
  bool operator==(const Dim<i> &o) const {
    return (head == o.head) && (tail == o.tail);
  }
朔-望's avatar
朔-望 已提交
52

53
  bool operator!=(const Dim<i> &o) const { return !(*this == o); }
朔-望's avatar
朔-望 已提交
54

55
  int64_t &operator[](int idx);
L
liuruilong 已提交
56

57
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
58

L
liuruilong 已提交
59
  std::string to_string() const;
朔-望's avatar
朔-望 已提交
60

61 62
  int64_t head;
  Dim<i - 1> tail;
朔-望's avatar
朔-望 已提交
63
};
朔-望's avatar
朔-望 已提交
64

朔-望's avatar
朔-望 已提交
65
// Base case specialization
朔-望's avatar
朔-望 已提交
66 67
template <>
struct Dim<0> {
68
  static constexpr int dimensions = 0;
朔-望's avatar
朔-望 已提交
69

70
  Dim(int64_t _head) {}
朔-望's avatar
朔-望 已提交
71

72
  Dim() {}
朔-望's avatar
朔-望 已提交
73

74
  Dim(int idx, const Dim<0> &size) {
朔-望's avatar
朔-望 已提交
75
#ifndef __CUDA_ARCH__
76 77 78
    if (idx > 0) {
      throw std::invalid_argument("Index out of range.");
    }
朔-望's avatar
朔-望 已提交
79
#else
80
    PADDLE_ASSERT(idx == 0);
朔-望's avatar
朔-望 已提交
81
#endif
82
  }
朔-望's avatar
朔-望 已提交
83

84
  bool operator==(const Dim<0> &o) const { return true; }
朔-望's avatar
朔-望 已提交
85

86
  bool operator!=(const Dim<0> &o) const { return false; }
朔-望's avatar
朔-望 已提交
87

88
  int64_t &operator[](int idx);
L
liuruilong 已提交
89

90
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
91 92 93 94 95
};

namespace {

// Helper for accessing Dim classes
朔-望's avatar
朔-望 已提交
96 97
template <int i>
struct DimGetter {
98
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
99
  template <typename D>
L
liuruilong 已提交
100
  static int64_t impl(const D &d) {
101 102 103
    return DimGetter<i - 1>::impl(d.tail);
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
104
  template <typename D>
L
liuruilong 已提交
105
  static int64_t &impl(D &d) {
106 107
    return DimGetter<i - 1>::impl(d.tail);
  }
朔-望's avatar
朔-望 已提交
108 109 110
};

// Eureka! We found the element!
朔-望's avatar
朔-望 已提交
111 112
template <>
struct DimGetter<0> {
113
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
114
  template <typename D>
L
liuruilong 已提交
115
  static int64_t impl(const D &d) {
116 117 118
    return d.head;
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
119
  template <typename D>
L
liuruilong 已提交
120
  static int64_t &impl(D &d) {
朔-望's avatar
朔-望 已提交
121 122
    return d.head;
  }
朔-望's avatar
朔-望 已提交
123 124
};

朔-望's avatar
朔-望 已提交
125
template <int D>
L
liuruilong 已提交
126
int64_t &indexer(Dim<D> &dim, int idx) {
朔-望's avatar
朔-望 已提交
127
#ifndef __CUDA_ARCH__
128 129 130
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
朔-望's avatar
朔-望 已提交
131
#else
132
  PADDLE_ASSERT(idx >= 0);
朔-望's avatar
朔-望 已提交
133
#endif
134 135 136 137
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
138
}
朔-望's avatar
朔-望 已提交
139

朔-望's avatar
朔-望 已提交
140
template <>
L
liuruilong 已提交
141
int64_t &indexer<0>(Dim<0> &dim, int idx) {
朔-望's avatar
朔-望 已提交
142
#ifndef __CUDA_ARCH__
143
  throw std::invalid_argument("Invalid index");
朔-望's avatar
朔-望 已提交
144
#else
145
  PADDLE_ASSERT(false);
朔-望's avatar
朔-望 已提交
146
#if CUDA_VERSION < 8000
147 148 149
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
朔-望's avatar
朔-望 已提交
150
#else
151
  static int64_t head = 0;
朔-望's avatar
朔-望 已提交
152
#endif
153
  return head;
朔-望's avatar
朔-望 已提交
154
#endif
朔-望's avatar
朔-望 已提交
155
}
朔-望's avatar
朔-望 已提交
156

朔-望's avatar
朔-望 已提交
157
template <int D>
L
liuruilong 已提交
158
int64_t indexer(const Dim<D> &dim, int idx) {
朔-望's avatar
朔-望 已提交
159
#ifndef __CUDA_ARCH__
160 161 162
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
朔-望's avatar
朔-望 已提交
163
#else
164
  PADDLE_ASSERT(idx >= 0);
朔-望's avatar
朔-望 已提交
165
#endif
166 167 168 169
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
170 171
}

朔-望's avatar
朔-望 已提交
172
template <>
L
liuruilong 已提交
173
int64_t indexer<0>(const Dim<0> &dim, int idx) {
朔-望's avatar
朔-望 已提交
174
#ifndef __CUDA_ARCH__
175
  throw std::invalid_argument("Invalid index");
朔-望's avatar
朔-望 已提交
176
#else
177
  PADDLE_ASSERT(false);
朔-望's avatar
朔-望 已提交
178
#if CUDA_VERSION < 8000
179 180 181
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
朔-望's avatar
朔-望 已提交
182
#else
183
  static int64_t head = 0;
朔-望's avatar
朔-望 已提交
184
#endif
185
  return head;
朔-望's avatar
朔-望 已提交
186
#endif
朔-望's avatar
朔-望 已提交
187 188
}

朔-望's avatar
朔-望 已提交
189
}  // namespace
朔-望's avatar
朔-望 已提交
190
// Static access to constant Dim
朔-望's avatar
朔-望 已提交
191
template <int i, int l>
L
liuruilong 已提交
192
int64_t get(const Dim<l> &d) {
193
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
194 195 196
}

// Static access to mutable Dim
朔-望's avatar
朔-望 已提交
197
template <int i, int l>
L
liuruilong 已提交
198
int64_t &get(Dim<l> &d) {
199
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
200 201 202
}

// Dynamic access to constant Dim
朔-望's avatar
朔-望 已提交
203
template <int l>
L
liuruilong 已提交
204
int64_t Dim<l>::operator[](int i) const {
205 206
  //  std::cout << "l: " << l << std::endl;
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
207 208 209
}

// Dynamic access to mutable Dim
朔-望's avatar
朔-望 已提交
210
template <int l>
L
liuruilong 已提交
211
int64_t &Dim<l>::operator[](int i) {
212
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
213 214 215
}

// Dynamic access to constant Dim
L
liuruilong 已提交
216
inline int64_t Dim<0>::operator[](int i) const { return indexer(*this, i); }
朔-望's avatar
朔-望 已提交
217 218

// Dynamic access to mutable Dim
L
liuruilong 已提交
219
inline int64_t &Dim<0>::operator[](int i) { return indexer(*this, i); }
朔-望's avatar
朔-望 已提交
220 221 222 223

// Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d)
template <int l>
L
liuruilong 已提交
224
typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l> &d, int i) {
225
  return d[i];
朔-望's avatar
朔-望 已提交
226 227 228 229
}

// Dynamic access to mutable Dim
template <int l>
L
liuruilong 已提交
230
typename std::enable_if<(l > 0), int64_t &>::type get(Dim<l> &d, int i) {
231
  return d[i];
朔-望's avatar
朔-望 已提交
232 233 234 235
}

// Dot product of two dims
template <int i>
L
liuruilong 已提交
236
int64_t linearize(const Dim<i> &a, const Dim<i> &b) {
237
  return a.head * b.head + linearize(a.tail, b.tail);
朔-望's avatar
朔-望 已提交
238 239 240 241 242
}

// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
243
inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
244
  return 0;
朔-望's avatar
朔-望 已提交
245 246 247
}

// Product of a Dim
朔-望's avatar
朔-望 已提交
248
template <int i>
L
liuruilong 已提交
249
int64_t product(const Dim<i> &a, int prod = 1) {
250
  return prod * a.head * product(a.tail);
朔-望's avatar
朔-望 已提交
251 252 253 254
}

// Base case product of a Dim
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
255
template <>
L
liuruilong 已提交
256
inline int64_t product(const Dim<0> &a, int prod) {
257
  return prod;
朔-望's avatar
朔-望 已提交
258 259 260 261
}

// Is 0 <= idx_i < size_i for all i?
template <int i>
L
liuruilong 已提交
262
bool contained(const Dim<i> &idx, const Dim<i> &size) {
263 264
  return ((0 <= idx.head) && (idx.head < size.head) &&
          contained(idx.tail, size.tail));
朔-望's avatar
朔-望 已提交
265 266 267 268 269
}

// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
270
inline bool contained(const Dim<0> &idx, const Dim<0> &size) {
271
  return true;
朔-望's avatar
朔-望 已提交
272 273 274 275 276 277
}

/**
 * \brief Compute exclusive prefix-multiply of a Dim.
 */
template <int i>
L
liuruilong 已提交
278
Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
279
  return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head));
朔-望's avatar
朔-望 已提交
280 281 282 283 284
}

///\cond HIDDEN
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
285
template <>
L
liuruilong 已提交
286
inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
287
  return Dim<0>();
朔-望's avatar
朔-望 已提交
288 289 290 291 292 293
}
///\endcond

/**
 * Add two dimensions together
 */
朔-望's avatar
朔-望 已提交
294
template <int i>
L
liuruilong 已提交
295
Dim<i> dim_plus(const Dim<i> &a, const Dim<i> &b) {
296
  return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
297 298 299 300
}

// Base case
template <>
L
liuruilong 已提交
301
inline Dim<0> dim_plus(const Dim<0> &a, const Dim<0> &b) {
302
  return Dim<0>();
朔-望's avatar
朔-望 已提交
303 304 305
}

template <int i>
L
liuruilong 已提交
306
Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
307
  return dim_plus(lhs, rhs);
朔-望's avatar
朔-望 已提交
308 309 310 311 312
}

/**
 * Multiply two dimensions together
 */
朔-望's avatar
朔-望 已提交
313
template <int i>
L
liuruilong 已提交
314
Dim<i> dim_mult(const Dim<i> &a, const Dim<i> &b) {
315
  return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
316 317 318 319
}

// Base case
template <>
L
liuruilong 已提交
320
inline Dim<0> dim_mult(const Dim<0> &a, const Dim<0> &b) {
321
  return Dim<0>();
朔-望's avatar
朔-望 已提交
322 323 324
}

template <int i>
L
liuruilong 已提交
325
Dim<i> operator*(const Dim<i> &lhs, const Dim<i> &rhs) {
326
  return dim_mult(lhs, rhs);
朔-望's avatar
朔-望 已提交
327 328 329 330 331 332 333 334 335 336 337 338 339
}

/**
 * \brief Normalize strides to ensure any dimension with extent 1
 * has stride 0.
 *
 * \param size Dim object containing the size of an array
 * \param stride Dim object containing stride of an array
 * \return Dim object the same size as \p size with normalized strides
 *
 */

template <int i>
L
liuruilong 已提交
340
Dim<i> normalize_strides(const Dim<i> &size, const Dim<i> &stride) {
341 342
  int norm_stride = size.head == 1 ? 0 : stride.head;
  return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail));
朔-望's avatar
朔-望 已提交
343 344 345 346 347
}

///\cond HIDDEN

template <>
L
liuruilong 已提交
348
inline Dim<0> normalize_strides(const Dim<0> &size, const Dim<0> &stride) {
349
  return Dim<0>();
朔-望's avatar
朔-望 已提交
350 351 352 353 354 355 356 357 358 359 360 361 362
}

///\endcond

/**
 * Helper function to create a Dim
 *
 * \param idxes The type of Dim constructed depends on the number of
 * params
 *
 */

template <typename... Args>
L
liuruilong 已提交
363
Dim<sizeof...(Args)> make_dim(Args... idxes) {
364
  return Dim<sizeof...(Args)>(idxes...);
朔-望's avatar
朔-望 已提交
365 366 367 368 369
}

// Allows us to output a Dim
// XXX For some reason, overloading fails to resolve this correctly
template <int i>
朔-望's avatar
朔-望 已提交
370 371
typename std::enable_if<(i > 1), std::ostream &>::type operator<<(
    std::ostream &os, const Dim<i> &d) {
372 373
  os << d.head << ", " << d.tail;
  return os;
朔-望's avatar
朔-望 已提交
374 375 376 377 378
}

// Base case that allows us to output a Dim
// XXX I wish this could be an overload instead of a template
template <int i>
朔-望's avatar
朔-望 已提交
379 380
typename std::enable_if<(i == 1), std::ostream &>::type operator<<(
    std::ostream &os, const Dim<i> &d) {
381 382
  os << d.head;
  return os;
朔-望's avatar
朔-望 已提交
383 384 385
}

inline std::ostream &operator<<(std::ostream &os, const Dim<0> &d) {
386
  return os;
朔-望's avatar
朔-望 已提交
387 388
}

朔-望's avatar
朔-望 已提交
389
template <int i>
L
liuruilong 已提交
390
std::string Dim<i>::to_string() const {
391
  std::stringstream stream;
朔-望's avatar
朔-望 已提交
392

393
  stream << *this;
朔-望's avatar
朔-望 已提交
394

395
  return stream.str();
朔-望's avatar
朔-望 已提交
396 397 398
}

template <int D>
L
liuruilong 已提交
399
Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
400
  Dim<D> result;
朔-望's avatar
朔-望 已提交
401

402 403 404 405
  for (int i = 0; i < D - 1; ++i) {
    result[i] = linear_index % extents[i];
    linear_index /= extents[i];
  }
朔-望's avatar
朔-望 已提交
406

407
  result[D - 1] = linear_index;
朔-望's avatar
朔-望 已提交
408

409
  return result;
朔-望's avatar
朔-望 已提交
410 411
}

朔-望's avatar
朔-望 已提交
412 413
}  // namespace framework
}  // namespace paddle_mobile