ddim.cpp 7.4 KB
Newer Older
W
wangliu 已提交
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
朔-望's avatar
朔-望 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "ddim.h"
L
liuruilong 已提交
16
#include <algorithm>
朔-望's avatar
朔-望 已提交
17 18

namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
19 20 21 22
namespace framework {

/// @cond HIDDEN

朔-望's avatar
朔-望 已提交
23 24
template <int i>
Dim<i> make_dim(const int64_t *d) {
25
  return Dim<i>(*d, make_dim<i - 1>(d + 1));
朔-望's avatar
朔-望 已提交
26 27
}

朔-望's avatar
朔-望 已提交
28 29 30 31
template <>
Dim<0> make_dim<0>(const int64_t *d) {
  return Dim<0>(*d);
}
朔-望's avatar
朔-望 已提交
32 33

void make_ddim(DDim &ddim, const int64_t *dims, int n) {
34
  switch (n) {
朔-望's avatar
朔-望 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    case 0:
      ddim = make_dim<0>(dims);
      break;
    case 1:
      ddim = make_dim<1>(dims);
      break;
    case 2:
      ddim = make_dim<2>(dims);
      break;
    case 3:
      ddim = make_dim<3>(dims);
      break;
    case 4:
      ddim = make_dim<4>(dims);
      break;
    case 5:
      ddim = make_dim<5>(dims);
      break;
    case 6:
      ddim = make_dim<6>(dims);
      break;
    case 7:
      ddim = make_dim<7>(dims);
      break;
    case 8:
      ddim = make_dim<8>(dims);
      break;
    case 9:
      ddim = make_dim<9>(dims);
      break;
    default:
      break;
67
  }
朔-望's avatar
朔-望 已提交
68 69 70 71 72
}

/// @endcond

DDim make_ddim(std::initializer_list<int64_t> dims) {
73 74 75
  DDim result(make_dim(0));
  make_ddim(result, dims.begin(), dims.size());
  return result;
朔-望's avatar
朔-望 已提交
76 77 78
}

DDim make_ddim(const std::vector<int64_t> &dims) {
79 80 81
  DDim result(make_dim(0));
  make_ddim(result, &dims[0], dims.size());
  return result;
朔-望's avatar
朔-望 已提交
82 83 84
}

DDim make_ddim(const std::vector<int> &dims) {
85 86 87 88
  std::vector<int64_t> res(dims.size());
  std::transform(dims.begin(), dims.end(), res.begin(),
                 [](int d) { return static_cast<int64_t>(d); });
  return make_ddim(res);
朔-望's avatar
朔-望 已提交
89 90 91 92 93 94
}

/// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes
// errors
struct DynamicMutableIndexer : Vistor<int64_t &> {
朔-望's avatar
朔-望 已提交
95
 public:
96
  explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
朔-望's avatar
朔-望 已提交
97

朔-望's avatar
朔-望 已提交
98 99 100 101
  template <int D>
  int64_t &operator()(Dim<D> &dim) const {
    return dim[idx_];
  }
朔-望's avatar
朔-望 已提交
102

朔-望's avatar
朔-望 已提交
103
 private:
104
  int idx_;
朔-望's avatar
朔-望 已提交
105 106 107
};

struct DynamicConstIndexer : public Vistor<int64_t> {
朔-望's avatar
朔-望 已提交
108
 public:
109
  explicit DynamicConstIndexer(int idx) : idx_(idx) {}
朔-望's avatar
朔-望 已提交
110

朔-望's avatar
朔-望 已提交
111 112
  template <int D>
  int64_t operator()(const Dim<D> &dim) const {
113 114
    return dim[idx_];
  }
朔-望's avatar
朔-望 已提交
115

朔-望's avatar
朔-望 已提交
116
 private:
117
  int idx_;
朔-望's avatar
朔-望 已提交
118 119 120 121 122
};

/// @endcond

int64_t &DDim::operator[](int idx) {
123
  return DDim::ApplyVistor(DynamicMutableIndexer(idx), *this);
朔-望's avatar
朔-望 已提交
124 125 126
}

int64_t DDim::operator[](int idx) const {
127
  return DDim::ApplyVistor(DynamicConstIndexer(idx), *this);
朔-望's avatar
朔-望 已提交
128 129 130 131 132
}

int DDim::size() const { return arity(*this); }

bool DDim::operator==(DDim d) const {
133 134 135 136 137 138
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);

