ddim.cc 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pten/core/ddim.h"
16

17
#include <set>
F
fengjiayi 已提交
18

19
namespace pten {
20
namespace framework {
F
fengjiayi 已提交
21

Q
qijun 已提交
22
DDim make_ddim(std::initializer_list<int64_t> dims) {
S
sneaxiy 已提交
23
  return DDim(dims.begin(), dims.size());
F
fengjiayi 已提交
24 25
}

Q
qijun 已提交
26
DDim make_ddim(const std::vector<int64_t>& dims) {
S
sneaxiy 已提交
27
  return DDim(dims.data(), dims.size());
F
fengjiayi 已提交
28 29
}

Y
Yu Yang 已提交
30
DDim make_ddim(const std::vector<int>& dims) {
S
sneaxiy 已提交
31
  return DDim(dims.data(), dims.size());
Y
Yu Yang 已提交
32 33
}

S
sneaxiy 已提交
34 35
struct DDimEqualityVisitor {
  explicit DDimEqualityVisitor(const int64_t* d) : d_(d) {}
F
fengjiayi 已提交
36 37

  template <int D>
S
sneaxiy 已提交
38
  inline bool operator()(const Dim<D>& self) const {
S
sneaxiy 已提交
39
    return UnrollCompare<D>::Run(self.Get(), d_);
F
fengjiayi 已提交
40 41
  }

S
sneaxiy 已提交
42
  const int64_t* d_;
F
fengjiayi 已提交
43 44
};

S
sneaxiy 已提交
45
bool DDim::operator==(const DDim& d) const {
S
sneaxiy 已提交
46 47
  return size() == d.size() &&
         this->apply_visitor(DDimEqualityVisitor(d.Get()));
F
fengjiayi 已提交
48 49
}

S
sneaxiy 已提交
50
bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
F
fengjiayi 已提交
51

L
liuwei1031 已提交
52 53 54 55 56 57 58 59 60 61
std::string DDim::to_str() const {
  std::stringstream ss;
  ss << '[';
  if (rank_ > 0) ss << dim_[0];

  for (int i = 1; i < rank_; ++i) ss << ", " << dim_[i];
  ss << ']';
  return ss.str();
}

S
sneaxiy 已提交
62
struct ProductVisitor {
F
fengjiayi 已提交
63
  template <int D>
S
sneaxiy 已提交
64
  inline int64_t operator()(const Dim<D>& dim) {
F
fengjiayi 已提交
65
    return product(dim);
F
fengjiayi 已提交
66
  }
F
fengjiayi 已提交
67 68
};

Q
qijun 已提交
69
int64_t product(const DDim& ddim) {
S
sneaxiy 已提交
70
  return ddim.apply_visitor(ProductVisitor());
F
fengjiayi 已提交
71 72
}

H
Hongyu Liu 已提交
73 74 75 76 77 78 79 80 81 82
bool contain_unknown_dim(const DDim& ddim) {
  for (int i = 0; i < ddim.size(); ++i) {
    if (ddim[i] < 0) {
      return true;
    }
  }

  return false;
}

83
DDim slice_ddim(const DDim& dim, int begin, int end) {
84
  PADDLE_ENFORCE_EQ(
85 86
      (begin >= 0 && end <= dim.size()),
      true,
87
      pten::errors::InvalidArgument(
88 89 90 91
          "[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.",
          begin,
          end,
          dim.size()));
S
sneaxiy 已提交
92 93
  // Constructor of DDim would check whether end - begin is valid
  return DDim(dim.Get() + begin, end - begin);
S
sneaxiy 已提交
94
}
F
fengjiayi 已提交
95

S
sneaxiy 已提交
96
int arity(const DDim& d) { return d.size(); }
F
fengjiayi 已提交
97

S
sneaxiy 已提交
98
struct DDimPrinter {
F
fengjiayi 已提交
99
  std::ostream& os;
L
liaogang 已提交
100
  explicit DDimPrinter(std::ostream& os_) : os(os_) {}
F
fengjiayi 已提交
101

S
sneaxiy 已提交
102 103
  template <int D>
  void operator()(const Dim<D>& t) {
F
fengjiayi 已提交
104 105 106 107
    os << t;
  }
};

108
std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
S
sneaxiy 已提交
109
  ddim.apply_visitor(DDimPrinter(os));
F
fengjiayi 已提交
110 111 112
  return os;
}

113
DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims) {
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
  PADDLE_ENFORCE_GE(
      src.size(),
      3,
      pten::errors::InvalidArgument("The rank of src dim should be at least 3 "
                                    "in flatten_to_3d, but received %d.",
                                    src.size()));
  PADDLE_ENFORCE_EQ(
      (num_row_dims >= 1 && num_row_dims < src.size()),
      true,
      pten::errors::InvalidArgument("The num_row_dims should be inside [1, %d] "
                                    "in flatten_to_3d, but received %d.",
                                    src.size() - 1,
                                    num_row_dims));
  PADDLE_ENFORCE_EQ(
      (num_col_dims >= 2 && num_col_dims <= src.size()),
      true,
      pten::errors::InvalidArgument("The num_col_dims should be inside [2, %d] "
                                    "in flatten_to_3d, but received %d.",
                                    src.size(),
                                    num_col_dims));
134
  PADDLE_ENFORCE_GE(
135 136
      num_col_dims,
      num_row_dims,
137
      pten::errors::InvalidArgument(
138 139
          "The num_row_dims should be less than num_col_dims in flatten_to_3d,"
          "but received num_row_dims = %d, num_col_dims = %d.",
140 141
          num_row_dims,
          num_col_dims));
142 143 144 145 146 147

  return DDim({product(slice_ddim(src, 0, num_row_dims)),
               product(slice_ddim(src, num_row_dims, num_col_dims)),
               product(slice_ddim(src, num_col_dims, src.size()))});
}

