ddim.h 7.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

F
fengjiayi 已提交
15 16 17 18
#pragma once

#include <initializer_list>
#include <stdexcept>
L
liuwei1031 已提交
19
#include <string>
F
fengjiayi 已提交
20
#include <vector>
W
wanghuancoder 已提交
21

Y
Yi Wang 已提交
22
#include "paddle/fluid/framework/dim.h"
F
fengjiayi 已提交
23

24 25
namespace paddle {
namespace framework {
F
fengjiayi 已提交
26

S
sneaxiy 已提交
27 28 29 30 31 32
#define PADDLE_VISIT_DDIM_BASE(rank, callback) \
  case (rank): {                               \
    constexpr auto kRank = (rank);             \
    return (callback);                         \
  }

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
#define PADDLE_VISIT_DDIM(rank, callback)                                  \
  switch (rank) {                                                          \
    PADDLE_VISIT_DDIM_BASE(0, callback);                                   \
    PADDLE_VISIT_DDIM_BASE(1, callback);                                   \
    PADDLE_VISIT_DDIM_BASE(2, callback);                                   \
    PADDLE_VISIT_DDIM_BASE(3, callback);                                   \
    PADDLE_VISIT_DDIM_BASE(4, callback);                                   \
    PADDLE_VISIT_DDIM_BASE(5, callback);                                   \
    PADDLE_VISIT_DDIM_BASE(6, callback);                                   \
    PADDLE_VISIT_DDIM_BASE(7, callback);                                   \
    PADDLE_VISIT_DDIM_BASE(8, callback);                                   \
    PADDLE_VISIT_DDIM_BASE(9, callback);                                   \
    default:                                                               \
      PADDLE_THROW(platform::errors::Unimplemented(                        \
          "Invalid dimension to be accessed. Now only supports access to " \
          "dimension 0 to 9, but received dimension is %d.",               \
          rank));                                                          \
S
sneaxiy 已提交
50 51
  }

S
sneaxiy 已提交
52 53
template <typename T1, typename T2>
inline void dynamic_dim_assign(const T1* in, T2* out, int n) {
S
sneaxiy 已提交
54
  PADDLE_VISIT_DDIM(n, (static_dim_assign<kRank, T1, T2>(in, out)));
S
sneaxiy 已提交
55 56
}

F
fengjiayi 已提交
57 58 59 60 61
/**
 * \brief A dynamically sized dimension.
 *
 * The number of dimensions must be between [1, 9].
 */
S
sneaxiy 已提交
62 63 64
class DDim {
 public:
  constexpr static int kMaxRank = 9;
F
fengjiayi 已提交
65

S
sneaxiy 已提交
66 67
  DDim() : rank_(1) { dim_[0] = 0; }

S
sneaxiy 已提交
68
  DDim(const DDim& ddim) : dim_() { CopyFrom(ddim); }
S
sneaxiy 已提交
69

S
sneaxiy 已提交
70 71 72 73 74 75 76
  DDim(const int* d, int n) : rank_(n) {
    dynamic_dim_assign(d, dim_.GetMutable(), n);
  }

  DDim(const int64_t* d, int n) : rank_(n) {
    dynamic_dim_assign(d, dim_.GetMutable(), n);
  }
F
fengjiayi 已提交
77 78