  for (unsigned int i = 0; i < v1.size(); i++) {
    if (v1[i] != v2[i]) {
      return false;
朔-望's avatar
朔-望 已提交
139
    }
140
  }
141

142 143
  return true;
  //  }
朔-望's avatar
朔-望 已提交
144 145 146 147 148
}

bool DDim::operator!=(DDim d) const { return !(*this == d); }

DDim DDim::operator+(DDim d) const {
149 150
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);
朔-望's avatar
朔-望 已提交
151

152
  std::vector<int64_t> v3;
朔-望's avatar
朔-望 已提交
153

154
  PADDLE_MOBILE_ENFORCE(v1.size() == v2.size(), "v1.size() != v2.size()");
朔-望's avatar
朔-望 已提交
155

156 157 158
  for (unsigned int i = 0; i < v1.size(); i++) {
    v3.push_back(v1[i] + v2[i]);
  }
朔-望's avatar
朔-望 已提交
159

160
  return make_ddim(v3);
朔-望's avatar
朔-望 已提交
161 162 163
}

DDim DDim::operator*(DDim d) const {
164 165
  std::vector<int64_t> v1 = vectorize(*this);
  std::vector<int64_t> v2 = vectorize(d);
朔-望's avatar
朔-望 已提交
166

167
  std::vector<int64_t> v3;
朔-望's avatar
朔-望 已提交
168

W
wangliu 已提交
169
  PADDLE_MOBILE_ENFORCE(v1.size() == v2.size(), "v1.size() == v2.size()");
朔-望's avatar
朔-望 已提交
170

171 172 173
  for (unsigned int i = 0; i < v1.size(); i++) {
    v3.push_back(v1[i] * v2[i]);
  }
朔-望's avatar
朔-望 已提交
174

175
  return make_ddim(v3);
朔-望's avatar
朔-望 已提交
176 177 178 179
}

int64_t get(const DDim &ddim, int idx) { return ddim[idx]; }

W
wangliu 已提交
180
void set(DDim *ddim, int idx, int value) { (*ddim)[idx] = value; }
朔-望's avatar
朔-望 已提交
181 182 183

/// @cond HIDDEN
struct VectorizeVisitor : Vistor<void> {
184
  std::vector<int64_t> &vector;
朔-望's avatar
朔-望 已提交
185

186
  explicit VectorizeVisitor(std::vector<int64_t> &v) : vector(v) {}
朔-望's avatar
朔-望 已提交
187

朔-望's avatar
朔-望 已提交
188 189
  template <typename T>
  void operator()(const T &t) {
190 191 192
    vector.push_back(t.head);
    this->operator()(t.tail);
  }
朔-望's avatar
朔-望 已提交
193

194
  void operator()(const Dim<0> &t) {}
朔-望's avatar
朔-望 已提交
195 196 197 198
};
/// @endcond

std::vector<int64_t> vectorize(const DDim &ddim) {
199 200 201 202
  std::vector<int64_t> result;
  VectorizeVisitor visitor(result);
  DDim::ApplyVistor(visitor, ddim);
  return result;
朔-望's avatar
朔-望 已提交
203 204 205 206 207
}

// NOTE: framework::vectorize converts to type int64_t
//       which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim &ddim) {
208 209 210
  std::vector<int64_t> temp = vectorize(ddim);
  std::vector<int> result(temp.begin(), temp.end());
  return result;
朔-望's avatar
朔-望 已提交
211 212 213
}

struct ProductVisitor : Vistor<int64_t> {
朔-望's avatar
朔-望 已提交
214 215
  template <int D>
  int64_t operator()(const Dim<D> &dim) {
216 217
    return product(dim);
  }
朔-望's avatar
朔-望 已提交
218 219 220
};

int64_t product(const DDim &ddim) {
221 222
  ProductVisitor visitor;
  return DDim::ApplyVistor(visitor, ddim);
朔-望's avatar
朔-望 已提交
223 224 225
}

