ddim.cpp 7.4 KB
Newer Older
W
wangliu 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
朔-望's avatar
朔-望 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "ddim.h"
L
liuruilong 已提交
16
#include <algorithm>
朔-望's avatar
朔-望 已提交
17 18

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
19 20 21 22
namespace framework {

/// @cond HIDDEN

朔-望's avatar
朔-望 已提交
23 24
template <int i>
Dim<i> make_dim(const int64_t *d) {
25
  return Dim<i>(*d, make_dim<i - 1>(d + 1));
朔-望's avatar
朔-望 已提交
26 27
}

朔-望's avatar
朔-望 已提交
28 29 30 31
template <>
Dim<0> make_dim<0>(const int64_t *d) {
  return Dim<0>(*d);
}
朔-望's avatar
朔-望 已提交
32 33

void make_ddim(DDim &ddim, const int64_t *dims, int n) {
34
  switch (n) {
朔-望's avatar
朔-望 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    case 0:
      ddim = make_dim<0>(dims);
      break;
    case 1:
      ddim = make_dim<1>(dims);
      break;
    case 2:
      ddim = make_dim<2>(dims);
      break;
    case 3:
      ddim = make_dim<3>(dims);
      break;
    case 4:
      ddim = make_dim<4>(dims);
      break;
    case 5:
      ddim = make_dim<5>(dims);
      break;
    case 6:
      ddim = make_dim<6>(dims);
      break;
    case 7:
      ddim = make_dim<7>(dims);
      break;
    case 8:
      ddim = make_dim<8>(dims);
      break;
    case 9:
      ddim = make_dim<9>(dims);
      break;
    default:
      break;
67
  }
朔-望's avatar
朔-望 已提交
68 69 70 71 72
}

/// @endcond

DDim make_ddim(std::initializer_list<int64_t> dims) {
73 74 75
  DDim result(make_dim(0));
  make_ddim(result, dims.begin(), dims.size());
  return result;
朔-望's avatar
朔-望 已提交
76 77 78
}

DDim make_ddim(const std::vector<int64_t> &dims) {
79 80 81
  DDim result(make_dim(0));
  make_ddim(result, &dims[0], dims.size());
  return result;
朔-望's avatar
朔-望 已提交
82 83 84
}

DDim make_ddim(const std::vector<int> &dims) {
85 86 87 88
  std::vector<int64_t> res(dims.size());
  std::transform(dims.begin(), dims.end(), res.begin(),
                 [](int d) { return static_cast<int64_t>(d); });
  return make_ddim(res);
朔-望's avatar
朔-望 已提交
89 90 91 92 93 94
}

/// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes
// errors
struct DynamicMutableIndexer : Vistor<int64_t &> {
朔-望's avatar
朔-望 已提交
95
 public:
96
  explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
朔-望's avatar
朔-望 已提交
97

朔-望's avatar
朔-望 已提交
98 99 100 101
  template <int D>
  int64_t &operator()(Dim<D> &dim) const {
    return dim[idx_];
  }
朔-望's avatar
朔-望 已提交
102

朔-望's avatar
朔-望 已提交
103
 private:
104
  int idx_;
朔-望's avatar
朔-望 已提交
105 106 107
};

struct DynamicConstIndexer : public Vistor<int64_t> {
朔-望's avatar
朔-望 已提交
108
 public:
109
  explicit DynamicConstIndexer(int idx) : idx_(idx) {}
朔-望's avatar
朔-望 已提交
110

朔-望's avatar
朔-望 已提交
111 112
  template <int D>
  int64_t operator()(const Dim<D> &dim) const {
113 114
    return dim[idx_];
  }
朔-望's avatar
朔-望 已提交
115

朔-望's avatar
朔-望 已提交
116
 private:
117
  int idx_;
朔-望's avatar
朔-望 已提交
118 119 120 121 122
};

/// @endcond

int64_t &DDim::operator[](int idx) {
123
  return DDim::ApplyVistor(DynamicMutableIndexer(idx), *this);
朔-望's avatar
朔-望 已提交
124 125 126
}

int64_t DDim::operator[](int idx) const {
127
  return DDim::ApplyVistor(DynamicConstIndexer(idx), *this);
朔-望's avatar
朔-望 已提交
128 129 130 131 132
}

int DDim::size() const { return arity(*this); }

bool DDim::operator==(DDim d) const {
133 134 135
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);

136 137 138 139
  if (v1.size() != v2.size()) {
    return false;
  }

140 141 142
  for (unsigned int i = 0; i < v1.size(); i++) {
    if (v1[i] != v2[i]) {
      return false;
朔-望's avatar
朔-望 已提交
143
    }
144
  }
145

146 147
  return true;
  //  }
朔-望's avatar
朔-望 已提交
148 149 150 151 152
}

