dim.h 10.3 KB
Newer Older
朔-望's avatar
朔-望 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include <iostream>
#include <sstream>
#include <stdexcept>
#include <type_traits>

#include "platform/hostdevice.h"

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
24
namespace framework {
朔-望's avatar
朔-望 已提交
25

朔-望's avatar
朔-望 已提交
26
// Statically sized, statically indexed dimension
朔-望's avatar
朔-望 已提交
27 28
template <int i>
struct Dim {
29
  static constexpr int dimensions = i;
朔-望's avatar
朔-望 已提交
30

31 32 33 34 35
  template <typename... Args>
  HOSTDEVICE Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
    static_assert(sizeof...(_tail) == i - 1,
                  "Dim initialized with the wrong number of parameters");
  }
朔-望's avatar
朔-望 已提交
36

37 38
  HOSTDEVICE
  Dim(int64_t _head, const Dim<i - 1> &_tail) : head(_head), tail(_tail) {}
朔-望's avatar
朔-望 已提交
39

40 41
  HOSTDEVICE
  Dim() : head(0), tail() {}
朔-望's avatar
朔-望 已提交
42

43 44 45 46 47 48
  /** Construct a Dim from a linear index and size.  Uses Fortran
   * order
   * indexing. */
  HOSTDEVICE
  Dim(int64_t idx, const Dim<i> &size)
      : head(idx % size.head), tail(idx / size.head, size.tail) {}
朔-望's avatar
朔-望 已提交
49

50 51 52
  /** Construct a Dim with each dimension set to the given index */
  HOSTDEVICE
  Dim(int64_t idx) : head(idx), tail(idx) {}
朔-望's avatar
朔-望 已提交
53

54 55 56 57
  HOSTDEVICE
  bool operator==(const Dim<i> &o) const {
    return (head == o.head) && (tail == o.tail);
  }
朔-望's avatar
朔-望 已提交
58

59 60
  HOSTDEVICE
  bool operator!=(const Dim<i> &o) const { return !(*this == o); }
朔-望's avatar
朔-望 已提交
61

62 63 64 65
  HOSTDEVICE
  int64_t &operator[](int idx);
  HOSTDEVICE
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
66

67
  HOST std::string to_string() const;
朔-望's avatar
朔-望 已提交
68

69 70
  int64_t head;
  Dim<i - 1> tail;
朔-望's avatar
朔-望 已提交
71
};
朔-望's avatar
朔-望 已提交
72

朔-望's avatar
朔-望 已提交
73
// Base case specialization
朔-望's avatar
朔-望 已提交
74 75
template <>
struct Dim<0> {
76
  static constexpr int dimensions = 0;
朔-望's avatar
朔-望 已提交
77

78 79
  HOSTDEVICE
  Dim(int64_t _head) {}
朔-望's avatar
朔-望 已提交
80

81 82
  HOSTDEVICE
  Dim() {}
朔-望's avatar
朔-望 已提交
83

84 85
  HOSTDEVICE
  Dim(int idx, const Dim<0> &size) {
朔-望's avatar
朔-望 已提交
86
#ifndef __CUDA_ARCH__
87 88 89
    if (idx > 0) {
      throw std::invalid_argument("Index out of range.");
    }
朔-望's avatar
朔-望 已提交
90
#else
91
    PADDLE_ASSERT(idx == 0);
朔-望's avatar
朔-望 已提交
92
#endif
93
  }
朔-望's avatar
朔-望 已提交
94

95 96
  HOSTDEVICE
  bool operator==(const Dim<0> &o) const { return true; }
朔-望's avatar
朔-望 已提交
97

98 99
  HOSTDEVICE
  bool operator!=(const Dim<0> &o) const { return false; }
朔-望's avatar
朔-望 已提交
100

101 102 103 104
  HOSTDEVICE
  int64_t &operator[](int idx);
  HOSTDEVICE
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
105 106 107 108 109
};