  template <int D>
S
sneaxiy 已提交
79 80 81
  /*implicit*/ DDim(const Dim<D>& in) : rank_(D) {  // NOLINT
    UnsafeCast<D>() = in;
  }
F
fengjiayi 已提交
82

S
sneaxiy 已提交
83 84
  /*implicit*/ DDim(std::initializer_list<int64_t> init_list)
      : DDim(init_list.begin(), init_list.size()) {}
F
fengjiayi 已提交
85

S
sneaxiy 已提交
86 87
  inline DDim& operator=(const DDim& ddim) { return CopyFrom(ddim); }

F
fengjiayi 已提交
88
  template <int D>
S
sneaxiy 已提交
89
  inline DDim& operator=(const Dim<D>& dim) {
S
sneaxiy 已提交
90
    rank_ = D;
S
sneaxiy 已提交
91
    UnsafeCast<D>() = dim;
F
fengjiayi 已提交
92 93 94
    return *this;
  }

S
sneaxiy 已提交
95
  inline int64_t& operator[](int idx) { return dim_[idx]; }
F
fengjiayi 已提交
96

S
sneaxiy 已提交
97 98
  inline int64_t operator[](int idx) const { return dim_[idx]; }

99 100 101 102 103 104 105 106 107 108 109
  int64_t& at(int idx) {
    PADDLE_ENFORCE_GE(idx, 0,
                      platform::errors::InvalidArgument(
                          "Invalid DDim index to be accessed. The valid index "
                          "is between 0 and %d, but received index is %d.",
                          rank_, idx));
    PADDLE_ENFORCE_LT(idx, rank_,
                      platform::errors::InvalidArgument(
                          "Invalid DDim index to be accessed. The valid index "
                          "is between 0 and %d, but received index is %d.",
                          rank_, idx));
S
sneaxiy 已提交
110
    return dim_[idx];
F
fengjiayi 已提交
111 112
  }

113 114 115 116 117 118 119 120 121 122 123
  int64_t at(int idx) const {
    PADDLE_ENFORCE_GE(idx, 0,
                      platform::errors::InvalidArgument(
                          "Invalid DDim index to be accessed. The valid index "
                          "is between 0 and %d, but received index is %d.",
                          rank_, idx));
    PADDLE_ENFORCE_LT(idx, rank_,
                      platform::errors::InvalidArgument(
                          "Invalid DDim index to be accessed. The valid index "
                          "is between 0 and %d, but received index is %d.",
                          rank_, idx));
S
sneaxiy 已提交
124
    return dim_[idx];
F
fengjiayi 已提交
125 126
  }

S
sneaxiy 已提交
127 128
  template <typename Visitor>
  typename std::result_of<Visitor(Dim<0>&)>::type apply_visitor(
S
sneaxiy 已提交
129 130 131
      Visitor&& visitor) {
    PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast<kRank>()));
  }
S
sneaxiy 已提交
132 133 134

  template <typename Visitor>
  typename std::result_of<Visitor(const Dim<0>&)>::type apply_visitor(
S
sneaxiy 已提交
135 136 137
      Visitor&& visitor) const {
    PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast<kRank>()));
  }
S
sneaxiy 已提交
138 139 140 141 142

  bool operator==(const DDim& d) const;

  bool operator!=(const DDim& d) const;

S
sneaxiy 已提交
143
  inline const int64_t* Get() const { return dim_.Get(); }
F
fengjiayi 已提交
144

S
sneaxiy 已提交
145
  inline int64_t* GetMutable() { return dim_.GetMutable(); }
S
sneaxiy 已提交
146

S
sneaxiy 已提交
147
  inline int size() const { return rank_; }
S
sneaxiy 已提交
148

L
liuwei1031 已提交
149 150
  std::string to_str() const;

151 152 153 154
  DDim reshape(const std::vector<int>& shape) const;

  DDim transpose(const std::vector<int>& axis) const;

S
sneaxiy 已提交
155
 private:
S
sneaxiy 已提交
156 157
  template <int D>
  inline Dim<D>& UnsafeCast() {
S
sneaxiy 已提交
158 159 160
    static_assert(D >= 0 && D <= kMaxRank, "Invalid rank");
    auto* p = static_cast<void*>(&dim_);
    return *reinterpret_cast<Dim<D>*>(p);
S
sneaxiy 已提交
161
  }
162

S
sneaxiy 已提交
163 164 165
  template <int D>
  inline const Dim<D>& UnsafeCast() const {
    static_assert(D >= 0 && D <= kMaxRank, "Invalid rank");
S
sneaxiy 已提交
166
    auto* p = static_cast<const void*>(&dim_);
S
sneaxiy 已提交
167
    return *reinterpret_cast<const Dim<D>*>(p);
S
sneaxiy 已提交
168 169
  }