struct SliceVectorizeVisitor : Vistor<void> {
226 227 228 229 230 231
  std::vector<int64_t> &vector;
  int begin;
  int end;

  SliceVectorizeVisitor(std::vector<int64_t> &v, int b, int e)
      : vector(v), begin(b), end(e) {
W
wangliu 已提交
232 233 234 235
    PADDLE_MOBILE_ENFORCE(
        begin < end, "Begin index must be less than end index in ddim slice.");
    PADDLE_MOBILE_ENFORCE(begin >= 0,
                          "Begin index can't be less than zero in ddim slice.");
236 237
  }

朔-望's avatar
朔-望 已提交
238 239
  template <int S>
  void operator()(const Dim<S> &dim) {
240 241 242 243
    if (begin == 0) {
      vector.push_back(dim.head);
    } else {
      --begin;
朔-望's avatar
朔-望 已提交
244
    }
245 246 247
    --end;
    if (end > 0) {
      this->operator()(dim.tail);
朔-望's avatar
朔-望 已提交
248
    }
249
  }
朔-望's avatar
朔-望 已提交
250

251 252 253 254
  void operator()(const Dim<0> &dim) {
    //    PADDLE_ENFORCE(end == 0, "End index in ddim slice is out
    //    of bound.");
  }
朔-望's avatar
朔-望 已提交
255 256 257
};

DDim slice_ddim(const DDim &ddim, int begin, int end) {
258 259 260 261 262
  std::vector<int64_t> vec;
  vec.reserve(end - begin);
  SliceVectorizeVisitor visitor(vec, begin, end);
  DDim::ApplyVistor(visitor, ddim);
  return make_ddim(vec);
朔-望's avatar
朔-望 已提交
263 264 265 266 267
}

/// \cond HIDDEN

struct ArityVisitor : Vistor<int> {
朔-望's avatar
朔-望 已提交
268 269 270 271
  template <int D>
  int operator()(Dim<D>) const {
    return D;
  }
朔-望's avatar
朔-望 已提交
272 273 274 275 276
};

/// \endcond

int arity(const DDim &d) {
277 278
  ArityVisitor arityVisitor = ArityVisitor();
  return DDim::ApplyVistor(arityVisitor, d);
朔-望's avatar
朔-望 已提交
279
}
W
wangliu 已提交
280 281

#ifdef PADDLE_MOBILE_DEBUG
W
wangliu 已提交
282
Print &operator<<(Print &printer, const DDim &ddim) {
W
wangliu 已提交
283 284
  for (int j = 0; j < ddim.size(); ++j) {
    printer << ddim[j] << " ";
285
  }
朔-望's avatar
朔-望 已提交
286

W
wangliu 已提交
287
  return printer;
朔-望's avatar
朔-望 已提交
288 289
}

W
wangliu 已提交
290 291
#endif

朔-望's avatar
朔-望 已提交
292
DDim::DDim(std::initializer_list<int64_t> init_list) {
293
  *this = make_ddim(init_list);
朔-望's avatar
朔-望 已提交
294 295 296
}

DDim flatten_to_2d(const DDim &src, int num_col_dims) {
297 298 299
  int rank = src.size();
  return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
                    product(slice_ddim(src, num_col_dims, rank))});
朔-望's avatar
朔-望 已提交
300 301 302 303 304
}

DDim flatten_to_1d(const DDim &src) { return make_ddim({product(src)}); }

DDim stride(const DDim &ddim) {
305 306 307 308 309 310
  std::vector<int64_t> strides(ddim.size());
  strides[ddim.size() - 1] = 1;
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i + 1];
  }
  return framework::make_ddim(strides);
朔-望's avatar
朔-望 已提交
311 312 313
}

DDim stride_numel(const framework::DDim &ddim) {
314 315 316 317 318 319
  std::vector<int64_t> strides(ddim.size());
  strides[ddim.size() - 1] = ddim[ddim.size() - 1];
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i];
  }
  return framework::make_ddim(strides);
朔-望's avatar
朔-望 已提交
320 321
}

朔-望's avatar
朔-望 已提交
322 323
}  // namespace framework
}  // namespace paddle_mobile