ddim.cc 7.9 KB
Newer Older
W
wangliu 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
朔-望's avatar
朔-望 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "ddim.h"
L
liuruilong 已提交
16
#include <algorithm>
朔-望's avatar
朔-望 已提交
17 18

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
19 20 21 22
namespace framework {

/// @cond HIDDEN

朔-望's avatar
朔-望 已提交
23 24
template <int i>
Dim<i> make_dim(const int64_t *d) {
25
  return Dim<i>(*d, make_dim<i - 1>(d + 1));
朔-望's avatar
朔-望 已提交
26 27
}

朔-望's avatar
朔-望 已提交
28 29 30 31
template <>
Dim<0> make_dim<0>(const int64_t *d) {
  return Dim<0>(*d);
}
朔-望's avatar
朔-望 已提交
32 33

void make_ddim(DDim &ddim, const int64_t *dims, int n) {
34
  switch (n) {
朔-望's avatar
朔-望 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    case 0:
      ddim = make_dim<0>(dims);
      break;
    case 1:
      ddim = make_dim<1>(dims);
      break;
    case 2:
      ddim = make_dim<2>(dims);
      break;
    case 3:
      ddim = make_dim<3>(dims);
      break;
    case 4:
      ddim = make_dim<4>(dims);
      break;
    case 5:
      ddim = make_dim<5>(dims);
      break;
    case 6:
      ddim = make_dim<6>(dims);
      break;
    case 7:
      ddim = make_dim<7>(dims);
      break;
    case 8:
      ddim = make_dim<8>(dims);
      break;
    case 9:
      ddim = make_dim<9>(dims);
      break;
    default:
      //      std::cout << "Dynamic dimensions must have between [1,
      //      9]
      //      dimensions.";
      break;
70
  }
朔-望's avatar
朔-望 已提交
71 72 73 74 75
}

/// @endcond

DDim make_ddim(std::initializer_list<int64_t> dims) {
76 77 78
  DDim result(make_dim(0));
  make_ddim(result, dims.begin(), dims.size());
  return result;
朔-望's avatar
朔-望 已提交
79 80 81
}

DDim make_ddim(const std::vector<int64_t> &dims) {
82 83 84
  DDim result(make_dim(0));
  make_ddim(result, &dims[0], dims.size());
  return result;
朔-望's avatar
朔-望 已提交
85 86 87
}

DDim make_ddim(const std::vector<int> &dims) {
88 89 90 91
  std::vector<int64_t> res(dims.size());
  std::transform(dims.begin(), dims.end(), res.begin(),
                 [](int d) { return static_cast<int64_t>(d); });
  return make_ddim(res);
朔-望's avatar
朔-望 已提交
92 93 94 95 96 97
}

/// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes
// errors
struct DynamicMutableIndexer : Vistor<int64_t &> {
朔-望's avatar
朔-望 已提交
98
 public:
99
  explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
朔-望's avatar
朔-望 已提交
100

朔-望's avatar
朔-望 已提交
101 102 103 104
  template <int D>
  int64_t &operator()(Dim<D> &dim) const {
    return dim[idx_];
  }
朔-望's avatar
朔-望 已提交
105

朔-望's avatar
朔-望 已提交
106
 private:
107
  int idx_;
朔-望's avatar
朔-望 已提交
108 109 110
};

struct DynamicConstIndexer : public Vistor<int64_t> {
朔-望's avatar
朔-望 已提交
111
 public:
112
  explicit DynamicConstIndexer(int idx) : idx_(idx) {}
朔-望's avatar
朔-望 已提交
113

朔-望's avatar
朔-望 已提交
114 115
  template <int D>
  int64_t operator()(const Dim<D> &dim) const {
116 117
    return dim[idx_];
  }
朔-望's avatar
朔-望 已提交
118

朔-望's avatar
朔-望 已提交
119
 private:
120
  int idx_;
朔-望's avatar
朔-望 已提交
121 122 123 124 125
};

/// @endcond

int64_t &DDim::operator[](int idx) {
126
  return DDim::ApplyVistor(DynamicMutableIndexer(idx), *this);
朔-望's avatar
朔-望 已提交
127 128 129
}

int64_t DDim::operator[](int idx) const {
130
  return DDim::ApplyVistor(DynamicConstIndexer(idx), *this);
朔-望's avatar
朔-望 已提交
131 132 133 134 135
}

int DDim::size() const { return arity(*this); }

bool DDim::operator==(DDim d) const {
136 137 138 139 140 141 142 143 144
  //  if (var.which() != d.getVar().which()) {
  //    return false;
  //  } else {
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);

