dim.h 7.7 KB
Newer Older
W
wangliu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

朔-望's avatar
朔-望 已提交
15 16
#pragma once

L
liuruilong 已提交
17
#include <cstdlib>
W
wangliu 已提交
18
#include "common/enforce.h"
朔-望's avatar
朔-望 已提交
19
namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
20
namespace framework {
朔-望's avatar
朔-望 已提交
21

朔-望's avatar
朔-望 已提交
22
// Statically sized, statically indexed dimension
朔-望's avatar
朔-望 已提交
23 24
template <int i>
struct Dim {
25
  static constexpr int dimensions = i;
朔-望's avatar
朔-望 已提交
26

27
  template <typename... Args>
L
liuruilong 已提交
28
  Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
29 30 31
    static_assert(sizeof...(_tail) == i - 1,
                  "Dim initialized with the wrong number of parameters");
  }
朔-望's avatar
朔-望 已提交
32

33
  Dim(int64_t _head, const Dim<i - 1> &_tail) : head(_head), tail(_tail) {}
朔-望's avatar
朔-望 已提交
34

35
  Dim() : head(0), tail() {}
朔-望's avatar
朔-望 已提交
36

37 38 39 40 41
  /** Construct a Dim from a linear index and size.  Uses Fortran
   * order
   * indexing. */
  Dim(int64_t idx, const Dim<i> &size)
      : head(idx % size.head), tail(idx / size.head, size.tail) {}
朔-望's avatar
朔-望 已提交
42

43 44
  /** Construct a Dim with each dimension set to the given index */
  Dim(int64_t idx) : head(idx), tail(idx) {}
朔-望's avatar
朔-望 已提交
45

46 47 48
  bool operator==(const Dim<i> &o) const {
    return (head == o.head) && (tail == o.tail);
  }
朔-望's avatar
朔-望 已提交
49

50
  bool operator!=(const Dim<i> &o) const { return !(*this == o); }
朔-望's avatar
朔-望 已提交
51

52
  int64_t &operator[](int idx);
L
liuruilong 已提交
53

54
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
55

L
liuruilong 已提交
56
  std::string to_string() const;
朔-望's avatar
朔-望 已提交
57

58 59
  int64_t head;
  Dim<i - 1> tail;
朔-望's avatar
朔-望 已提交
60
};
朔-望's avatar
朔-望 已提交
61

朔-望's avatar
朔-望 已提交
62
// Base case specialization
朔-望's avatar
朔-望 已提交
63 64
template <>
struct Dim<0> {
65
  static constexpr int dimensions = 0;
朔-望's avatar
朔-望 已提交
66

67
  Dim(int64_t _head) {}
朔-望's avatar
朔-望 已提交
68

69
  Dim() {}
朔-望's avatar
朔-望 已提交
70

71 72
  Dim(int idx, const Dim<0> &size) {
    if (idx > 0) {
W
wangliu 已提交
73
      PADDLE_MOBILE_THROW_EXCEPTION("Index out of range.")
74 75
    }
  }
朔-望's avatar
朔-望 已提交
76

77
  bool operator==(const Dim<0> &o) const { return true; }
朔-望's avatar
朔-望 已提交
78

79
  bool operator!=(const Dim<0> &o) const { return false; }
朔-望's avatar
朔-望 已提交
80

81
  int64_t &operator[](int idx);
L
liuruilong 已提交
82

83
  int64_t operator[](int idx) const;
朔-望's avatar
朔-望 已提交
84 85 86 87 88
};

