dim.h 9.6 KB
Newer Older
W
wangliu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

朔-望's avatar
朔-望 已提交
15 16 17 18 19 20 21 22
#pragma once

#include <iostream>
#include <sstream>
#include <stdexcept>
#include <type_traits>

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
23
namespace framework {
朔-望's avatar
朔-望 已提交
24

朔-望's avatar
朔-望 已提交
25
// Statically sized, statically indexed dimension
朔-望's avatar
朔-望 已提交
26 27
template <int i>
struct Dim {
28
  static constexpr int dimensions = i;
朔-望's avatar
朔-望 已提交
29

30
  template <typename... Args>
L
liuruilong 已提交
31
  Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
32 33 34
    static_assert(sizeof...(_tail) == i - 1,
                  "Dim initialized with the wrong number of parameters");
  }
朔-望's avatar
朔-望 已提交
35

36
  Dim(int64_t _head, const Dim<i - 1> &_tail) : head(_head), tail(_tail) {}
朔-望's avatar
朔-望 已提交
37

38
  Dim() : head(0), tail() {}
朔-望's avatar
朔-望 已提交
39

40 41 42 43 44
  /** Construct a Dim from a linear index and size.  Uses Fortran
   * order
   * indexing. */
  Dim(int64_t idx, const Dim<i> &size)
      : head(idx % size.head), tail(idx / size.head, size.tail) {}
朔-望's avatar
朔-望 已提交
45

46 47
  /** Construct a Dim with each dimension set to the given index */
  Dim(int64_t idx) : head(idx), tail(idx) {}
朔-望's avatar
朔-望 已提交
48

49 50 51
  bool operator==(const Dim<i> &o) const {
    return (head == o.head) && (tail == o.tail);
  }
朔-望's avatar
朔-望 已提交
52

53
  bool operator!=(const Dim<i> &o) const { return !(*this == o); }
朔-望's avatar
朔-望 已提交
54

55
  int64_t &operator[](int idx);
L
liuruilong 已提交
56

57
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
58

L
liuruilong 已提交
59
  std::string to_string() const;
朔-望's avatar
朔-望 已提交
60

61 62
  int64_t head;
  Dim<i - 1> tail;
朔-望's avatar
朔-望 已提交
63
};
朔-望's avatar
朔-望 已提交
64

朔-望's avatar
朔-望 已提交
65
// Base case specialization
朔-望's avatar
朔-望 已提交
66 67
template <>
struct Dim<0> {
68
  static constexpr int dimensions = 0;
朔-望's avatar
朔-望 已提交
69

70
  Dim(int64_t _head) {}
朔-望's avatar
朔-望 已提交
71

72
  Dim() {}
朔-望's avatar
朔-望 已提交
73

74
  Dim(int idx, const Dim<0> &size) {
朔-望's avatar
朔-望 已提交
75
#ifndef __CUDA_ARCH__
76 77 78
    if (idx > 0) {
      throw std::invalid_argument("Index out of range.");
    }
朔-望's avatar
朔-望 已提交
79
#else
80
    PADDLE_ASSERT(idx == 0);
朔-望's avatar
朔-望 已提交
81
#endif
82
  }
朔-望's avatar
朔-望 已提交
83

84
  bool operator==(const Dim<0> &o) const { return true; }
朔-望's avatar
朔-望 已提交
85

86
  bool operator!=(const Dim<0> &o) const { return false; }
朔-望's avatar
朔-望 已提交
87

88
  int64_t &operator[](int idx);
L
liuruilong 已提交
89

90
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
91 92 93 94 95
};