bool DDim::operator!=(DDim d) const { return !(*this == d); }

DDim DDim::operator+(DDim d) const {
153 154
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);
朔-望's avatar
朔-望 已提交
155

156
  std::vector<int64_t> v3;
朔-望's avatar
朔-望 已提交
157

158
  PADDLE_MOBILE_ENFORCE(v1.size() == v2.size(), "v1.size() != v2.size()");
朔-望's avatar
朔-望 已提交
159

160 161 162
  for (unsigned int i = 0; i < v1.size(); i++) {
    v3.push_back(v1[i] + v2[i]);
  }
朔-望's avatar
朔-望 已提交
163

164
  return make_ddim(v3);
朔-望's avatar
朔-望 已提交
165 166 167
}

DDim DDim::operator*(DDim d) const {
168 169
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);
朔-望's avatar
朔-望 已提交
170

171
  std::vector<int64_t> v3;
朔-望's avatar
朔-望 已提交
172

W
wangliu 已提交
173
  PADDLE_MOBILE_ENFORCE(v1.size() == v2.size(), "v1.size() == v2.size()");
朔-望's avatar
朔-望 已提交
174

175 176 177
  for (unsigned int i = 0; i < v1.size(); i++) {
    v3.push_back(v1[i] * v2[i]);
  }
朔-望's avatar
朔-望 已提交
178

179
  return make_ddim(v3);
朔-望's avatar
朔-望 已提交
180 181 182 183
}

int64_t get(const DDim &ddim, int idx) { return ddim[idx]; }

W
wangliu 已提交
184
void set(DDim *ddim, int idx, int value) { (*ddim)[idx] = value; }
朔-望's avatar
朔-望 已提交
185 186 187

/// @cond HIDDEN
struct VectorizeVisitor : Vistor<void> {
188
  std::vector<int64_t> &vector;
朔-望's avatar
朔-望 已提交
189

190
  explicit VectorizeVisitor(std::vector<int64_t> &v) : vector(v) {}
朔-望's avatar
朔-望 已提交
191

朔-望's avatar
朔-望 已提交
192 193
  template <typename T>
  void operator()(const T &t) {
194 195 196
    vector.push_back(t.head);
    this->operator()(t.tail);
  }
朔-望's avatar
朔-望 已提交
197

198
  void operator()(const Dim<0> &t) {}
朔-望's avatar
朔-望 已提交
199 200 201 202
};
/// @endcond

std::vector<int64_t> vectorize(const DDim &ddim) {
203 204 205 206
  std::vector<int64_t> result;
  VectorizeVisitor visitor(result);
  DDim::ApplyVistor(visitor, ddim);
  return result;
朔-望's avatar
朔-望 已提交
207 208 209 210 211
}

// NOTE: framework::vectorize converts to type int64_t
//       which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim &ddim) {
212 213 214
  std::vector<int64_t> temp = vectorize(ddim);
  std::vector<int> result(temp.begin(), temp.end());
  return result;
朔-望's avatar
朔-望 已提交
215 216 217
}

struct ProductVisitor : Vistor<int64_t> {
朔-望's avatar
朔-望 已提交
218 219
  template <int D>
  int64_t operator()(const Dim<D> &dim) {
220 221
    return product(dim);
  }
朔-望's avatar
朔-望 已提交
222 223 224
};

int64_t product(const DDim &ddim) {
225 226
  ProductVisitor visitor;
  return DDim::ApplyVistor(visitor, ddim);
朔-望's avatar
朔-望 已提交
227 228 229
}

