ddim.cc 7.9 KB
Newer Older
W
wangliu 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
朔-望's avatar
朔-望 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "ddim.h"

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
18 19 20 21
namespace framework {

/// @cond HIDDEN

朔-望's avatar
朔-望 已提交
22 23
template <int i>
Dim<i> make_dim(const int64_t *d) {
24
  return Dim<i>(*d, make_dim<i - 1>(d + 1));
朔-望's avatar
朔-望 已提交
25 26
}

朔-望's avatar
朔-望 已提交
27 28 29 30
template <>
Dim<0> make_dim<0>(const int64_t *d) {
  return Dim<0>(*d);
}
朔-望's avatar
朔-望 已提交
31 32

void make_ddim(DDim &ddim, const int64_t *dims, int n) {
33
  switch (n) {
朔-望's avatar
朔-望 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
    case 0:
      ddim = make_dim<0>(dims);
      break;
    case 1:
      ddim = make_dim<1>(dims);
      break;
    case 2:
      ddim = make_dim<2>(dims);
      break;
    case 3:
      ddim = make_dim<3>(dims);
      break;
    case 4:
      ddim = make_dim<4>(dims);
      break;
    case 5:
      ddim = make_dim<5>(dims);
      break;
    case 6:
      ddim = make_dim<6>(dims);
      break;
    case 7:
      ddim = make_dim<7>(dims);
      break;
    case 8:
      ddim = make_dim<8>(dims);
      break;
    case 9:
      ddim = make_dim<9>(dims);
      break;
    default:
      //      std::cout << "Dynamic dimensions must have between [1,
      //      9]
      //      dimensions.";
      break;
69
  }
朔-望's avatar
朔-望 已提交
70 71 72 73 74
}

/// @endcond

DDim make_ddim(std::initializer_list<int64_t> dims) {
75 76 77
  DDim result(make_dim(0));
  make_ddim(result, dims.begin(), dims.size());
  return result;
朔-望's avatar
朔-望 已提交
78 79 80
}

DDim make_ddim(const std::vector<int64_t> &dims) {
81 82 83
  DDim result(make_dim(0));
  make_ddim(result, &dims[0], dims.size());
  return result;
朔-望's avatar
朔-望 已提交
84 85 86
}

DDim make_ddim(const std::vector<int> &dims) {
87 88 89 90
  std::vector<int64_t> res(dims.size());
  std::transform(dims.begin(), dims.end(), res.begin(),
                 [](int d) { return static_cast<int64_t>(d); });
  return make_ddim(res);
朔-望's avatar
朔-望 已提交
91 92 93 94 95 96
}

/// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes
// errors
struct DynamicMutableIndexer : Vistor<int64_t &> {
朔-望's avatar
朔-望 已提交
97
 public:
98
  explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
朔-望's avatar
朔-望 已提交
99

朔-望's avatar
朔-望 已提交
100 101 102 103
  template <int D>
  int64_t &operator()(Dim<D> &dim) const {
    return dim[idx_];
  }
朔-望's avatar
朔-望 已提交
104

朔-望's avatar
朔-望 已提交
105
 private:
106
  int idx_;
朔-望's avatar
朔-望 已提交
107 108 109
};

struct DynamicConstIndexer : public Vistor<int64_t> {
朔-望's avatar
朔-望 已提交
110
 public:
111
  explicit DynamicConstIndexer(int idx) : idx_(idx) {}
朔-望's avatar
朔-望 已提交
112

朔-望's avatar
朔-望 已提交
113 114
  template <int D>
  int64_t operator()(const Dim<D> &dim) const {
115 116
    return dim[idx_];
  }
朔-望's avatar
朔-望 已提交
117

朔-望's avatar
朔-望 已提交
118
 private:
119
  int idx_;
朔-望's avatar
朔-望 已提交
120 121 122 123 124
};

/// @endcond

int64_t &DDim::operator[](int idx) {
125
  return DDim::ApplyVistor(DynamicMutableIndexer(idx), *this);
朔-望's avatar
朔-望 已提交
126 127 128
}

int64_t DDim::operator[](int idx) const {
129
  return DDim::ApplyVistor(DynamicConstIndexer(idx), *this);
朔-望's avatar
朔-望 已提交
130 131 132 133 134
}

int DDim::size() const { return arity(*this); }

bool DDim::operator==(DDim d) const {
135 136 137 138 139 140 141 142 143
  //  if (var.which() != d.getVar().which()) {
  //    return false;
  //  } else {
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);