namespace {

// Helper for accessing Dim classes
朔-望's avatar
朔-望 已提交
89 90
template <int i>
struct DimGetter {
91
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
92
  template <typename D>
L
liuruilong 已提交
93
  static int64_t impl(const D &d) {
94 95 96
    return DimGetter<i - 1>::impl(d.tail);
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
97
  template <typename D>
L
liuruilong 已提交
98
  static int64_t &impl(D &d) {
99 100
    return DimGetter<i - 1>::impl(d.tail);
  }
朔-望's avatar
朔-望 已提交
101 102 103
};

// Eureka! We found the element!
朔-望's avatar
朔-望 已提交
104 105
template <>
struct DimGetter<0> {
106
  // Return a copy if Dim is const
朔-望's avatar
朔-望 已提交
107
  template <typename D>
L
liuruilong 已提交
108
  static int64_t impl(const D &d) {
109 110 111
    return d.head;
  }
  // Return a reference if Dim is mutable
朔-望's avatar
朔-望 已提交
112
  template <typename D>
L
liuruilong 已提交
113
  static int64_t &impl(D &d) {
朔-望's avatar
朔-望 已提交
114 115
    return d.head;
  }
朔-望's avatar
朔-望 已提交
116 117
};

朔-望's avatar
朔-望 已提交
118
template <int D>
L
liuruilong 已提交
119
int64_t &indexer(Dim<D> &dim, int idx) {
120
  if (idx < 0) {
W
wangliu 已提交
121
    PADDLE_MOBILE_THROW_EXCEPTION("Tried to access a negative dimension")
122
  }
W
wangliu 已提交
123

124 125 126 127
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
128
}
朔-望's avatar
朔-望 已提交
129

朔-望's avatar
朔-望 已提交
130
template <>
L
liuruilong 已提交
131
int64_t &indexer<0>(Dim<0> &dim, int idx) {
W
wangliu 已提交
132
  PADDLE_MOBILE_THROW_EXCEPTION("Invalid index")
133
  exit(0);
朔-望's avatar
朔-望 已提交
134
}
朔-望's avatar
朔-望 已提交
135

朔-望's avatar
朔-望 已提交
136
template <int D>
L
liuruilong 已提交
137
int64_t indexer(const Dim<D> &dim, int idx) {
138
  if (idx < 0) {
W
wangliu 已提交
139
    PADDLE_MOBILE_THROW_EXCEPTION("Tried to access a negative dimension")
140 141 142 143 144
  }
  if (idx == 0) {
    return dim.head;
  }
  return indexer(dim.tail, idx - 1);
朔-望's avatar
朔-望 已提交
145 146
}

朔-望's avatar
朔-望 已提交
147
template <>
L
liuruilong 已提交
148
int64_t indexer<0>(const Dim<0> &dim, int idx) {
W
wangliu 已提交
149
  PADDLE_MOBILE_THROW_EXCEPTION("Invalid index")
150
  exit(0);
朔-望's avatar
朔-望 已提交
151 152
}

朔-望's avatar
朔-望 已提交
153
}  // namespace
朔-望's avatar
朔-望 已提交
154
// Static access to constant Dim
朔-望's avatar
朔-望 已提交
155
template <int i, int l>
L
liuruilong 已提交
156
int64_t get(const Dim<l> &d) {
157
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
158 159 160
}

// Static access to mutable Dim
朔-望's avatar
朔-望 已提交
161
template <int i, int l>
L
liuruilong 已提交
162
int64_t &get(Dim<l> &d) {
163
  return DimGetter<i>::impl(d);
朔-望's avatar
朔-望 已提交
164 165 166
}

// Dynamic access to constant Dim
朔-望's avatar
朔-望 已提交
167
template <int l>
L
liuruilong 已提交
168
int64_t Dim<l>::operator[](int i) const {
169 170
  //  std::cout << "l: " << l << std::endl;
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
171 172 173
}

// Dynamic access to mutable Dim
朔-望's avatar
朔-望 已提交
174
template <int l>
L
liuruilong 已提交
175
int64_t &Dim<l>::operator[](int i) {
176
  return indexer(*this, i);
朔-望's avatar
朔-望 已提交
177 178 179
}

// Dynamic access to constant Dim
L
liuruilong 已提交
180
inline int64_t Dim<0>::operator[](int i) const { return indexer(*this, i); }
朔-望's avatar
朔-望 已提交
181 182

// Dynamic access to mutable Dim
L
liuruilong 已提交
183
inline int64_t &Dim<0>::operator[](int i) { return indexer(*this, i); }
朔-望's avatar
朔-望 已提交
184 185 186 187

// Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d)
template <int l>
L
liuruilong 已提交
188
typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l> &d, int i) {
189
  return d[i];
朔-望's avatar
朔-望 已提交
190 191 192 193
}

// Dynamic access to mutable Dim
template <int l>
L
liuruilong 已提交
194
typename std::enable_if<(l > 0), int64_t &>::type get(Dim<l> &d, int i) {
195
  return d[i];
朔-望's avatar
朔-望 已提交
196 197 198 199
}

// Dot product of two dims
template <int i>
L
liuruilong 已提交
200
int64_t linearize(const Dim<i> &a, const Dim<i> &b) {
201
  return a.head * b.head + linearize(a.tail, b.tail);
朔-望's avatar
朔-望 已提交
202 203 204 205 206
}

// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
207
inline int64_t linearize(const Dim<0> &a, const Dim<0> &b) {
208
  return 0;
朔-望's avatar
朔-望 已提交
209 210 211
}

// Product of a Dim
朔-望's avatar
朔-望 已提交
212
template <int i>
L
liuruilong 已提交
213
int64_t product(const Dim<i> &a, int prod = 1) {
214
  return prod * a.head * product(a.tail);
朔-望's avatar
朔-望 已提交
215 216 217 218
}