namespace {

// Helper for accessing Dim classes
朔-望's avatar
朔-望 已提交
96 97
template <int i>
struct DimGetter {
98
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
99
  template <typename D>
L
liuruilong 已提交
100
  static int64_t impl(const D &d) {
101 102 103
    return DimGetter<i - 1>::impl(d.tail);
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
104
  template <typename D>
L
liuruilong 已提交
105
  static int64_t &impl(D &d) {
106 107
    return DimGetter<i - 1>::impl(d.tail);
  }
朔-望's avatar
朔-望 已提交
108 109 110
};

// Eureka! We found the element!
朔-望's avatar
朔-望 已提交
111 112
template <>
struct DimGetter<0> {
113
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
114
  template <typename D>
L
liuruilong 已提交
115
  static int64_t impl(const D &d) {
116 117 118
    return d.head;
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
119
  template <typename D>
L
liuruilong 已提交
120
  static int64_t &impl(D &d) {
朔-望's avatar
朔-望 已提交
121 122
    return d.head;
  }
朔-望's avatar
朔-望 已提交
123 124
};

朔-望's avatar
朔-望 已提交
125
template <int D>
L
liuruilong 已提交
126
int64_t &indexer(Dim<D> &dim, int idx) {
朔-望's avatar
朔-望 已提交
127
#ifndef __CUDA_ARCH__
128 129 130
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
朔-望's avatar
朔-望 已提交
131
#else
132
  PADDLE_ASSERT(idx >= 0);
朔-望's avatar
朔-望 已提交
133
#endif
134 135 136 137
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
138
}
朔-望's avatar
朔-望 已提交
139

朔-望's avatar
朔-望 已提交
140
template <>
L
liuruilong 已提交
141
int64_t &indexer<0>(Dim<0> &dim, int idx) {
朔-望's avatar
朔-望 已提交
142
#ifndef __CUDA_ARCH__
143
  throw std::invalid_argument("Invalid index");
朔-望's avatar
朔-望 已提交
144
#else
145
  PADDLE_ASSERT(false);
朔-望's avatar
朔-望 已提交
146
#if CUDA_VERSION < 8000
147 148 149
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
朔-望's avatar
朔-望 已提交
150
#else
151
  static int64_t head = 0;
朔-望's avatar
朔-望 已提交
152
#endif
153
  return head;
朔-望's avatar
朔-望 已提交
154
#endif
朔-望's avatar
朔-望 已提交
155
}
朔-望's avatar
朔-望 已提交
156

朔-望's avatar
朔-望 已提交
157
template <int D>
L
liuruilong 已提交
158
int64_t indexer(const Dim<D> &dim, int idx) {
朔-望's avatar
朔-望 已提交
159
#ifndef __CUDA_ARCH__
160 161 162
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
朔-望's avatar
朔-望 已提交
163
#else
164
  PADDLE_ASSERT(idx >= 0);
朔-望's avatar
朔-望 已提交
165
#endif
166 167 168 169
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
170 171
}

朔-望's avatar
朔-望 已提交
172
template <>
L
liuruilong 已提交
173
int64_t indexer<0>(const Dim<0> &dim, int idx) {
朔-望's avatar
朔-望 已提交
174
#ifndef __CUDA_ARCH__
175
  throw std::invalid_argument("Invalid index");
朔-望's avatar
朔-望 已提交
176
#else
177
  PADDLE_ASSERT(false);
朔-望's avatar
朔-望 已提交
178
#if CUDA_VERSION < 8000
179 180 181
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
朔-望's avatar
朔-望 已提交
182
#else
183
  static int64_t head = 0;
朔-望's avatar
朔-望 已提交
184
#endif
185
  return head;
朔-望's avatar
朔-望 已提交
186
#endif
朔-望's avatar
朔-望 已提交
187 188
}

朔-望's avatar
朔-望 已提交
189
}  // namespace
朔-望's avatar
朔-望 已提交
190
// Static access to constant Dim
朔-望's avatar
朔-望 已提交
191
template <int i, int l>
L
liuruilong 已提交
192
int64_t get(const Dim<l> &d) {
193
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
194 195 196
}

// Static access to mutable Dim
朔-望's avatar
朔-望 已提交
197
template <int i, int l>
L
liuruilong 已提交
198
int64_t &get(Dim<l> &d) {
199
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
200 201 202
}

// Dynamic access to constant Dim
朔-望's avatar
朔-望 已提交
203
template <int l>
L
liuruilong 已提交
204
int64_t Dim<l>::operator[](int i) const {
205 206
  //  std::cout << "l: " << l << std::endl;
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
207 208 209
}

// Dynamic access to mutable Dim
朔-望's avatar
朔-望 已提交
210
template <int l>
L
liuruilong 已提交
211
int64_t &Dim<l>::operator[](int i) {
212
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
213 214 215
}

// Dynamic access to constant Dim
L
liuruilong 已提交
216
inline int64_t Dim<0>::operator[](int i) const {
217
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
218 219 220
}

// Dynamic access to mutable Dim
L
liuruilong 已提交
221
inline int64_t &Dim<0>::operator[](int i) {
222
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
223 224 225 226 227
}

// Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d)
template <int l>
L
liuruilong 已提交
228
typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l> &d,
朔-望's avatar
朔-望 已提交
229
                                                               int i) {
230
  return d[i];
朔-望's avatar
朔-望 已提交
231 232 233 234
}

// Dynamic access to mutable Dim
template <int l>
L
liuruilong 已提交
235
typename std::enable_if<(l > 0), int64_t &>::type get(Dim<l> &d,
朔-望's avatar
朔-望 已提交
236
                                                                 int i) {
237
  return d[i];
朔-望's avatar
朔-望 已提交
238 239 240 241
}

// Dot product of two dims
template <int i>
L
liuruilong 已提交
242
int64_t linearize(const Dim<i> &a, const Dim<i> &b) {
243
  return a.head * b.head + linearize(a.tail, b.tail);
朔-望's avatar
朔-望 已提交
244 245 246 247 248
}

// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
249
inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
250
  return 0;
朔-望's avatar
朔-望 已提交
251 252 253
}

