dim.h 10.2 KB
Newer Older
W
wangliu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

朔-望's avatar
朔-望 已提交
15 16 17 18 19 20 21 22 23 24
#pragma once

#include <iostream>
#include <sstream>
#include <stdexcept>
#include <type_traits>

#include "platform/hostdevice.h"

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
25
namespace framework {
朔-望's avatar
朔-望 已提交
26

朔-望's avatar
朔-望 已提交
27
// Statically sized, statically indexed dimension
朔-望's avatar
朔-望 已提交
28 29
template <int i>
struct Dim {
30
  static constexpr int dimensions = i;
朔-望's avatar
朔-望 已提交
31

32 33 34 35 36
  template <typename... Args>
  HOSTDEVICE Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
    static_assert(sizeof...(_tail) == i - 1,
                  "Dim initialized with the wrong number of parameters");
  }
朔-望's avatar
朔-望 已提交
37

38 39
  HOSTDEVICE
  Dim(int64_t _head, const Dim<i - 1> &_tail) : head(_head), tail(_tail) {}
朔-望's avatar
朔-望 已提交
40

41 42
  HOSTDEVICE
  Dim() : head(0), tail() {}
朔-望's avatar
朔-望 已提交
43

44 45 46 47 48 49
  /** Construct a Dim from a linear index and size.  Uses Fortran
   * order
   * indexing. */
  HOSTDEVICE
  Dim(int64_t idx, const Dim<i> &size)
      : head(idx % size.head), tail(idx / size.head, size.tail) {}
朔-望's avatar
朔-望 已提交
50

51 52 53
  /** Construct a Dim with each dimension set to the given index */
  HOSTDEVICE
  Dim(int64_t idx) : head(idx), tail(idx) {}
朔-望's avatar
朔-望 已提交
54

55 56 57 58
  HOSTDEVICE
  bool operator==(const Dim<i> &o) const {
    return (head == o.head) && (tail == o.tail);
  }
朔-望's avatar
朔-望 已提交
59

60 61
  HOSTDEVICE
  bool operator!=(const Dim<i> &o) const { return !(*this == o); }
朔-望's avatar
朔-望 已提交
62

63 64 65 66
  HOSTDEVICE
  int64_t &operator[](int idx);
  HOSTDEVICE
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
67

68
  HOST std::string to_string() const;
朔-望's avatar
朔-望 已提交
69

70 71
  int64_t head;
  Dim<i - 1> tail;
朔-望's avatar
朔-望 已提交
72
};
朔-望's avatar
朔-望 已提交
73

朔-望's avatar
朔-望 已提交
74
// Base case specialization
朔-望's avatar
朔-望 已提交
75 76
template <>
struct Dim<0> {
77
  static constexpr int dimensions = 0;
朔-望's avatar
朔-望 已提交
78

79 80
  HOSTDEVICE
  Dim(int64_t _head) {}
朔-望's avatar
朔-望 已提交
81

82 83
  HOSTDEVICE
  Dim() {}
朔-望's avatar
朔-望 已提交
84

85 86
  HOSTDEVICE
  Dim(int idx, const Dim<0> &size) {
朔-望's avatar
朔-望 已提交
87
#ifndef __CUDA_ARCH__
88 89 90
    if (idx > 0) {
      throw std::invalid_argument("Index out of range.");
    }
朔-望's avatar
朔-望 已提交
91
#else
92
    PADDLE_ASSERT(idx == 0);
朔-望's avatar
朔-望 已提交
93
#endif
94
  }
朔-望's avatar
朔-望 已提交
95

96 97
  HOSTDEVICE
  bool operator==(const Dim<0> &o) const { return true; }
朔-望's avatar
朔-望 已提交
98

99 100
  HOSTDEVICE
  bool operator!=(const Dim<0> &o) const { return false; }
朔-望's avatar
朔-望 已提交
101

102 103 104 105
  HOSTDEVICE
  int64_t &operator[](int idx);
  HOSTDEVICE
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
106 107 108 109 110
};