F
fengjiayi 已提交
148
DDim flatten_to_2d(const DDim& src, int num_col_dims) {
S
sneaxiy 已提交
149 150
  return DDim({product(slice_ddim(src, 0, num_col_dims)),
               product(slice_ddim(src, num_col_dims, src.size()))});
151 152
}

S
sneaxiy 已提交
153
DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); }
F
fengjiayi 已提交
154

W
wanghaoshuang 已提交
155
DDim stride(const DDim& ddim) {
S
sneaxiy 已提交
156 157
  DDim strides;
  strides.rank_ = ddim.size();
W
wanghaoshuang 已提交
158 159 160 161
  strides[ddim.size() - 1] = 1;
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i + 1];
  }
S
sneaxiy 已提交
162
  return strides;
W
wanghaoshuang 已提交
163
}
Y
Yancey1989 已提交
164

S
sneaxiy 已提交
165
DDim stride_numel(const DDim& ddim) {
S
sneaxiy 已提交
166 167
  DDim strides;
  strides.rank_ = ddim.size();
Y
Yancey1989 已提交
168 169 170 171
  strides[ddim.size() - 1] = ddim[ddim.size() - 1];
  for (int i = ddim.size() - 2; i >= 0; --i) {
    strides[i] = strides[i + 1] * ddim[i];
  }
S
sneaxiy 已提交
172
  return strides;
Y
Yancey1989 已提交
173 174
}

175 176 177 178 179 180 181
DDim DDim::reshape(const std::vector<int>& shape) const {
  const int64_t copy_dim_val = 0;
  const DDim& in_dims = *this;
  DDim out_dims;
  out_dims.rank_ = shape.size();
  for (size_t i = 0; i < shape.size(); ++i) {
    if (shape[i] == copy_dim_val) {
182 183
      PADDLE_ENFORCE_LT(static_cast<int>(i),
                        in_dims.size(),
184
                        pten::errors::InvalidArgument(
185 186 187 188
                            "Index %d of shape under which the value of 0 "
                            "is stored, must be lower than the number of "
                            "old dimensions. But received shape[%d] = 0, "
                            "dimensions = %d, shape = [%s].",
189 190 191
                            i,
                            in_dims.size(),
                            in_dims));
192 193 194 195 196 197 198 199 200 201 202 203 204
      out_dims[i] = in_dims[i];
    } else {
      out_dims[i] = shape[i];
    }
  }
  return out_dims;
}

DDim DDim::transpose(const std::vector<int>& axis) const {
  const DDim& in_dims = *this;
  size_t in_rank = in_dims.size();
  size_t axis_size = axis.size();

205
  auto axis_set = std::set<int>(axis.begin(), axis.end());
206 207
  PADDLE_ENFORCE_EQ(axis_set.size(),
                    axis_size,
208
                    pten::errors::InvalidArgument(
209 210
                        "In an axis array, elements must be unique."));

211
  PADDLE_ENFORCE_EQ(
212 213
      in_rank,
      axis_size,
214 215 216 217 218 219
      pten::errors::InvalidArgument("The input dimension's size "
                                    "should be equal to the axis's size. "
                                    "But received dimension is %d, "
                                    "axis's size is %d",
                                    in_rank,
                                    axis_size));
220

221 222
  PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
                    axis_size,
223
                    pten::errors::InvalidArgument(
224
                        "Axis values must be ranging from 0 to (dims - 1)."));
225 226 227 228 229 230 231 232

  DDim out_dims(in_dims);
  for (size_t i = 0; i < axis_size; i++) {
    out_dims[i] = in_dims[axis[i]];
  }
  return out_dims;
}

233
}  // namespace framework
234
}  // namespace pten