  for (unsigned int i = 0; i < v1.size(); i++) {
    if (v1[i] != v2[i]) {
      return false;
朔-望's avatar
朔-望 已提交
144
    }
145
  }
146

147 148
  return true;
  //  }
朔-望's avatar
朔-望 已提交
149 150 151 152 153
}

bool DDim::operator!=(DDim d) const { return !(*this == d); }

DDim DDim::operator+(DDim d) const {
154 155
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);
朔-望's avatar
朔-望 已提交
156

157
  std::vector<int64_t> v3;
朔-望's avatar
朔-望 已提交
158

159
  assert(v1.size() == v2.size());
朔-望's avatar
朔-望 已提交
160

161 162 163
  for (unsigned int i = 0; i < v1.size(); i++) {
    v3.push_back(v1[i] + v2[i]);
  }
朔-望's avatar
朔-望 已提交
164

165
  return make_ddim(v3);
朔-望's avatar
朔-望 已提交
166 167 168
}

DDim DDim::operator*(DDim d) const {
169 170
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);
朔-望's avatar
朔-望 已提交
171

172
  std::vector<int64_t> v3;
朔-望's avatar
朔-望 已提交
173

174
  assert(v1.size() == v2.size());
朔-望's avatar
朔-望 已提交
175

176 177 178
  for (unsigned int i = 0; i < v1.size(); i++) {
    v3.push_back(v1[i] * v2[i]);
  }
朔-望's avatar
朔-望 已提交
179

180
  return make_ddim(v3);
朔-望's avatar
朔-望 已提交
181 182 183 184 185 186 187 188
}

int64_t get(const DDim &ddim, int idx) { return ddim[idx]; }

void set(DDim &ddim, int idx, int value) { ddim[idx] = value; }

/// @cond HIDDEN
struct VectorizeVisitor : Vistor<void> {
189
  std::vector<int64_t> &vector;
朔-望's avatar
朔-望 已提交
190

191
  explicit VectorizeVisitor(std::vector<int64_t> &v) : vector(v) {}
朔-望's avatar
朔-望 已提交
192

朔-望's avatar
朔-望 已提交
193 194
  template <typename T>
  void operator()(const T &t) {
195 196 197
    vector.push_back(t.head);
    this->operator()(t.tail);
  }
朔-望's avatar
朔-望 已提交
198

199
  void operator()(const Dim<0> &t) {}
朔-望's avatar
朔-望 已提交
200 201 202 203
};
/// @endcond

std::vector<int64_t> vectorize(const DDim &ddim) {
204 205 206 207
  std::vector<int64_t> result;
  VectorizeVisitor visitor(result);
  DDim::ApplyVistor(visitor, ddim);
  return result;
朔-望's avatar
朔-望 已提交
208 209 210 211 212
}

// NOTE: framework::vectorize converts to type int64_t
//       which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim &ddim) {
213 214 215
  std::vector<int64_t> temp = vectorize(ddim);
  std::vector<int> result(temp.begin(), temp.end());
  return result;
朔-望's avatar
朔-望 已提交
216 217 218
}

struct ProductVisitor : Vistor<int64_t> {
朔-望's avatar
朔-望 已提交
219 220
  template <int D>
  int64_t operator()(const Dim<D> &dim) {
221 222
    return product(dim);
  }
朔-望's avatar
朔-望 已提交
223 224 225
};

int64_t product(const DDim &ddim) {
226 227
  ProductVisitor visitor;
  return DDim::ApplyVistor(visitor, ddim);
朔-望's avatar
朔-望 已提交
228 229 230
}

struct SliceVectorizeVisitor : Vistor<void> {
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
  std::vector<int64_t> &vector;
  int begin;
  int end;