namespace {

// Helper for accessing Dim classes
朔-望's avatar
朔-望 已提交
111 112
template <int i>
struct DimGetter {
113
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
114 115
  template <typename D>
  HOSTDEVICE static int64_t impl(const D &d) {
116 117 118
    return DimGetter<i - 1>::impl(d.tail);
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
119 120
  template <typename D>
  HOSTDEVICE static int64_t &impl(D &d) {
121 122
    return DimGetter<i - 1>::impl(d.tail);
  }
朔-望's avatar
朔-望 已提交
123 124 125
};

// Eureka! We found the element!
朔-望's avatar
朔-望 已提交
126 127
template <>
struct DimGetter<0> {
128
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
129 130
  template <typename D>
  HOSTDEVICE static int64_t impl(const D &d) {
131 132 133
    return d.head;
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
134 135 136 137
  template <typename D>
  HOSTDEVICE static int64_t &impl(D &d) {
    return d.head;
  }
朔-望's avatar
朔-望 已提交
138 139
};

朔-望's avatar
朔-望 已提交
140 141
template <int D>
HOSTDEVICE int64_t &indexer(Dim<D> &dim, int idx) {
朔-望's avatar
朔-望 已提交
142
#ifndef __CUDA_ARCH__
143 144 145
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
朔-望's avatar
朔-望 已提交
146
#else
147
  PADDLE_ASSERT(idx >= 0);
朔-望's avatar
朔-望 已提交
148
#endif
149 150 151 152
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
153
}
朔-望's avatar
朔-望 已提交
154

朔-望's avatar
朔-望 已提交
155 156
template <>
HOSTDEVICE int64_t &indexer<0>(Dim<0> &dim, int idx) {
朔-望's avatar
朔-望 已提交
157
#ifndef __CUDA_ARCH__
158
  throw std::invalid_argument("Invalid index");
朔-望's avatar
朔-望 已提交
159
#else
160
  PADDLE_ASSERT(false);
朔-望's avatar
朔-望 已提交
161
#if CUDA_VERSION < 8000
162 163 164
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
朔-望's avatar
朔-望 已提交
165
#else
166
  static int64_t head = 0;
朔-望's avatar
朔-望 已提交
167
#endif
168
  return head;
朔-望's avatar
朔-望 已提交
169
#endif
朔-望's avatar
朔-望 已提交
170
}
朔-望's avatar
朔-望 已提交
171

朔-望's avatar
朔-望 已提交
172 173
template <int D>
HOSTDEVICE int64_t indexer(const Dim<D> &dim, int idx) {
朔-望's avatar
朔-望 已提交
174
#ifndef __CUDA_ARCH__
175 176 177
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
朔-望's avatar
朔-望 已提交
178
#else
179
  PADDLE_ASSERT(idx >= 0);
朔-望's avatar
朔-望 已提交
180
#endif
181 182 183 184
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
185 186
}

朔-望's avatar
朔-望 已提交
187 188
template <>
HOSTDEVICE int64_t indexer<0>(const Dim<0> &dim, int idx) {
朔-望's avatar
朔-望 已提交
189
#ifndef __CUDA_ARCH__
190
  throw std::invalid_argument("Invalid index");
朔-望's avatar
朔-望 已提交
191
#else
192
  PADDLE_ASSERT(false);
朔-望's avatar
朔-望 已提交
193
#if CUDA_VERSION < 8000
194 195 196
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
朔-望's avatar
朔-望 已提交
197
#else
198
  static int64_t head = 0;
朔-望's avatar
朔-望 已提交
199
#endif
200
  return head;
朔-望's avatar
朔-望 已提交
201
#endif
朔-望's avatar
朔-望 已提交
202 203
}

朔-望's avatar
朔-望 已提交
204
}  // namespace
朔-望's avatar
朔-望 已提交
205
// Static access to constant Dim
朔-望's avatar
朔-望 已提交
206 207
template <int i, int l>
HOSTDEVICE int64_t get(const Dim<l> &d) {
208
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
209 210 211
}

// Static access to mutable Dim
朔-望's avatar
朔-望 已提交
212 213
template <int i, int l>
HOSTDEVICE int64_t &get(Dim<l> &d) {
214
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
215 216 217
}

// Dynamic access to constant Dim
朔-望's avatar
朔-望 已提交
218 219
template <int l>
HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
220 221
  //  std::cout << "l: " << l << std::endl;
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
222 223 224
}

// Dynamic access to mutable Dim
朔-望's avatar
朔-望 已提交
225 226
template <int l>
HOSTDEVICE int64_t &Dim<l>::operator[](int i) {
227
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
228 229 230 231
}

// Dynamic access to constant Dim
inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const {
232
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
233 234 235 236
}

// Dynamic access to mutable Dim
inline HOSTDEVICE int64_t &Dim<0>::operator[](int i) {
237
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
238 239 240 241 242 243 244
}

// Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d)
template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l> &d,
                                                               int i) {
245
  return d[i];
朔-望's avatar
朔-望 已提交
246 247 248 249 250 251
}

