ddim.cc 4.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15 16
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/platform/enforce.h"
F
fengjiayi 已提交
17

18 19
namespace paddle {
namespace framework {
F
fengjiayi 已提交
20

Q
qijun 已提交
21
DDim make_ddim(std::initializer_list<int64_t> dims) {
S
sneaxiy 已提交
22
  return DDim(dims.begin(), dims.size());
F
fengjiayi 已提交
23 24
}

Q
qijun 已提交
25
DDim make_ddim(const std::vector<int64_t>& dims) {
S
sneaxiy 已提交
26
  return DDim(dims.data(), dims.size());
F
fengjiayi 已提交
27 28
}

Y
Yu Yang 已提交
29
DDim make_ddim(const std::vector<int>& dims) {
S
sneaxiy 已提交
30
  return DDim(dims.data(), dims.size());
Y
Yu Yang 已提交
31 32
}

S
sneaxiy 已提交
33 34
struct DDimEqualityVisitor {
  explicit DDimEqualityVisitor(const int64_t* d) : d_(d) {}
F
fengjiayi 已提交
35 36

  template <int D>
S
sneaxiy 已提交
37
  inline bool operator()(const Dim<D>& self) const {
S
sneaxiy 已提交
38
    return UnrollCompare<D>::Run(self.Get(), d_);
F
fengjiayi 已提交
39 40
  }

S
sneaxiy 已提交
41
  const int64_t* d_;
F
fengjiayi 已提交
42 43
};

S
sneaxiy 已提交
44
bool DDim::operator==(const DDim& d) const {
S
sneaxiy 已提交
45
  return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.Get()));
F
fengjiayi 已提交
46 47
}

S
sneaxiy 已提交
48
bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
F
fengjiayi 已提交
49

S
sneaxiy 已提交
50 51 52
struct DDimPlusVisitor {
  explicit DDimPlusVisitor(const int64_t* d1, const int64_t* d2)
      : d1_(d1), d2_(d2) {}
F
fengjiayi 已提交
53

S
sneaxiy 已提交
54 55
  template <int D>
  inline void operator()(Dim<D>& self) const {
S
sneaxiy 已提交
56
    UnrollAdd<D>::Run(d1_, d2_, self.GetMutable());
F
fengjiayi 已提交
57 58
  }

S
sneaxiy 已提交
59 60 61
  const int64_t* d1_;
  const int64_t* d2_;
};
F
fengjiayi 已提交
62

S
sneaxiy 已提交
63 64
DDim DDim::operator+(const DDim& d) const {
  PADDLE_ENFORCE(rank_ == d.rank_);
S
sneaxiy 已提交
65 66
  DDim ret;
  ret.rank_ = rank_;
S
sneaxiy 已提交
67
  ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
S
sneaxiy 已提交
68
  return ret;
F
fengjiayi 已提交
69 70
}

S
sneaxiy 已提交
71 72 73
struct DDimMulVisitor {
  explicit DDimMulVisitor(const int64_t* d1, const int64_t* d2)
      : d1_(d1), d2_(d2) {}
F
fengjiayi 已提交
74

S
sneaxiy 已提交
75 76
  template <int D>
  inline void operator()(Dim<D>& self) const {
S
sneaxiy 已提交
77
    UnrollMul<D>::Run(d1_, d2_, self.GetMutable());
F
fengjiayi 已提交
78 79
  }

S
sneaxiy 已提交
80 81 82 83 84 85
  const int64_t* d1_;
  const int64_t* d2_;
};

DDim DDim::operator*(const DDim& d) const {
  PADDLE_ENFORCE(rank_ == d.rank_);
S
sneaxiy 已提交
86 87
  DDim ret;
  ret.rank_ = rank_;
S
sneaxiy 已提交
88
  ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
S
sneaxiy 已提交
89
  return ret;
F
fengjiayi 已提交
90 91
}

Q
qijun 已提交
92
int64_t get(const DDim& ddim, int idx) { return ddim[idx]; }
F
fengjiayi 已提交
93

S
sneaxiy 已提交
94
void set(DDim& ddim, int idx, int value) { ddim[idx] = value; }  // NOLINT
F
fengjiayi 已提交
95