namespace {

// Helper for accessing Dim classes
朔-望's avatar
朔-望 已提交
110 111
template <int i>
struct DimGetter {
112
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
113 114
  template <typename D>
  HOSTDEVICE static int64_t impl(const D &d) {
115 116 117
    return DimGetter<i - 1>::impl(d.tail);
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
118 119
  template <typename D>
  HOSTDEVICE static int64_t &impl(D &d) {
120 121
    return DimGetter<i - 1>::impl(d.tail);
  }
朔-望's avatar
朔-望 已提交
122 123 124
};

// Eureka! We found the element!
朔-望's avatar
朔-望 已提交
125 126
template <>
struct DimGetter<0> {
127
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
128 129
  template <typename D>
  HOSTDEVICE static int64_t impl(const D &d) {
130 131 132
    return d.head;
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
133 134 135 136
  template <typename D>
  HOSTDEVICE static int64_t &impl(D &d) {
    return d.head;
  }
朔-望's avatar
朔-望 已提交
137 138
};

朔-望's avatar
朔-望 已提交
139 140
template <int D>
HOSTDEVICE int64_t &indexer(Dim<D> &dim, int idx) {
朔-望's avatar
朔-望 已提交
141
#ifndef __CUDA_ARCH__
142 143 144
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
朔-望's avatar
朔-望 已提交
145
#else
146
  PADDLE_ASSERT(idx >= 0);
朔-望's avatar
朔-望 已提交
147
#endif
148 149 150 151
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
152
}
朔-望's avatar
朔-望 已提交
153

朔-望's avatar
朔-望 已提交
154 155
template <>
HOSTDEVICE int64_t &indexer<0>(Dim<0> &dim, int idx) {
朔-望's avatar
朔-望 已提交
156
#ifndef __CUDA_ARCH__
157
  throw std::invalid_argument("Invalid index");
朔-望's avatar
朔-望 已提交
158
#else
159
  PADDLE_ASSERT(false);
朔-望's avatar
朔-望 已提交
160
#if CUDA_VERSION < 8000
161 162 163
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
朔-望's avatar
朔-望 已提交
164
#else
165
  static int64_t head = 0;
朔-望's avatar
朔-望 已提交
166
#endif
167
  return head;
朔-望's avatar
朔-望 已提交
168
#endif
朔-望's avatar
朔-望 已提交
169
}
朔-望's avatar
朔-望 已提交
170

朔-望's avatar
朔-望 已提交
171 172
template <int D>
HOSTDEVICE int64_t indexer(const Dim<D> &dim, int idx) {
朔-望's avatar
朔-望 已提交
173
#ifndef __CUDA_ARCH__
174 175 176
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
朔-望's avatar
朔-望 已提交
177
#else
178
  PADDLE_ASSERT(idx >= 0);
朔-望's avatar
朔-望 已提交
179
#endif
180 181 182 183
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
184 185
}

朔-望's avatar
朔-望 已提交
186 187
template <>
HOSTDEVICE int64_t indexer<0>(const Dim<0> &dim, int idx) {
朔-望's avatar
朔-望 已提交
188
#ifndef __CUDA_ARCH__
189
  throw std::invalid_argument("Invalid index");
朔-望's avatar
朔-望 已提交
190
#else
191
  PADDLE_ASSERT(false);
朔-望's avatar
朔-望 已提交
192
#if CUDA_VERSION < 8000
193 194 195
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
朔-望's avatar
朔-望 已提交
196
#else
197
  static int64_t head = 0;
朔-望's avatar
朔-望 已提交
198
#endif
199
  return head;
朔-望's avatar
朔-望 已提交
200
#endif
朔-望's avatar
朔-望 已提交
201 202
}

朔-望's avatar
朔-望 已提交
203
}  // namespace
朔-望's avatar
朔-望 已提交
204
// Static access to constant Dim
朔-望's avatar
朔-望 已提交
205 206
template <int i, int l>
HOSTDEVICE int64_t get(const Dim<l> &d) {
207
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
208 209 210
}

// Static access to mutable Dim
朔-望's avatar
朔-望 已提交
211 212
template <int i, int l>
HOSTDEVICE int64_t &get(Dim<l> &d) {
213
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
214 215 216
}

// Dynamic access to constant Dim
朔-望's avatar
朔-望 已提交
217 218
template <int l>
HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
219 220
  //  std::cout << "l: " << l << std::endl;
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
221 222 223
}

// Dynamic access to mutable Dim
朔-望's avatar
朔-望 已提交
224 225
template <int l>
HOSTDEVICE int64_t &Dim<l>::operator[](int i) {
226
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
227 228 229 230
}

// Dynamic access to constant Dim
inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const {
231
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
232 233 234 235
}

// Dynamic access to mutable Dim
inline HOSTDEVICE int64_t &Dim<0>::operator[](int i) {
236
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
237 238 239 240 241 242 243
}

// Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d)
template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l> &d,
                                                               int i) {
244
  return d[i];
朔-望's avatar
朔-望 已提交
245 246 247 248 249 250
}