// Product of a Dim
朔-望's avatar
朔-望 已提交
254
template <int i>
L
liuruilong 已提交
255
int64_t product(const Dim<i> &a, int prod = 1) {
256
  return prod * a.head * product(a.tail);
朔-望's avatar
朔-望 已提交
257 258 259 260
}

// Base case product of a Dim
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
261
template <>
L
liuruilong 已提交
262
inline int64_t product(const Dim<0> &a, int prod) {
263
  return prod;
朔-望's avatar
朔-望 已提交
264 265 266 267
}

// Is 0 <= idx_i < size_i for all i?
template <int i>
L
liuruilong 已提交
268
bool contained(const Dim<i> &idx, const Dim<i> &size) {
269 270
  return ((0 <= idx.head) && (idx.head < size.head) &&
          contained(idx.tail, size.tail));
朔-望's avatar
朔-望 已提交
271 272 273 274 275
}

// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
276
inline bool contained(const Dim<0> &idx, const Dim<0> &size) {
277
  return true;
朔-望's avatar
朔-望 已提交
278 279 280 281 282 283
}

/**
 * \brief Compute exclusive prefix-multiply of a Dim.
 */
template <int i>
L
liuruilong 已提交
284
Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
285
  return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head));
朔-望's avatar
朔-望 已提交
286 287 288 289 290
}

///\cond HIDDEN
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
291
template <>
L
liuruilong 已提交
292
inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
293
  return Dim<0>();
朔-望's avatar
朔-望 已提交
294 295 296 297 298 299
}
///\endcond

/**
 * Add two dimensions together
 */
朔-望's avatar
朔-望 已提交
300
template <int i>
L
liuruilong 已提交
301
Dim<i> dim_plus(const Dim<i> &a, const Dim<i> &b) {
302
  return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
303 304 305 306
}

// Base case
template <>
L
liuruilong 已提交
307
inline Dim<0> dim_plus(const Dim<0> &a, const Dim<0> &b) {
308
  return Dim<0>();
朔-望's avatar
朔-望 已提交
309 310 311
}

template <int i>
L
liuruilong 已提交
312
Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
313
  return dim_plus(lhs, rhs);
朔-望's avatar
朔-望 已提交
314 315 316 317 318
}

/**
 * Multiply two dimensions together
 */
