dim.h 10.3 KB
Newer Older
1
//  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2 3 4 5 6 7 8 9 10 11 12 13
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
F
fengjiayi 已提交
14 15 16 17 18 19 20
#pragma once

#include <iostream>
#include <sstream>
#include <stdexcept>
#include <type_traits>

Y
Yi Wang 已提交
21 22
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/hostdevice.h"
F
fengjiayi 已提交
23

24 25
namespace paddle {
namespace framework {
F
fengjiayi 已提交
26 27 28 29 30 31 32

// Statically sized, statically indexed dimension
template <int i>
struct Dim {
  static constexpr int dimensions = i;

  template <typename... Args>
Q
qijun 已提交
33
  HOSTDEVICE Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
F
fengjiayi 已提交
34 35 36 37 38
    static_assert(sizeof...(_tail) == i - 1,
                  "Dim initialized with the wrong number of parameters");
  }

  HOSTDEVICE
Q
qijun 已提交
39
  Dim(int64_t _head, const Dim<i - 1>& _tail) : head(_head), tail(_tail) {}
F
fengjiayi 已提交
40 41 42 43 44 45 46

  HOSTDEVICE
  Dim() : head(0), tail() {}

  /** Construct a Dim from a linear index and size.  Uses Fortran order
   * indexing. */
  HOSTDEVICE
Q
qijun 已提交
47
  Dim(int64_t idx, const Dim<i>& size)
F
fengjiayi 已提交
48 49 50 51
      : head(idx % size.head), tail(idx / size.head, size.tail) {}

  /** Construct a Dim with each dimension set to the given index */
  HOSTDEVICE
Q
qijun 已提交
52
  Dim(int64_t idx) : head(idx), tail(idx) {}
F
fengjiayi 已提交
53 54 55 56 57 58 59 60 61 62

  HOSTDEVICE
  bool operator==(const Dim<i>& o) const {
    return (head == o.head) && (tail == o.tail);
  }

  HOSTDEVICE
  bool operator!=(const Dim<i>& o) const { return !(*this == o); }

  HOSTDEVICE
Q
qijun 已提交
63
  int64_t& operator[](int idx);
F
fengjiayi 已提交
64
  HOSTDEVICE
Q
qijun 已提交
65
  int64_t operator[](int idx) const;
F
fengjiayi 已提交
66 67 68

  HOST std::string to_string() const;

Q
qijun 已提交
69
  int64_t head;
F
fengjiayi 已提交
70 71 72 73 74
  Dim<i - 1> tail;
};

// Base case specialization
template <>
X
xuwei06 已提交
75 76
struct Dim<0> {
  static constexpr int dimensions = 0;
F
fengjiayi 已提交
77 78

  HOSTDEVICE
X
xuwei06 已提交
79
  Dim(int64_t _head) {}
F
fengjiayi 已提交
80 81

  HOSTDEVICE
X
xuwei06 已提交
82
  Dim() {}
F
fengjiayi 已提交
83 84

  HOSTDEVICE
X
xuwei06 已提交
85
  Dim(int idx, const Dim<0>& size) {
F
fengjiayi 已提交
86
#ifndef __CUDA_ARCH__
X
xuwei06 已提交
87
    if (idx > 0) {
F
fengjiayi 已提交
88 89 90
      throw std::invalid_argument("Index out of range.");
    }
#else
X
xuwei06 已提交
91
    PADDLE_ASSERT(idx == 0);
F
fengjiayi 已提交
92 93 94 95
#endif
  }

  HOSTDEVICE
X
xuwei06 已提交
96
  bool operator==(const Dim<0>& o) const { return true; }
F
fengjiayi 已提交
97 98

  HOSTDEVICE
X
xuwei06 已提交
99
  bool operator!=(const Dim<0>& o) const { return false; }
F
fengjiayi 已提交
100 101

