meta_tensor.cc 10.0 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/meta_tensor.h"
16

17 18
#include "glog/logging.h"

19
#include "paddle/phi/core/dense_tensor.h"
20

21 22
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
23 24
#include "paddle/phi/core/string_tensor.h"
#include "paddle/phi/core/string_tensor_utils.h"
25
#include "paddle/phi/core/tensor_utils.h"
26 27 28
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
29

30
namespace phi {
31

32 33 34 35 36 37
static inline void ValidCheck(const MetaTensor& meta_tensor) {
  PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
                    true,
                    phi::errors::InvalidArgument(
                        "The current MetaTensor is not initialized."));
}
38

39 40 41 42
int64_t MetaTensor::numel() const {
  ValidCheck(*this);
  return tensor_->numel();
}
43

44 45
DDim MetaTensor::dims() const {
  ValidCheck(*this);
46 47 48 49 50
  if (phi::SelectedRows::classof(tensor_)) {
    return static_cast<SelectedRows*>(tensor_)->GetCompleteDims();
  } else {
    return tensor_->dims();
  }
51
}
52

W
wanghuancoder 已提交
53 54 55 56 57 58 59 60
DDim MetaTensor::strides() const {
  ValidCheck(*this);
  if (dynamic_cast<DenseTensor*>(tensor_)) {
    return dynamic_cast<DenseTensor*>(tensor_)->strides();
  }
  return DDim();
}

61 62 63 64 65 66 67 68 69
DataType MetaTensor::dtype() const {
  ValidCheck(*this);
  return tensor_->dtype();
}

DataLayout MetaTensor::layout() const {
  ValidCheck(*this);
  return tensor_->layout();
}
70 71

void MetaTensor::set_dims(const DDim& dims) {
72
  ValidCheck(*this);
73
  if (phi::DenseTensor::classof(tensor_)) {
W
wanghuancoder 已提交
74 75 76 77 78 79
    auto meta =
        DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_));
    meta->dims = dims;
    if (!strided_kernel_used_) {
      meta->strides = meta->calc_strides(dims);
    }
J
Jack Zhou 已提交
80 81 82
  } else if (phi::StringTensor::classof(tensor_)) {
    StringTensorUtils::GetMutableMeta(static_cast<StringTensor*>(tensor_))
        ->dims = dims;
83
  } else if (phi::SelectedRows::classof(tensor_)) {
84
    static_cast<SelectedRows*>(tensor_)->set_height(dims[0]);
Z
zhangkaihuo 已提交
85 86 87 88 89 90
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->dims = dims;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->dims = dims;
91 92
#ifdef PADDLE_WITH_DISTRIBUTE
  } else if (phi::distributed::DistTensor::classof(tensor_)) {
93
    static_cast<distributed::DistTensor*>(tensor_)->set_dims(dims);
94
#endif
95
  } else {
96
    PADDLE_THROW(phi::errors::Unimplemented(
97 98 99 100
        "Unsupported setting dims for `%s`.", tensor_->type_info().name()));
  }
}

W
wanghuancoder 已提交
101 102 103 104 105 106 107 108
void MetaTensor::set_strides(const DDim& strides) {
  ValidCheck(*this);
  if (dynamic_cast<DenseTensor*>(tensor_)) {
    DenseTensorUtils::GetMutableMeta(dynamic_cast<DenseTensor*>(tensor_))
        ->strides = strides;
  }
}

