ddim.cc 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/core/ddim.h"
16

17
#include <set>
F
fengjiayi 已提交
18

19
namespace phi {
F
fengjiayi 已提交
20

Q
qijun 已提交
21
DDim make_ddim(std::initializer_list<int64_t> dims) {
S
sneaxiy 已提交
22
  return DDim(dims.begin(), dims.size());
F
fengjiayi 已提交
23 24
}

Q
qijun 已提交
25
DDim make_ddim(const std::vector<int64_t>& dims) {
S
sneaxiy 已提交
26
  return DDim(dims.data(), dims.size());
F
fengjiayi 已提交
27 28
}

Y
Yu Yang 已提交
29
DDim make_ddim(const std::vector<int>& dims) {
S
sneaxiy 已提交
30
  return DDim(dims.data(), dims.size());
Y
Yu Yang 已提交
31 32
}

S
sneaxiy 已提交
33 34
struct DDimEqualityVisitor {
  explicit DDimEqualityVisitor(const int64_t* d) : d_(d) {}
F
fengjiayi 已提交
35 36

  template <int D>
S
sneaxiy 已提交
37
  inline bool operator()(const Dim<D>& self) const {
S
sneaxiy 已提交
38
    return UnrollCompare<D>::Run(self.Get(), d_);
F
fengjiayi 已提交
39 40
  }

S
sneaxiy 已提交
41
  const int64_t* d_;
F
fengjiayi 已提交
42 43
};

S
sneaxiy 已提交
44
bool DDim::operator==(const DDim& d) const {
S
sneaxiy 已提交
45 46
  return size() == d.size() &&
         this->apply_visitor(DDimEqualityVisitor(d.Get()));
F
fengjiayi 已提交
47 48
}

S
sneaxiy 已提交
49
bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
F
fengjiayi 已提交
50

L
liuwei1031 已提交
51 52 53 54 55 56 57 58 59 60
std::string DDim::to_str() const {
  std::stringstream ss;
  ss << '[';
  if (rank_ > 0) ss << dim_[0];

  for (int i = 1; i < rank_; ++i) ss << ", " << dim_[i];
  ss << ']';
  return ss.str();
}

S
sneaxiy 已提交
61
struct ProductVisitor {
F
fengjiayi 已提交
62
  template <int D>
S
sneaxiy 已提交
63
  inline int64_t operator()(const Dim<D>& dim) {
F
fengjiayi 已提交
64
    return product(dim);
F
fengjiayi 已提交
65
  }
F
fengjiayi 已提交
66 67
};

Q
qijun 已提交
68
int64_t product(const DDim& ddim) {
S
sneaxiy 已提交
69
  return ddim.apply_visitor(ProductVisitor());
F
fengjiayi 已提交
70 71
}

H
Hongyu Liu 已提交
72 73 74 75 76 77 78 79 80 81
bool contain_unknown_dim(const DDim& ddim) {
  for (int i = 0; i < ddim.size(); ++i) {
    if (ddim[i] < 0) {
      return true;
    }
  }

  return false;
}

82
DDim slice_ddim(const DDim& dim, int begin, int end) {
83
  PADDLE_ENFORCE_EQ(
84 85
      (begin >= 0 && end <= dim.size()),
      true,
86
      phi::errors::InvalidArgument(
87 88 89 90
          "[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.",
          begin,
          end,
          dim.size()));
S
sneaxiy 已提交
91 92
  // Constructor of DDim would check whether end - begin is valid
  return DDim(dim.Get() + begin, end - begin);
S
sneaxiy 已提交
93
}
F
fengjiayi 已提交
94

S
sneaxiy 已提交
95
int arity(const DDim& d) { return d.size(); }
F
fengjiayi 已提交
96

S
sneaxiy 已提交
97
struct DDimPrinter {
F
fengjiayi 已提交
98
  std::ostream& os;
L
liaogang 已提交
99
  explicit DDimPrinter(std::ostream& os_) : os(os_) {}
F
fengjiayi 已提交
100

S
sneaxiy 已提交
101 102
  template <int D>
  void operator()(const Dim<D>& t) {
F
fengjiayi 已提交
103 104 105 106
    os << t;
  }
};

107
std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
S
sneaxiy 已提交
108
  ddim.apply_visitor(DDimPrinter(os));
F
fengjiayi 已提交
109 110 111
  return os;
}