  SliceVectorizeVisitor(std::vector<int64_t> &v, int b, int e)
      : vector(v), begin(b), end(e) {
    //    PADDLE_ENFORCE(begin < end,
    //                   "Begin index must be less than end index in
    //                   ddim
    //                   slice.");
    //    PADDLE_ENFORCE(begin >= 0,
    //                   "Begin index can't be less than zero in
    //                   ddim slice.");
  }

朔-望's avatar
朔-望 已提交
246 247
  template <int S>
  void operator()(const Dim<S> &dim) {
248 249 250 251
    if (begin == 0) {
      vector.push_back(dim.head);
    } else {
      --begin;
朔-望's avatar
朔-望 已提交
252
    }
253 254 255
    --end;
    if (end > 0) {
      this->operator()(dim.tail);
朔-望's avatar
朔-望 已提交
256
    }
257
  }
朔-望's avatar
朔-望 已提交
258

259 260 261 262
  void operator()(const Dim<0> &dim) {
    //    PADDLE_ENFORCE(end == 0, "End index in ddim slice is out
    //    of bound.");
  }
朔-望's avatar
朔-望 已提交
263 264 265
};

DDim slice_ddim(const DDim &ddim, int begin, int end) {
266 267 268 269 270 271 272
  std::vector<int64_t> vec;
  vec.reserve(end - begin);
  SliceVectorizeVisitor visitor(vec, begin, end);
  //  boost::apply_visitor(visitor, dim);
  DDim::ApplyVistor(visitor, ddim);
  //  visitor(ddim.var.Get<Dim<4>>());
  return make_ddim(vec);
朔-望's avatar
朔-望 已提交
273 274 275 276 277
}

/// \cond HIDDEN

struct ArityVisitor : Vistor<int> {
朔-望's avatar
朔-望 已提交
278 279 280 281
  template <int D>
  int operator()(Dim<D>) const {
    return D;
  }
朔-望's avatar
朔-望 已提交
282 283 284 285 286
};

/// \endcond

int arity(const DDim &d) {
287 288 289 290
  ArityVisitor arityVisitor = ArityVisitor();
  return DDim::ApplyVistor(arityVisitor, d);
  //  return arityVisitor(d.var.Get<Dim<4>>());
  //  return boost::apply_visitor(ArityVisitor(), d); }
朔-望's avatar
朔-望 已提交
291 292 293 294 295 296
}
/// \cond HIDDEN

/// \endcond

struct OSVistor : Vistor<std::ostream &> {
297
  OSVistor(std::ostream &os) : os_(os) {}
朔-望's avatar
朔-望 已提交
298

朔-望's avatar
朔-望 已提交
299 300
  template <int D>
  std::ostream &operator()(Dim<D> dim) const {
301 302
    return os_ << dim;
  }
朔-望's avatar
朔-望 已提交
303

朔-望's avatar
朔-望 已提交
304
 private:
305
  std::ostream &os_;
朔-望's avatar
朔-望 已提交
306 307 308
};

std::ostream &operator<<(std::ostream &os, const DDim &ddim) {
309 310 311
  auto vistor = OSVistor(os);
  DDim::ApplyVistor(vistor, ddim);
  return os;
朔-望's avatar
朔-望 已提交
312 313 314
}

DDim::DDim(std::initializer_list<int64_t> init_list) {
315
  *this = make_ddim(init_list);
朔-望's avatar
朔-望 已提交
316 317 318
}

DDim flatten_to_2d(const DDim &src, int num_col_dims) {
319 320 321
  int rank = src.size();
  return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
                    product(slice_ddim(src, num_col_dims, rank))});
朔-望's avatar
朔-望 已提交
322 323 324 325 326
}

DDim flatten_to_1d(const DDim &src) { return make_ddim({product(src)}); }

DDim stride(const DDim &ddim) {
327 328 329 330 331 332
  std::vector<int64_t> strides(ddim.size());
  strides[ddim.size() - 1] = 1;
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i + 1];
  }
  return framework::make_ddim(strides);
朔-望's avatar
朔-望 已提交
333 334 335
}

DDim stride_numel(const framework::DDim &ddim) {
336 337 338 339 340 341
  std::vector<int64_t> strides(ddim.size());
  strides[ddim.size() - 1] = ddim[ddim.size() - 1];
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i];
  }
  return framework::make_ddim(strides);
朔-望's avatar
朔-望 已提交
342 343
}

朔-望's avatar
朔-望 已提交
344 345
}  // namespace framework
}  // namespace paddle_mobile