S
sneaxiy 已提交
170 171
  inline DDim& CopyFrom(const DDim& ddim) {
    PADDLE_VISIT_DDIM(ddim.rank_, (*this = ddim.UnsafeCast<kRank>()));
S
sneaxiy 已提交
172 173
  }

S
sneaxiy 已提交
174 175 176
  friend DDim stride(const DDim& ddim);
  friend DDim stride_numel(const DDim& ddim);

S
sneaxiy 已提交
177
 private:
S
sneaxiy 已提交
178 179
  Dim<kMaxRank> dim_;
  int rank_;
F
fengjiayi 已提交
180 181
};

S
sneaxiy 已提交
182
#undef PADDLE_VISIT_DDIM_BASE
S
sneaxiy 已提交
183 184
#undef PADDLE_VISIT_DDIM

F
fengjiayi 已提交
185
/**
Q
qijun 已提交
186
 * \brief Make a DDim from std::vector<int64_t>
F
fengjiayi 已提交
187 188 189
 *
 * \param dims An vector of ints. Must be sized between [1, 9]
 */
Q
qijun 已提交
190
DDim make_ddim(const std::vector<int64_t>& dims);
F
fengjiayi 已提交
191

Y
Yu Yang 已提交
192 193
DDim make_ddim(const std::vector<int>& dims);

F
fengjiayi 已提交
194 195 196 197 198 199
/**
 * \brief Make a DDim from an initializer list
 *
 * \param dims An initializer list of ints. Must be sized between [1, 9]
 *
 */
Q
qijun 已提交
200
DDim make_ddim(std::initializer_list<int64_t> dims);
F
fengjiayi 已提交
201

202 203 204 205 206 207 208
template <typename T = int64_t>
std::vector<T> vectorize(const DDim& ddim) {
  std::vector<T> result(DDim::kMaxRank);
  dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
  result.resize(ddim.size());
  return result;
}
F
fengjiayi 已提交
209

Q
qijun 已提交
210
int64_t product(const DDim& ddim);
F
fengjiayi 已提交
211

H
Hongyu Liu 已提交
212 213
bool contain_unknown_dim(const DDim& ddim);

F
fengjiayi 已提交
214 215 216 217 218 219 220
/**
 * \brief Slice a ddim
 *
 * Slice dim with [begin, end).
 * e.g.  DDim d = make_ddim({1,2,3,4,5});
 *       slice_ddim(d, 1, 3); ====> {2,3}
 */
221 222
DDim slice_ddim(const DDim& dim, int begin, int end);

F
fengjiayi 已提交
223 224 225 226 227 228 229 230
/**
 * \brief What is the length of this dimension?
 *
 * \param Dynamic dimension to inspect
 */

int arity(const DDim& ddim);

231
std::ostream& operator<<(std::ostream&, const DDim&);
F
fengjiayi 已提交
232

233 234 235 236 237 238 239
/**
* \brief Flatten dim to 3d
* e.g., DDim d = mak_ddim({1, 2, 3, 4, 5, 6})
*       flatten_to_3d(d, 2, 4); ===> {1*2, 3*4, 5*6} ===> {2, 12, 30}
*/
DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims);

F
fengjiayi 已提交
240
// Reshape a tensor to a matrix. The matrix's first dimension(column length)
F
test  
fengjiayi 已提交
241
// will be the product of tensor's first `num_col_dims` dimensions.
F
fengjiayi 已提交
242
DDim flatten_to_2d(const DDim& src, int num_col_dims);
243

F
fengjiayi 已提交
244 245
DDim flatten_to_1d(const DDim& src);

W
wanghaoshuang 已提交
246
DDim stride(const DDim& ddim);
Y
Yancey1989 已提交
247 248

DDim stride_numel(const DDim& ddim);
249 250
}  // namespace framework
}  // namespace paddle