112
DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims) {
113 114 115
  PADDLE_ENFORCE_GE(
      src.size(),
      3,
116 117 118
      phi::errors::InvalidArgument("The rank of src dim should be at least 3 "
                                   "in flatten_to_3d, but received %d.",
                                   src.size()));
119 120 121
  PADDLE_ENFORCE_EQ(
      (num_row_dims >= 1 && num_row_dims < src.size()),
      true,
122 123 124 125
      phi::errors::InvalidArgument("The num_row_dims should be inside [1, %d] "
                                   "in flatten_to_3d, but received %d.",
                                   src.size() - 1,
                                   num_row_dims));
126 127 128
  PADDLE_ENFORCE_EQ(
      (num_col_dims >= 2 && num_col_dims <= src.size()),
      true,
129 130 131 132
      phi::errors::InvalidArgument("The num_col_dims should be inside [2, %d] "
                                   "in flatten_to_3d, but received %d.",
                                   src.size(),
                                   num_col_dims));
133
  PADDLE_ENFORCE_GE(
134 135
      num_col_dims,
      num_row_dims,
136
      phi::errors::InvalidArgument(
137 138
          "The num_row_dims should be less than num_col_dims in flatten_to_3d,"
          "but received num_row_dims = %d, num_col_dims = %d.",
139 140
          num_row_dims,
          num_col_dims));
141 142 143 144 145 146

  return DDim({product(slice_ddim(src, 0, num_row_dims)),
               product(slice_ddim(src, num_row_dims, num_col_dims)),
               product(slice_ddim(src, num_col_dims, src.size()))});
}

F
fengjiayi 已提交
147
DDim flatten_to_2d(const DDim& src, int num_col_dims) {
S
sneaxiy 已提交
148 149
  return DDim({product(slice_ddim(src, 0, num_col_dims)),
               product(slice_ddim(src, num_col_dims, src.size()))});
150 151
}

S
sneaxiy 已提交
152
DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); }
F
fengjiayi 已提交
153

W
wanghaoshuang 已提交
154
DDim stride(const DDim& ddim) {
S
sneaxiy 已提交
155 156
  DDim strides;
  strides.rank_ = ddim.size();
W
wanghaoshuang 已提交
157 158 159 160
  strides[ddim.size() - 1] = 1;
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i + 1];
  }
S
sneaxiy 已提交
161
  return strides;
W
wanghaoshuang 已提交
162
}
Y
Yancey1989 已提交
163

S
sneaxiy 已提交
164
DDim stride_numel(const DDim& ddim) {
S
sneaxiy 已提交
165 166
  DDim strides;
  strides.rank_ = ddim.size();
Y
Yancey1989 已提交
167 168 169 170
  strides[ddim.size() - 1] = ddim[ddim.size() - 1];
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i];
  }
S
sneaxiy 已提交
171
  return strides;
Y
Yancey1989 已提交
172 173
}

174
DDim DDim::reshape(std::vector<int>& shape) const {
175
  const DDim& in_dims = *this;
176 177 178 179 180 181

  for (uint64_t i = 0; i < shape.size(); ++i) {
    if (shape[i] == 0) {
      shape[i] = in_dims.at(i);
    }
  }
182 183 184 185 186 187 188 189 190 191

  // Dim marked as "-1" must be inferred
  auto it = std::find(shape.begin(), shape.end(), -1);
  if (it != shape.end()) {
    int index = std::distance(shape.begin(), it);
    int reshape_out_product =
        std::accumulate(shape.begin(), shape.end(), -1, std::multiplies<int>());
    shape[index] = product(in_dims) / reshape_out_product;
  }

192
  return phi::make_ddim(shape);
193 194 195 196 197 198
}

DDim DDim::transpose(const std::vector<int>& axis) const {
  const DDim& in_dims = *this;

  DDim out_dims(in_dims);
199
  for (size_t i = 0; i < axis.size(); i++) {
200 201 202 203 204
    out_dims[i] = in_dims[axis[i]];
  }
  return out_dims;
}

205
}  // namespace phi
206 207 208 209 210 211 212 213 214 215 216 217 218

namespace std {

std::size_t hash<phi::DDim>::operator()(phi::DDim const& ddim) const {
  int ndim = ddim.size();
  std::size_t seed = ndim;
  for (int i = 0; i < ndim; ++i) {
    seed ^= ddim.Get()[i] + 0x9e3779b9 + (seed << 6) + (seed >> 2);
  }
  return seed;
}

}  // namespace std