  HOSTDEVICE
Q
qijun 已提交
102
  int64_t& operator[](int idx);
F
fengjiayi 已提交
103
  HOSTDEVICE
Q
qijun 已提交
104
  int64_t operator[](int idx) const;
F
fengjiayi 已提交
105 106 107 108 109 110 111 112 113
};

namespace {

// Helper for accessing Dim classes
template <int i>
struct DimGetter {
  // Return a copy if Dim is const
  template <typename D>
Q
qijun 已提交
114
  HOSTDEVICE static int64_t impl(const D& d) {
F
fengjiayi 已提交
115 116 117 118
    return DimGetter<i - 1>::impl(d.tail);
  }
  // Return a reference if Dim is mutable
  template <typename D>
Q
qijun 已提交
119
  HOSTDEVICE static int64_t& impl(D& d) {
F
fengjiayi 已提交
120 121 122 123 124 125 126 127 128
    return DimGetter<i - 1>::impl(d.tail);
  }
};

// Eureka! We found the element!
template <>
struct DimGetter<0> {
  // Return a copy if Dim is const
  template <typename D>
Q
qijun 已提交
129
  HOSTDEVICE static int64_t impl(const D& d) {
F
fengjiayi 已提交
130 131 132 133
    return d.head;
  }
  // Return a reference if Dim is mutable
  template <typename D>
Q
qijun 已提交
134
  HOSTDEVICE static int64_t& impl(D& d) {
F
fengjiayi 已提交
135 136 137 138 139
    return d.head;
  }
};

template <int D>
Q
qijun 已提交
140
HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
F
fengjiayi 已提交
141 142 143 144 145
#ifndef __CUDA_ARCH__
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
#else
146
  PADDLE_ASSERT(idx >= 0);
F
fengjiayi 已提交
147 148 149 150 151 152 153 154
#endif
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
}

template <>
X
xuwei06 已提交
155
HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) {
F
fengjiayi 已提交
156
#ifndef __CUDA_ARCH__
X
xuwei06 已提交
157
  throw std::invalid_argument("Invalid index");
F
fengjiayi 已提交
158
#else
X
xuwei06 已提交
159
  PADDLE_ASSERT(false);
F
fengjiayi 已提交
160
#endif
161 162 163 164 165
#if (defined __CUDA_ARCH__) && (CUDA_VERSION < 8000)
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
#else
X
xuwei06 已提交
166
  static int64_t head = 0;
167
#endif
X
xuwei06 已提交
168
  return head;
F
fengjiayi 已提交
169 170 171
}

template <int D>
Q
qijun 已提交
172
HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
F
fengjiayi 已提交
173 174 175 176 177
#ifndef __CUDA_ARCH__
  if (idx < 0) {
    throw std::invalid_argument("Tried to access a negative dimension");
  }
#else
178
  PADDLE_ASSERT(idx >= 0);
F
fengjiayi 已提交
179 180 181 182 183 184 185 186
#endif
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
}

template <>
X
xuwei06 已提交
187
HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) {
F
fengjiayi 已提交
188
#ifndef __CUDA_ARCH__
X
xuwei06 已提交
189
  throw std::invalid_argument("Invalid index");
F
fengjiayi 已提交
190
#else
X
xuwei06 已提交
191
  PADDLE_ASSERT(false);
F
fengjiayi 已提交
192
#endif
193 194 195 196 197
#if (defined __CUDA_ARCH__) && (CUDA_VERSION < 8000)
  // On CUDA versions previous to 8.0, only __shared__ variables
  // could be declared as static in the device code.
  int64_t head = 0;
#else
X
xuwei06 已提交
198
  static int64_t head = 0;
199
#endif
X
xuwei06 已提交
200
  return head;
F
fengjiayi 已提交
201 202 203 204 205
}

}  // namespace
// Static access to constant Dim
template <int i, int l>
Q
qijun 已提交
206
HOSTDEVICE int64_t get(const Dim<l>& d) {
F
fengjiayi 已提交
207 208 209 210 211
  return DimGetter<i>::impl(d);
}

// Static access to mutable Dim
template <int i, int l>
Q
qijun 已提交
212
HOSTDEVICE int64_t& get(Dim<l>& d) {
F
fengjiayi 已提交
213 214 215 216 217
  return DimGetter<i>::impl(d);
}

// Dynamic access to constant Dim
template <int l>
Q
qijun 已提交
218
HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
F
fengjiayi 已提交
219 220 221 222 223
  return indexer(*this, i);
}

// Dynamic access to mutable Dim
template <int l>
Q
qijun 已提交
224
HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
F
fengjiayi 已提交
225 226 227 228
  return indexer(*this, i);
}

// Dynamic access to constant Dim
X
xuwei06 已提交
229
inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const {
F
fengjiayi 已提交
230 231 232 233
  return indexer(*this, i);
}

// Dynamic access to mutable Dim
X
xuwei06 已提交
234
inline HOSTDEVICE int64_t& Dim<0>::operator[](int i) {
Q
qijun 已提交
235 236
  return indexer(*this, i);
}
F
fengjiayi 已提交
237 238 239 240

// Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d)
template <int l>
Q
qijun 已提交
241 242
HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l>& d,
                                                               int i) {
F
fengjiayi 已提交
243 244 245 246 247
  return d[i];
}

// Dynamic access to mutable Dim
template <int l>
Q
qijun 已提交
248 249
HOSTDEVICE typename std::enable_if<(l > 0), int64_t&>::type get(Dim<l>& d,
                                                                int i) {
F
fengjiayi 已提交
250 251 252 253 254
  return d[i];
}