109
void MetaTensor::set_dtype(DataType dtype) {
110
  ValidCheck(*this);
111
  if (phi::DenseTensor::classof(tensor_)) {
112
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
113
        ->dtype = dtype;
J
Jack Zhou 已提交
114 115
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set dtype
116
  } else if (phi::SelectedRows::classof(tensor_)) {
117 118 119
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->dtype = dtype;
Z
zhangkaihuo 已提交
120 121 122 123 124 125
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->dtype = dtype;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->dtype = dtype;
126 127
#ifdef PADDLE_WITH_DISTRIBUTE
  } else if (phi::distributed::DistTensor::classof(tensor_)) {
128
    // skip, DistTensor no need to set dtype
129
#endif
130
  } else {
131
    PADDLE_THROW(phi::errors::Unimplemented(
132 133 134 135 136
        "Unsupported settting dtype for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::set_layout(DataLayout layout) {
137
  ValidCheck(*this);
138
  if (phi::DenseTensor::classof(tensor_)) {
W
wanghuancoder 已提交
139 140 141 142 143 144
    auto meta =
        DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_));
    meta->layout = layout;
    if (!strided_kernel_used_) {
      meta->strides = meta->calc_strides(meta->dims);
    }
J
Jack Zhou 已提交
145 146
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set layout
147
  } else if (phi::SelectedRows::classof(tensor_)) {
W
wanghuancoder 已提交
148 149 150 151 152 153
    auto meta = DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value());
    meta->layout = layout;
    if (!strided_kernel_used_) {
      meta->strides = meta->calc_strides(meta->dims);
    }
Z
zhangkaihuo 已提交
154 155 156 157 158 159
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->layout = layout;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->layout = layout;
160 161
#ifdef PADDLE_WITH_DISTRIBUTE
  } else if (phi::distributed::DistTensor::classof(tensor_)) {
162
    // skip, DistTensor no need to set dtype
163
#endif
164
  } else {
165
    PADDLE_THROW(phi::errors::Unimplemented(
166 167 168 169 170
        "Unsupported settting layout for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
171 172
  ValidCheck(*this);
  ValidCheck(meta_tensor);
Z
zhangkaihuo 已提交
173
  if (phi::SparseCooTensor::classof(tensor_) ||
174 175 176 177 178
      phi::SparseCsrTensor::classof(tensor_)
#ifdef PADDLE_WITH_DISTRIBUTE
      || phi::distributed::DistTensor::classof(tensor_)
#endif
  ) {
Z
zhangkaihuo 已提交
179 180
    return;
  }
181
  if (meta_tensor.lod().empty()) {
H
hong 已提交
182 183 184
    // no need share
    return;
  }
185
  if (phi::DenseTensor::classof(tensor_)) {
186 187
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod =
        meta_tensor.lod();
188
  } else if (phi::SelectedRows::classof(tensor_)) {
189 190 191
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->lod = meta_tensor.lod();
192
  } else {
193
    PADDLE_THROW(
194 195
        phi::errors::Unimplemented("Unsupported sharing lod inplace for `%s`.",
                                   tensor_->type_info().name()));
196 197 198
  }
}

199
void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
200
  ValidCheck(*this);
Y
YuanRisheng 已提交
201
  if (phi::DenseTensor::classof(tensor_) ||
Z
zhangkaihuo 已提交
202 203
      phi::SelectedRows::classof(tensor_) ||
      phi::SparseCooTensor::classof(tensor_) ||
204 205 206 207 208
      phi::SparseCsrTensor::classof(tensor_)
#ifdef PADDLE_WITH_DISTRIBUTE
      || phi::distributed::DistTensor::classof(tensor_)
#endif
  ) {
Y
YuanRisheng 已提交
209
    share_dims(meta_tensor);
210 211 212
    set_dtype(meta_tensor.dtype());
    set_layout(meta_tensor.layout());
    share_lod(meta_tensor);
213
  } else {
214
    PADDLE_THROW(phi::errors::Unimplemented(
215
        "Unsupported sharing meta for `%s`.", tensor_->type_info().name()));
216 217 218
  }
}

Y
YuanRisheng 已提交
219 220 221 222 223 224 225 226
TensorBase* MetaTensor::tensor() const { return tensor_; }

bool MetaTensor::is_dense() const { return DenseTensor::classof(tensor_); }
bool MetaTensor::is_selected_rows() const {
  return SelectedRows::classof(tensor_);
}
bool MetaTensor::is_tensor_array() const { return false; }

Y
YuanRisheng 已提交
227
void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
228
  ValidCheck(*this);
Y
YuanRisheng 已提交
229 230
  bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
  bool is_selected_rows = phi::SelectedRows::classof(tensor_);
Z
zhangkaihuo 已提交
231 232
  bool is_sparse_coo = phi::SparseCooTensor::classof(tensor_);
  bool is_sparse_csr = phi::SparseCsrTensor::classof(tensor_);
233 234 235 236 237 238
  bool is_dist_tensor = false;
#ifdef PADDLE_WITH_DISTRIBUTE
  is_dist_tensor = phi::distributed::DistTensor::classof(tensor_);
#endif
  if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr ||
      is_dist_tensor) {
Y
YuanRisheng 已提交
239
    if (is_selected_rows) {
C
Chen Weihang 已提交
240
      const auto in_tensor_base = meta_tensor.tensor();
Y
YuanRisheng 已提交
241 242 243 244 245 246 247 248 249
      PADDLE_ENFORCE_EQ(
          phi::SelectedRows::classof(in_tensor_base),
          true,
          errors::InvalidArgument("The input MetaTensor is SelectedRows, but "
                                  "the output MetaTensor is not this type."));
      auto* selected_rows_out = static_cast<SelectedRows*>(tensor_);
      auto* selected_rows_in = static_cast<SelectedRows*>(in_tensor_base);
      selected_rows_out->set_rows(selected_rows_in->rows());
      selected_rows_out->set_height(selected_rows_in->height());
W
wanghuancoder 已提交
250 251 252 253 254 255
      auto meta = DenseTensorUtils::GetMutableMeta(
          static_cast<SelectedRows*>(tensor_)->mutable_value());
      meta->dims = selected_rows_in->mutable_value()->dims();
      if (!strided_kernel_used_) {
        meta->strides = meta->calc_strides(meta->dims);
      }
256 257
    } else {
      set_dims(meta_tensor.dims());
Y
YuanRisheng 已提交
258 259 260 261 262 263 264
    }
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "Unsupported sharing dims for `%s`.", tensor_->type_info().name()));
  }
}

W
wanghuancoder 已提交
265 266 267 268 269 270 271
void MetaTensor::share_strides(const MetaTensor& meta_tensor) {
  ValidCheck(*this);
  if (phi::DenseTensor::classof(tensor_)) {
    set_strides(meta_tensor.strides());
  }
}

272 273
bool MetaTensor::initialized() const { return tensor_ != nullptr; }

274 275 276 277 278 279 280
// Private Member Methods

const LoD& MetaTensor::lod() const {
  if (phi::DenseTensor::classof(tensor_)) {
    return static_cast<DenseTensor*>(tensor_)->lod();
  } else if (phi::SelectedRows::classof(tensor_)) {
    return static_cast<SelectedRows*>(tensor_)->value().lod();
Z
zhangkaihuo 已提交
281 282 283 284
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    return static_cast<SparseCooTensor*>(tensor_)->non_zero_elements().lod();
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    return static_cast<SparseCsrTensor*>(tensor_)->non_zero_elements().lod();
285 286 287 288 289 290
  } else {
    PADDLE_THROW(phi::errors::Unimplemented("Unsupported getting lod of `%s`.",
                                            tensor_->type_info().name()));
  }
}

291
}  // namespace phi