struct SliceVectorizeVisitor : Vistor<void> {
230 231 232 233 234 235
  std::vector<int64_t> &vector;
  int begin;
  int end;

  SliceVectorizeVisitor(std::vector<int64_t> &v, int b, int e)
      : vector(v), begin(b), end(e) {
W
wangliu 已提交
236 237 238 239
    PADDLE_MOBILE_ENFORCE(
        begin < end, "Begin index must be less than end index in ddim slice.");
    PADDLE_MOBILE_ENFORCE(begin >= 0,
                          "Begin index can't be less than zero in ddim slice.");
240 241
  }

朔-望's avatar
朔-望 已提交
242 243
  template <int S>
  void operator()(const Dim<S> &dim) {
244 245 246 247
    if (begin == 0) {
      vector.push_back(dim.head);
    } else {
      --begin;
朔-望's avatar
朔-望 已提交
248
    }
249 250 251
    --end;
    if (end > 0) {
      this->operator()(dim.tail);
朔-望's avatar
朔-望 已提交
252
    }
253
  }
朔-望's avatar
朔-望 已提交
254

255 256 257 258
  void operator()(const Dim<0> &dim) {
    //    PADDLE_ENFORCE(end == 0, "End index in ddim slice is out
    //    of bound.");
  }
朔-望's avatar
朔-望 已提交
259 260 261
};

DDim slice_ddim(const DDim &ddim, int begin, int end) {
262 263 264 265 266
  std::vector<int64_t> vec;
  vec.reserve(end - begin);
  SliceVectorizeVisitor visitor(vec, begin, end);
  DDim::ApplyVistor(visitor, ddim);
  return make_ddim(vec);
朔-望's avatar
朔-望 已提交
267 268 269 270 271
}

/// \cond HIDDEN

struct ArityVisitor : Vistor<int> {
朔-望's avatar
朔-望 已提交
272 273 274 275
  template <int D>
  int operator()(Dim<D>) const {
    return D;
  }
朔-望's avatar
朔-望 已提交
276 277 278 279 280
};

/// \endcond

int arity(const DDim &d) {
281 282
  ArityVisitor arityVisitor = ArityVisitor();
  return DDim::ApplyVistor(arityVisitor, d);
朔-望's avatar
朔-望 已提交
283
}
W
wangliu 已提交
284 285

#ifdef PADDLE_MOBILE_DEBUG
W
wangliu 已提交
286
Print &operator<<(Print &printer, const DDim &ddim) {
W
wangliu 已提交
287 288
  for (int j = 0; j < ddim.size(); ++j) {
    printer << ddim[j] << " ";
289
  }
朔-望's avatar
朔-望 已提交
290

W
wangliu 已提交
291
  return printer;
朔-望's avatar
朔-望 已提交
292 293
}

W
wangliu 已提交
294 295
#endif

朔-望's avatar
朔-望 已提交
296
DDim::DDim(std::initializer_list<int64_t> init_list) {
297
  *this = make_ddim(init_list);
朔-望's avatar
朔-望 已提交
298 299 300
}

DDim flatten_to_2d(const DDim &src, int num_col_dims) {
301 302 303
  int rank = src.size();
  return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
                    product(slice_ddim(src, num_col_dims, rank))});
朔-望's avatar
朔-望 已提交
304 305 306 307 308
}

DDim flatten_to_1d(const DDim &src) { return make_ddim({product(src)}); }

DDim stride(const DDim &ddim) {
309 310 311 312 313 314
  std::vector<int64_t> strides(ddim.size());
  strides[ddim.size() - 1] = 1;
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i + 1];
  }
  return framework::make_ddim(strides);
朔-望's avatar
朔-望 已提交
315 316 317
}

DDim stride_numel(const framework::DDim &ddim) {
318 319 320 321 322 323
  std::vector<int64_t> strides(ddim.size());
  strides[ddim.size() - 1] = ddim[ddim.size() - 1];
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i];
  }
  return framework::make_ddim(strides);
朔-望's avatar
朔-望 已提交
324 325
}

朔-望's avatar
朔-望 已提交
326 327
}  // namespace framework
}  // namespace paddle_mobile