// Dot product of two dims
template <int i>
Q
qijun 已提交
255
HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
F
fengjiayi 已提交
256 257 258 259 260 261
  return a.head * b.head + linearize(a.tail, b.tail);
}

// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
X
xuwei06 已提交
262 263
HOSTDEVICE inline int64_t linearize(const Dim<0>& a, const Dim<0>& b) {
  return 0;
F
fengjiayi 已提交
264 265 266 267
}

// Product of a Dim
template <int i>
Q
qijun 已提交
268
HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
F
fengjiayi 已提交
269 270 271 272 273 274
  return prod * a.head * product(a.tail);
}

// Base case product of a Dim
// Notice it is inline because it is no longer a template
template <>
X
xuwei06 已提交
275 276
HOSTDEVICE inline int64_t product(const Dim<0>& a, int prod) {
  return prod;
F
fengjiayi 已提交
277 278 279 280 281 282 283 284 285 286 287 288
}

// Is 0 <= idx_i < size_i for all i?
template <int i>
HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) {
  return ((0 <= idx.head) && (idx.head < size.head) &&
          contained(idx.tail, size.tail));
}

// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
template <>
X
xuwei06 已提交
289 290
HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) {
  return true;
F
fengjiayi 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304
}

/**
 * \brief Compute exclusive prefix-multiply of a Dim.
 */
template <int i>
HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) {
  return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head));
}

///\cond HIDDEN
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
template <>
X
xuwei06 已提交
305 306
HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) {
  return Dim<0>();
F
fengjiayi 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319
}
///\endcond

/**
 * Add two dimensions together
 */
template <int i>
HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) {
  return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
}

// Base case
template <>
X
xuwei06 已提交
320 321
HOSTDEVICE inline Dim<0> dim_plus(const Dim<0>& a, const Dim<0>& b) {
  return Dim<0>();
F
fengjiayi 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
}

template <int i>
HOSTDEVICE Dim<i> operator+(const Dim<i>& lhs, const Dim<i>& rhs) {
  return dim_plus(lhs, rhs);
}

/**
 * Multiply two dimensions together
 */
template <int i>
HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) {
  return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
}

// Base case
template <>
X
xuwei06 已提交
339 340
HOSTDEVICE inline Dim<0> dim_mult(const Dim<0>& a, const Dim<0>& b) {
  return Dim<0>();
F
fengjiayi 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
}

template <int i>
HOSTDEVICE Dim<i> operator*(const Dim<i>& lhs, const Dim<i>& rhs) {
  return dim_mult(lhs, rhs);
}

/**
 * \brief Normalize strides to ensure any dimension with extent 1
 * has stride 0.
 *
 * \param size Dim object containing the size of an array
 * \param stride Dim object containing stride of an array
 * \return Dim object the same size as \p size with normalized strides
 *
 */

template <int i>
HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) {
  int norm_stride = size.head == 1 ? 0 : stride.head;
  return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail));
}

///\cond HIDDEN

template <>
X
xuwei06 已提交
367 368 369
HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size,
                                           const Dim<0>& stride) {
  return Dim<0>();
F
fengjiayi 已提交
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
}

///\endcond

/**
 * Helper function to create a Dim
 *
 * \param idxes The type of Dim constructed depends on the number of params
 *
 */

template <typename... Args>
HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) {
  return Dim<sizeof...(Args)>(idxes...);
}

// Allows us to output a Dim
// XXX For some reason, overloading fails to resolve this correctly
template <int i>
typename std::enable_if<(i > 1), std::ostream&>::type operator<<(
390
    std::ostream& os, const Dim<i>& d) {
F
fengjiayi 已提交
391 392 393 394 395 396 397 398
  os << d.head << ", " << d.tail;
  return os;
}

// Base case that allows us to output a Dim
// XXX I wish this could be an overload instead of a template
template <int i>
typename std::enable_if<(i == 1), std::ostream&>::type operator<<(
399
    std::ostream& os, const Dim<i>& d) {
F
fengjiayi 已提交
400 401 402 403
  os << d.head;
  return os;
}

X
xuwei06 已提交
404 405 406 407
inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
  return os;
}

F
fengjiayi 已提交
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
template <int i>
HOST std::string Dim<i>::to_string() const {
  std::stringstream stream;

  stream << *this;

  return stream.str();
}

template <int D>
HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
  Dim<D> result;

  for (int i = 0; i < D - 1; ++i) {
    result[i] = linear_index % extents[i];
    linear_index /= extents[i];
  }

  result[D - 1] = linear_index;

  return result;
}

431 432
}  // namespace framework
}  // namespace paddle