// Dynamic access to mutable Dim
template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int64_t &>::type get(Dim<l> &d,
                                                                 int i) {
251
  return d[i];
朔-望's avatar
朔-望 已提交
252 253 254 255 256
}

// Dot product of two dims
template <int i>
HOSTDEVICE int64_t linearize(const Dim<i> &a, const Dim<i> &b) {
257
  return a.head * b.head + linearize(a.tail, b.tail);
朔-望's avatar
朔-望 已提交
258 259 260 261 262 263
}

// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
264
  return 0;
朔-望's avatar
朔-望 已提交
265 266 267
}

// Product of a Dim
朔-望's avatar
朔-望 已提交
268 269
template <int i>
HOSTDEVICE int64_t product(const Dim<i> &a, int prod = 1) {
270
  return prod * a.head * product(a.tail);
朔-望's avatar
朔-望 已提交
271 272 273 274
}

// Base case product of a Dim
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
275 276
template <>
HOSTDEVICE inline int64_t product(const Dim<0> &a, int prod) {
277
  return prod;
朔-望's avatar
朔-望 已提交
278 279 280 281 282
}

// Is 0 <= idx_i < size_i for all i?
template <int i>
HOSTDEVICE bool contained(const Dim<i> &idx, const Dim<i> &size) {
283 284
  return ((0 <= idx.head) && (idx.head < size.head) &&
          contained(idx.tail, size.tail));
朔-望's avatar
朔-望 已提交
285 286 287 288 289 290
}

// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline bool contained(const Dim<0> &idx, const Dim<0> &size) {
291
  return true;
朔-望's avatar
朔-望 已提交
292 293 294 295 296 297 298
}

/**
 * \brief Compute exclusive prefix-multiply of a Dim.
 */
template <int i>
HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
299
  return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head));
朔-望's avatar
朔-望 已提交
300 301 302 303 304
}

///\cond HIDDEN
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
305 306
template <>
HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
307
  return Dim<0>();
朔-望's avatar
朔-望 已提交
308 309 310 311 312 313
}
///\endcond

/**
 * Add two dimensions together
 */
朔-望's avatar
朔-望 已提交
314 315
template <int i>
HOSTDEVICE Dim<i> dim_plus(const Dim<i> &a, const Dim<i> &b) {
316
  return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
317 318 319 320 321
}

// Base case
template <>
HOSTDEVICE inline Dim<0> dim_plus(const Dim<0> &a, const Dim<0> &b) {
322
  return Dim<0>();
朔-望's avatar
朔-望 已提交
323 324 325 326
}

template <int i>
HOSTDEVICE Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
327
  return dim_plus(lhs, rhs);
朔-望's avatar
朔-望 已提交
328 329 330 331 332
}