朔-望's avatar
朔-望 已提交
319
template <int i>
L
liuruilong 已提交
320
Dim<i> dim_mult(const Dim<i> &a, const Dim<i> &b) {
321
  return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
322 323 324 325
}

// Base case
template <>
L
liuruilong 已提交
326
inline Dim<0> dim_mult(const Dim<0> &a, const Dim<0> &b) {
327
  return Dim<0>();
朔-望's avatar
朔-望 已提交
328 329 330
}

template <int i>
L
liuruilong 已提交
331
Dim<i> operator*(const Dim<i> &lhs, const Dim<i> &rhs) {
332
  return dim_mult(lhs, rhs);
朔-望's avatar
朔-望 已提交
333 334 335 336 337 338 339 340 341 342 343 344 345
}

/**
 * \brief Normalize strides to ensure any dimension with extent 1
 * has stride 0.
 *
 * \param size Dim object containing the size of an array
 * \param stride Dim object containing stride of an array
 * \return Dim object the same size as \p size with normalized strides
 *
 */

template <int i>
L
liuruilong 已提交
346
Dim<i> normalize_strides(const Dim<i> &size, const Dim<i> &stride) {
347 348
  int norm_stride = size.head == 1 ? 0 : stride.head;
  return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail));
朔-望's avatar
朔-望 已提交
349 350 351 352 353
}

///\cond HIDDEN

template <>
L
liuruilong 已提交
354
inline Dim<0> normalize_strides(const Dim<0> &size,
朔-望's avatar
朔-望 已提交
355
                                           const Dim<0> &stride) {
356
  return Dim<0>();
朔-望's avatar
朔-望 已提交
357 358 359 360 361 362 363 364 365 366 367 368 369
}

///\endcond

/**
 * Helper function to create a Dim
 *
 * \param idxes The type of Dim constructed depends on the number of
 * params
 *
 */

template <typename... Args>
L
liuruilong 已提交
370
Dim<sizeof...(Args)> make_dim(Args... idxes) {
371
  return Dim<sizeof...(Args)>(idxes...);
朔-望's avatar
朔-望 已提交
372 373 374 375 376
}

// Allows us to output a Dim
// XXX For some reason, overloading fails to resolve this correctly
template <int i>
朔-望's avatar
朔-望 已提交
377 378
typename std::enable_if<(i > 1), std::ostream &>::type operator<<(
    std::ostream &os, const Dim<i> &d) {
379 380
  os << d.head << ", " << d.tail;
  return os;
朔-望's avatar
朔-望 已提交
381 382 383 384 385
}

// Base case that allows us to output a Dim
// XXX I wish this could be an overload instead of a template
template <int i>
朔-望's avatar
朔-望 已提交
386 387
typename std::enable_if<(i == 1), std::ostream &>::type operator<<(
    std::ostream &os, const Dim<i> &d) {
388 389
  os << d.head;
  return os;
朔-望's avatar
朔-望 已提交
390 391 392
}

inline std::ostream &operator<<(std::ostream &os, const Dim<0> &d) {
393
  return os;
朔-望's avatar
朔-望 已提交
394 395
}

朔-望's avatar
朔-望 已提交
396
template <int i>
L
liuruilong 已提交
397
std::string Dim<i>::to_string() const {
398
  std::stringstream stream;
朔-望's avatar
朔-望 已提交
399

400
  stream << *this;
朔-望's avatar
朔-望 已提交
401

402
  return stream.str();
朔-望's avatar
朔-望 已提交
403 404 405
}

template <int D>
L
liuruilong 已提交
406
Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
407
  Dim<D> result;
朔-望's avatar
朔-望 已提交
408

409 410 411 412
  for (int i = 0; i < D - 1; ++i) {
    result[i] = linear_index % extents[i];
    linear_index /= extents[i];
  }
朔-望's avatar
朔-望 已提交
413

414
  result[D - 1] = linear_index;
朔-望's avatar
朔-望 已提交
415

416
  return result;
朔-望's avatar
朔-望 已提交
417 418
}

朔-望's avatar
朔-望 已提交
419 420
}  // namespace framework
}  // namespace paddle_mobile