  for (unsigned int i = 0; i < v1.size(); i++) {
    if (v1[i] != v2[i]) {
      return false;
朔-望's avatar
朔-望 已提交
145
    }
146
  }
147

148 149
  return true;
  //  }
朔-望's avatar
朔-望 已提交
150 151 152 153 154
}

bool DDim::operator!=(DDim d) const { return !(*this == d); }

DDim DDim::operator+(DDim d) const {
155 156
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);
朔-望's avatar
朔-望 已提交
157

158
  std::vector<int64_t> v3;
朔-望's avatar
朔-望 已提交
159

160
  assert(v1.size() == v2.size());
朔-望's avatar
朔-望 已提交
161

162 163 164
  for (unsigned int i = 0; i < v1.size(); i++) {
    v3.push_back(v1[i] + v2[i]);
  }
朔-望's avatar
朔-望 已提交
165

166
  return make_ddim(v3);
朔-望's avatar
朔-望 已提交
167 168 169
}

DDim DDim::operator*(DDim d) const {
170 171
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);
朔-望's avatar
朔-望 已提交
172

173
  std::vector<int64_t> v3;
朔-望's avatar
朔-望 已提交
174

175
  assert(v1.size() == v2.size());
朔-望's avatar
朔-望 已提交
176

177 178 179
  for (unsigned int i = 0; i < v1.size(); i++) {
    v3.push_back(v1[i] * v2[i]);
  }
朔-望's avatar
朔-望 已提交
180

181
  return make_ddim(v3);
朔-望's avatar
朔-望 已提交
182 183 184 185 186 187 188 189
}

int64_t get(const DDim &ddim, int idx) { return ddim[idx]; }

void set(DDim &ddim, int idx, int value) { ddim[idx] = value; }

/// @cond HIDDEN
struct VectorizeVisitor : Vistor<void> {
190
  std::vector<int64_t> &vector;
朔-望's avatar
朔-望 已提交
191

192
  explicit VectorizeVisitor(std::vector<int64_t> &v) : vector(v) {}
朔-望's avatar
朔-望 已提交
193

朔-望's avatar
朔-望 已提交
194 195
  template <typename T>
  void operator()(const T &t) {
196 197 198
    vector.push_back(t.head);
    this->operator()(t.tail);
  }
朔-望's avatar
朔-望 已提交
199

200
  void operator()(const Dim<0> &t) {}
朔-望's avatar
朔-望 已提交
201 202 203 204
};
/// @endcond

std::vector<int64_t> vectorize(const DDim &ddim) {
205 206 207 208
  std::vector<int64_t> result;
  VectorizeVisitor visitor(result);
  DDim::ApplyVistor(visitor, ddim);
  return result;
朔-望's avatar
朔-望 已提交
209 210 211 212 213
}

// NOTE: framework::vectorize converts to type int64_t
//       which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim &ddim) {
214 215 216
  std::vector<int64_t> temp = vectorize(ddim);
  std::vector<int> result(temp.begin(), temp.end());
  return result;
朔-望's avatar
朔-望 已提交
217 218 219
}

struct ProductVisitor : Vistor<int64_t> {
朔-望's avatar
朔-望 已提交
220 221
  template <int D>
  int64_t operator()(const Dim<D> &dim) {
222 223
    return product(dim);
  }
朔-望's avatar
朔-望 已提交
224 225 226
};

int64_t product(const DDim &ddim) {
227 228
  ProductVisitor visitor;
  return DDim::ApplyVistor(visitor, ddim);
朔-望's avatar
朔-望 已提交
229 230 231
}

struct SliceVectorizeVisitor : Vistor<void> {
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
  std::vector<int64_t> &vector;
  int begin;
  int end;