// Dynamic access to mutable Dim
template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int64_t &>::type get(Dim<l> &d,
                                                                 int i) {
252
  return d[i];
朔-望's avatar
朔-望 已提交
253 254 255 256 257
}

// Dot product of two dims
template <int i>
HOSTDEVICE int64_t linearize(const Dim<i> &a, const Dim<i> &b) {
258
  return a.head * b.head + linearize(a.tail, b.tail);
朔-望's avatar
朔-望 已提交
259 260 261 262 263 264
}

// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
265
  return 0;
朔-望's avatar
朔-望 已提交
266 267 268
}

// Product of a Dim
朔-望's avatar
朔-望 已提交
269 270
template <int i>
HOSTDEVICE int64_t product(const Dim<i> &a, int prod = 1) {
271
  return prod * a.head * product(a.tail);
朔-望's avatar
朔-望 已提交
272 273 274 275
}

// Base case product of a Dim
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
276 277
template <>
HOSTDEVICE inline int64_t product(const Dim<0> &a, int prod) {
278
  return prod;
朔-望's avatar
朔-望 已提交
279 280 281 282 283
}

// Is 0 <= idx_i < size_i for all i?
template <int i>
HOSTDEVICE bool contained(const Dim<i> &idx, const Dim<i> &size) {
284 285
  return ((0 <= idx.head) && (idx.head < size.head) &&
          contained(idx.tail, size.tail));
朔-望's avatar
朔-望 已提交
286 287 288 289 290 291
}

// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline bool contained(const Dim<0> &idx, const Dim<0> &size) {
292
  return true;
朔-望's avatar
朔-望 已提交
293 294 295 296 297 298 299
}

/**
 * \brief Compute exclusive prefix-multiply of a Dim.
 */
template <int i>
HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
300
  return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head));
朔-望's avatar
朔-望 已提交
301 302 303 304 305
}

///\cond HIDDEN
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
306 307
template <>
HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
308
  return Dim<0>();
朔-望's avatar
朔-望 已提交
309 310 311 312 313 314
}
///\endcond

/**
 * Add two dimensions together
 */
朔-望's avatar
朔-望 已提交
315 316
template <int i>
HOSTDEVICE Dim<i> dim_plus(const Dim<i> &a, const Dim<i> &b) {
317
  return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
318 319 320 321 322
}

// Base case
template <>
HOSTDEVICE inline Dim<0> dim_plus(const Dim<0> &a, const Dim<0> &b) {
323
  return Dim<0>();
朔-望's avatar
朔-望 已提交
324 325 326 327
}

template <int i>
HOSTDEVICE Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
328
  return dim_plus(lhs, rhs);
朔-望's avatar
朔-望 已提交
329 330 331 332 333
}