Q
qijun 已提交
96
std::vector<int64_t> vectorize(const DDim& ddim) {
S
sneaxiy 已提交
97
  std::vector<int64_t> result(DDim::kMaxRank);
S
sneaxiy 已提交
98
  dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
S
sneaxiy 已提交
99
  result.resize(ddim.size());
F
fengjiayi 已提交
100 101 102
  return result;
}

C
chengduoZH 已提交
103 104 105
// NOTE: framework::vectorize converts to type int64_t
//       which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim& ddim) {
S
sneaxiy 已提交
106
  std::vector<int> result(DDim::kMaxRank);
S
sneaxiy 已提交
107
  dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
S
sneaxiy 已提交
108
  result.resize(ddim.size());
C
chengduoZH 已提交
109 110 111
  return result;
}

S
sneaxiy 已提交
112
struct ProductVisitor {
F
fengjiayi 已提交
113
  template <int D>
S
sneaxiy 已提交
114
  inline int64_t operator()(const Dim<D>& dim) {
F
fengjiayi 已提交
115
    return product(dim);
F
fengjiayi 已提交
116
  }
F
fengjiayi 已提交
117 118
};

Q
qijun 已提交
119
int64_t product(const DDim& ddim) {
S
sneaxiy 已提交
120
  return ddim.apply_visitor(ProductVisitor());
F
fengjiayi 已提交
121 122
}

123
DDim slice_ddim(const DDim& dim, int begin, int end) {
S
sneaxiy 已提交
124 125
  PADDLE_ENFORCE(begin >= 0,
                 "Begin index can't be less than zero in ddim slice.");
S
sneaxiy 已提交
126 127 128
  int len = end - begin;
  DDim ret;
  ret.rank_ = len;
S
sneaxiy 已提交
129
  dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_);
S
sneaxiy 已提交
130 131
  return ret;
}
F
fengjiayi 已提交
132

S
sneaxiy 已提交
133
int arity(const DDim& d) { return d.size(); }
F
fengjiayi 已提交
134

L
liaogang 已提交
135
/// \cond HIDDEN
F
fengjiayi 已提交
136

S
sneaxiy 已提交
137
struct DDimPrinter {
F
fengjiayi 已提交
138
  std::ostream& os;
L
liaogang 已提交
139
  explicit DDimPrinter(std::ostream& os_) : os(os_) {}
F
fengjiayi 已提交
140 141 142 143 144 145 146

  template <typename T>
  void operator()(const T& t) {
    os << t;
  }
};

L
liaogang 已提交
147
/// \endcond
F
fengjiayi 已提交
148

149
std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
S
sneaxiy 已提交
150
  ddim.apply_visitor(DDimPrinter(os));
F
fengjiayi 已提交
151 152 153
  return os;
}

F
fengjiayi 已提交
154
DDim flatten_to_2d(const DDim& src, int num_col_dims) {
155
  int rank = src.size();
F
Fix bug  
fengjiayi 已提交
156 157
  return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
                    product(slice_ddim(src, num_col_dims, rank))});
158 159
}

F
Fix bug  
fengjiayi 已提交
160
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
F
fengjiayi 已提交
161

W
wanghaoshuang 已提交
162
DDim stride(const DDim& ddim) {
S
sneaxiy 已提交
163 164
  DDim strides;
  strides.rank_ = ddim.size();
W
wanghaoshuang 已提交
165 166 167 168
  strides[ddim.size() - 1] = 1;
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i + 1];
  }
S
sneaxiy 已提交
169
  return strides;
W
wanghaoshuang 已提交
170
}
Y
Yancey1989 已提交
171

S
sneaxiy 已提交
172
DDim stride_numel(const DDim& ddim) {
S
sneaxiy 已提交
173 174
  DDim strides;
  strides.rank_ = ddim.size();
Y
Yancey1989 已提交
175 176 177 178
  strides[ddim.size() - 1] = ddim[ddim.size() - 1];
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i];
  }
S
sneaxiy 已提交
179
  return strides;
Y
Yancey1989 已提交
180 181
}

182 183
}  // namespace framework
}  // namespace paddle