// Base case product of a Dim
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
219
template <>
L
liuruilong 已提交
220
inline int64_t product(const Dim<0> &a, int prod) {
221
  return prod;
朔-望's avatar
朔-望 已提交
222 223 224 225
}

// Is 0 <= idx_i < size_i for all i?
template <int i>
L
liuruilong 已提交
226
bool contained(const Dim<i> &idx, const Dim<i> &size) {
227 228
  return ((0 <= idx.head) && (idx.head < size.head) &&
          contained(idx.tail, size.tail));
朔-望's avatar
朔-望 已提交
229 230 231 232 233
}

// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
template <>
L
liuruilong 已提交
234
inline bool contained(const Dim<0> &idx, const Dim<0> &size) {
235
  return true;
朔-望's avatar
朔-望 已提交
236 237 238 239 240 241
}

/**
 * \brief Compute exclusive prefix-multiply of a Dim.
 */
template <int i>
L
liuruilong 已提交
242
Dim<i> ex_prefix_mul(const Dim<i> &src, int mul = 1) {
243
  return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head));
朔-望's avatar
朔-望 已提交
244 245 246 247 248
}

///\cond HIDDEN
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
朔-望's avatar
朔-望 已提交
249
template <>
L
liuruilong 已提交
250
inline Dim<0> ex_prefix_mul(const Dim<0> &src, int mul) {
251
  return Dim<0>();
朔-望's avatar
朔-望 已提交
252 253 254 255 256 257
}
///\endcond

/**
 * Add two dimensions together
 */
朔-望's avatar
朔-望 已提交
258
template <int i>
L
liuruilong 已提交
259
Dim<i> dim_plus(const Dim<i> &a, const Dim<i> &b) {
260
  return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
261 262 263 264
}

// Base case
template <>
L
liuruilong 已提交
265
inline Dim<0> dim_plus(const Dim<0> &a, const Dim<0> &b) {
266
  return Dim<0>();
朔-望's avatar
朔-望 已提交
267 268 269
}

template <int i>
L
liuruilong 已提交
270
Dim<i> operator+(const Dim<i> &lhs, const Dim<i> &rhs) {
271
  return dim_plus(lhs, rhs);
朔-望's avatar
朔-望 已提交
272 273 274 275 276
}

/**
 * Multiply two dimensions together
 */
朔-望's avatar
朔-望 已提交
277
template <int i>
L
liuruilong 已提交
278
Dim<i> dim_mult(const Dim<i> &a, const Dim<i> &b) {
279
  return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail));
朔-望's avatar
朔-望 已提交
280 281 282 283
}

// Base case
template <>
L
liuruilong 已提交
284
inline Dim<0> dim_mult(const Dim<0> &a, const Dim<0> &b) {
285
  return Dim<0>();
朔-望's avatar
朔-望 已提交
286 287 288
}

template <int i>
L
liuruilong 已提交
289
Dim<i> operator*(const Dim<i> &lhs, const Dim<i> &rhs) {
290
  return dim_mult(lhs, rhs);
朔-望's avatar
朔-望 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303
}

/**
 * \brief Normalize strides to ensure any dimension with extent 1
 * has stride 0.
 *
 * \param size Dim object containing the size of an array
 * \param stride Dim object containing stride of an array
 * \return Dim object the same size as \p size with normalized strides
 *
 */

template <int i>
L
liuruilong 已提交
304
Dim<i> normalize_strides(const Dim<i> &size, const Dim<i> &stride) {
305 306
  int norm_stride = size.head == 1 ? 0 : stride.head;
  return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail));
朔-望's avatar
朔-望 已提交
307 308 309 310 311
}

///\cond HIDDEN

template <>
L
liuruilong 已提交
312
inline Dim<0> normalize_strides(const Dim<0> &size, const Dim<0> &stride) {
313
  return Dim<0>();
朔-望's avatar
朔-望 已提交
314 315 316 317 318 319 320 321 322 323 324 325 326
}

///\endcond

/**
 * Helper function to create a Dim
 *
 * \param idxes The type of Dim constructed depends on the number of
 * params
 *
 */

template <typename... Args>
L
liuruilong 已提交
327
Dim<sizeof...(Args)> make_dim(Args... idxes) {
328
  return Dim<sizeof...(Args)>(idxes...);
朔-望's avatar
朔-望 已提交
329 330
}

朔-望's avatar
朔-望 已提交
331 332
}  // namespace framework
}  // namespace paddle_mobile