/**
 * Multiply two dimensions together
 */
朔-望's avatar
朔-望 已提交
334 335
template <int i>
HOSTDEVICE Dim<i> dim_mult(const Dim<i> &a, const Dim<i> &b) {
336
  return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
337 338 339 340 341
}

// Base case
template <>
HOSTDEVICE inline Dim<0> dim_mult(const Dim<0> &a, const Dim<0> &b) {
342
  return Dim<0>();
朔-望's avatar
朔-望 已提交
343 344 345 346
}

template <int i>
HOSTDEVICE Dim<i> operator*(const Dim<i> &lhs, const Dim<i> &rhs) {
347
  return dim_mult(lhs, rhs);
朔-望's avatar
朔-望 已提交
348 349 350 351 352 353 354 355 356 357 358 359 360 361
}

/**
 * \brief Normalize strides to ensure any dimension with extent 1
 * has stride 0.
 *
 * \param size Dim object containing the size of an array
 * \param stride Dim object containing stride of an array
 * \return Dim object the same size as \p size with normalized strides
 *
 */

template <int i>
HOSTDEVICE Dim<i> normalize_strides(const Dim<i> &size, const Dim<i> &stride) {
362 363
  int norm_stride = size.head == 1 ? 0 : stride.head;
  return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail));
朔-望's avatar
朔-望 已提交
364 365 366 367 368 369 370
}

///\cond HIDDEN

template <>
HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0> &size,
                                           const Dim<0> &stride) {
371
  return Dim<0>();
朔-望's avatar
朔-望 已提交
372 373 374 375 376 377 378 379 380 381 382 383 384 385
}

///\endcond

/**
 * Helper function to create a Dim
 *
 * \param idxes The type of Dim constructed depends on the number of
 * params
 *
 */

template <typename... Args>
HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) {
386
  return Dim<sizeof...(Args)>(idxes...);
朔-望's avatar
朔-望 已提交
387 388 389 390 391
}

// Allows us to output a Dim
// XXX For some reason, overloading fails to resolve this correctly
template <int i>
朔-望's avatar
朔-望 已提交
392 393
typename std::enable_if<(i > 1), std::ostream &>::type operator<<(
    std::ostream &os, const Dim<i> &d) {
394 395
  os << d.head << ", " << d.tail;
  return os;
朔-望's avatar
朔-望 已提交
396 397 398 399 400
}

// Base case that allows us to output a Dim
// XXX I wish this could be an overload instead of a template
template <int i>
朔-望's avatar
朔-望 已提交
401 402
typename std::enable_if<(i == 1), std::ostream &>::type operator<<(
    std::ostream &os, const Dim<i> &d) {
403 404
  os << d.head;
  return os;
朔-望's avatar
朔-望 已提交
405 406 407
}

inline std::ostream &operator<<(std::ostream &os, const Dim<0> &d) {
408
  return os;
朔-望's avatar
朔-望 已提交
409 410
}

朔-望's avatar
朔-望 已提交
411 412
template <int i>
HOST std::string Dim<i>::to_string() const {
413
  std::stringstream stream;
朔-望's avatar
朔-望 已提交
414

415
  stream << *this;
朔-望's avatar
朔-望 已提交
416

417
  return stream.str();
朔-望's avatar
朔-望 已提交
418 419 420 421
}

template <int D>
HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
422
  Dim<D> result;
朔-望's avatar
朔-望 已提交
423

424 425 426 427
  for (int i = 0; i < D - 1; ++i) {
    result[i] = linear_index % extents[i];
    linear_index /= extents[i];
  }
朔-望's avatar
朔-望 已提交
428

429
  result[D - 1] = linear_index;
朔-望's avatar
朔-望 已提交
430

431
  return result;
朔-望's avatar
朔-望 已提交
432 433
}

朔-望's avatar
朔-望 已提交
434 435
}  // namespace framework
}  // namespace paddle_mobile