/**
 * Multiply two dimensions together
 */
朔-望's avatar
朔-望 已提交
333 334
template <int i>
HOSTDEVICE Dim<i> dim_mult(const Dim<i> &a, const Dim<i> &b) {
335
  return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
336 337 338 339 340
}

// Base case
template <>
HOSTDEVICE inline Dim<0> dim_mult(const Dim<0> &a, const Dim<0> &b) {
341
  return Dim<0>();
朔-望's avatar
朔-望 已提交
342 343 344 345
}

template <int i>
HOSTDEVICE Dim<i> operator*(const Dim<i> &lhs, const Dim<i> &rhs) {
346
  return dim_mult(lhs, rhs);
朔-望's avatar
朔-望 已提交
347 348 349 350 351 352 353 354 355 356 357 358 359 360
}

/**
 * \brief Normalize strides to ensure any dimension with extent 1
 * has stride 0.
 *
 * \param size Dim object containing the size of an array
 * \param stride Dim object containing stride of an array
 * \return Dim object the same size as \p size with normalized strides
 *
 */

template <int i>
HOSTDEVICE Dim<i> normalize_strides(const Dim<i> &size, const Dim<i> &stride) {
361 362
  int norm_stride = size.head == 1 ? 0 : stride.head;
  return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail));
朔-望's avatar
朔-望 已提交
363 364 365 366 367 368 369
}

///\cond HIDDEN

template <>
HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0> &size,
                                           const Dim<0> &stride) {
370
  return Dim<0>();
朔-望's avatar
朔-望 已提交
371 372 373 374 375 376 377 378 379 380 381 382 383 384
}

///\endcond

/**
 * Helper function to create a Dim
 *
 * \param idxes The type of Dim constructed depends on the number of
 * params
 *
 */

template <typename... Args>
HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) {
385
  return Dim<sizeof...(Args)>(idxes...);
朔-望's avatar
朔-望 已提交
386 387 388 389 390
}

// Allows us to output a Dim
// XXX For some reason, overloading fails to resolve this correctly
template <int i>
朔-望's avatar
朔-望 已提交
391 392
typename std::enable_if<(i > 1), std::ostream &>::type operator<<(
    std::ostream &os, const Dim<i> &d) {
393 394
  os << d.head << ", " << d.tail;
  return os;
朔-望's avatar
朔-望 已提交
395 396 397 398 399
}

// Base case that allows us to output a Dim
// XXX I wish this could be an overload instead of a template
template <int i>
朔-望's avatar
朔-望 已提交
400 401
typename std::enable_if<(i == 1), std::ostream &>::type operator<<(
    std::ostream &os, const Dim<i> &d) {
402 403
  os << d.head;
  return os;
朔-望's avatar
朔-望 已提交
404 405 406
}

inline std::ostream &operator<<(std::ostream &os, const Dim<0> &d) {
407
  return os;
朔-望's avatar
朔-望 已提交
408 409
}

朔-望's avatar
朔-望 已提交
410 411
template <int i>
HOST std::string Dim<i>::to_string() const {
412
  std::stringstream stream;
朔-望's avatar
朔-望 已提交
413

414
  stream << *this;
朔-望's avatar
朔-望 已提交
415

416
  return stream.str();
朔-望's avatar
朔-望 已提交
417 418 419 420
}

template <int D>
HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
421
  Dim<D> result;
朔-望's avatar
朔-望 已提交
422

423 424 425 426
  for (int i = 0; i < D - 1; ++i) {
    result[i] = linear_index % extents[i];
    linear_index /= extents[i];
  }
朔-望's avatar
朔-望 已提交
427

428
  result[D - 1] = linear_index;
朔-望's avatar
朔-望 已提交
429

430
  return result;
朔-望's avatar
朔-望 已提交
431 432
}

朔-望's avatar
朔-望 已提交
433 434
}  // namespace framework
}  // namespace paddle_mobile