  SliceVectorizeVisitor(std::vector<int64_t> &v, int b, int e)
      : vector(v), begin(b), end(e) {
    //    PADDLE_ENFORCE(begin < end,
    //                   "Begin index must be less than end index in
    //                   ddim
    //                   slice.");
    //    PADDLE_ENFORCE(begin >= 0,
    //                   "Begin index can't be less than zero in
    //                   ddim slice.");
  }

朔-望's avatar
朔-望 已提交
247 248
  template <int S>
  void operator()(const Dim<S> &dim) {
249 250 251 252
    if (begin == 0) {
      vector.push_back(dim.head);
    } else {
      --begin;
朔-望's avatar
朔-望 已提交
253
    }
254 255 256
    --end;
    if (end > 0) {
      this->operator()(dim.tail);
朔-望's avatar
朔-望 已提交
257
    }
258
  }
朔-望's avatar
朔-望 已提交
259

260 261 262 263
  void operator()(const Dim<0> &dim) {
    //    PADDLE_ENFORCE(end == 0, "End index in ddim slice is out
    //    of bound.");
  }
朔-望's avatar
朔-望 已提交
264 265 266
};

DDim slice_ddim(const DDim &ddim, int begin, int end) {
267 268 269 270 271 272 273
  std::vector<int64_t> vec;
  vec.reserve(end - begin);
  SliceVectorizeVisitor visitor(vec, begin, end);
  //  boost::apply_visitor(visitor, dim);
  DDim::ApplyVistor(visitor, ddim);
  //  visitor(ddim.var.Get<Dim<4>>());
  return make_ddim(vec);
朔-望's avatar
朔-望 已提交
274 275 276 277 278
}

/// \cond HIDDEN

struct ArityVisitor : Vistor<int> {
朔-望's avatar
朔-望 已提交
279 280 281 282
  template <int D>
  int operator()(Dim<D>) const {
    return D;
  }
朔-望's avatar
朔-望 已提交
283 284 285 286 287
};

/// \endcond

int arity(const DDim &d) {
288 289 290 291
  ArityVisitor arityVisitor = ArityVisitor();
  return DDim::ApplyVistor(arityVisitor, d);
  //  return arityVisitor(d.var.Get<Dim<4>>());
  //  return boost::apply_visitor(ArityVisitor(), d); }
朔-望's avatar
朔-望 已提交
292 293 294 295 296 297
}
/// \cond HIDDEN

/// \endcond

struct OSVistor : Vistor<std::ostream &> {
298
  OSVistor(std::ostream &os) : os_(os) {}
朔-望's avatar
朔-望 已提交
299

朔-望's avatar
朔-望 已提交
300 301
  template <int D>
  std::ostream &operator()(Dim<D> dim) const {
302 303
    return os_ << dim;
  }
朔-望's avatar
朔-望 已提交
304

朔-望's avatar
朔-望 已提交
305
 private:
306
  std::ostream &os_;
朔-望's avatar
朔-望 已提交
307 308 309
};

std::ostream &operator<<(std::ostream &os, const DDim &ddim) {
310 311 312
  auto vistor = OSVistor(os);
  DDim::ApplyVistor(vistor, ddim);
  return os;
朔-望's avatar
朔-望 已提交
313 314 315
}

DDim::DDim(std::initializer_list<int64_t> init_list) {
316
  *this = make_ddim(init_list);
朔-望's avatar
朔-望 已提交
317 318 319
}

DDim flatten_to_2d(const DDim &src, int num_col_dims) {
320 321 322
  int rank = src.size();
  return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
                    product(slice_ddim(src, num_col_dims, rank))});
朔-望's avatar
朔-望 已提交
323 324 325 326 327
}

DDim flatten_to_1d(const DDim &src) { return make_ddim({product(src)}); }

DDim stride(const DDim &ddim) {
328 329 330 331 332 333
  std::vector<int64_t> strides(ddim.size());
  strides[ddim.size() - 1] = 1;
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i + 1];
  }
  return framework::make_ddim(strides);
朔-望's avatar
朔-望 已提交
334 335 336
}

DDim stride_numel(const framework::DDim &ddim) {
337 338 339 340 341 342
  std::vector<int64_t> strides(ddim.size());
  strides[ddim.size() - 1] = ddim[ddim.size() - 1];
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i];
  }
  return framework::make_ddim(strides);
朔-望's avatar
朔-望 已提交
343 344
}

朔-望's avatar
朔-望 已提交
345 346
}  // namespace framework
}  // namespace paddle_mobile