meta_tensor.cc 9.9 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/meta_tensor.h"
16

17 18
#include "glog/logging.h"

19
#include "paddle/phi/core/dense_tensor.h"
20
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
21 22
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
23 24
#include "paddle/phi/core/string_tensor.h"
#include "paddle/phi/core/string_tensor_utils.h"
25
#include "paddle/phi/core/tensor_utils.h"
26

27
namespace phi {
28

29 30 31 32 33 34
static inline void ValidCheck(const MetaTensor& meta_tensor) {
  PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
                    true,
                    phi::errors::InvalidArgument(
                        "The current MetaTensor is not initialized."));
}
35

36 37 38 39
int64_t MetaTensor::numel() const {
  ValidCheck(*this);
  return tensor_->numel();
}
40

41 42
DDim MetaTensor::dims() const {
  ValidCheck(*this);
43 44 45 46 47
  if (phi::SelectedRows::classof(tensor_)) {
    return static_cast<SelectedRows*>(tensor_)->GetCompleteDims();
  } else {
    return tensor_->dims();
  }
48
}
49

W
wanghuancoder 已提交
50 51 52 53 54 55 56 57
DDim MetaTensor::strides() const {
  ValidCheck(*this);
  if (dynamic_cast<DenseTensor*>(tensor_)) {
    return dynamic_cast<DenseTensor*>(tensor_)->strides();
  }
  return DDim();
}

58 59 60 61 62 63 64 65 66
DataType MetaTensor::dtype() const {
  ValidCheck(*this);
  return tensor_->dtype();
}

DataLayout MetaTensor::layout() const {
  ValidCheck(*this);
  return tensor_->layout();
}
67 68

void MetaTensor::set_dims(const DDim& dims) {
69
  ValidCheck(*this);
70
  if (phi::DenseTensor::classof(tensor_)) {
W
wanghuancoder 已提交
71 72 73 74 75 76
    auto meta =
        DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_));
    meta->dims = dims;
    if (!strided_kernel_used_) {
      meta->strides = meta->calc_strides(dims);
    }
J
Jack Zhou 已提交
77 78 79
  } else if (phi::StringTensor::classof(tensor_)) {
    StringTensorUtils::GetMutableMeta(static_cast<StringTensor*>(tensor_))
        ->dims = dims;
80
  } else if (phi::SelectedRows::classof(tensor_)) {
81
    static_cast<SelectedRows*>(tensor_)->set_height(dims[0]);
Z
zhangkaihuo 已提交
82 83 84 85 86 87
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->dims = dims;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->dims = dims;
88
  } else if (phi::distributed::DistTensor::classof(tensor_)) {
89
    static_cast<distributed::DistTensor*>(tensor_)->set_dims(dims);
90
  } else {
91
    PADDLE_THROW(phi::errors::Unimplemented(
92 93 94 95
        "Unsupported setting dims for `%s`.", tensor_->type_info().name()));
  }
}

W
wanghuancoder 已提交
96 97 98 99 100 101 102 103
void MetaTensor::set_strides(const DDim& strides) {
  ValidCheck(*this);
  if (dynamic_cast<DenseTensor*>(tensor_)) {
    DenseTensorUtils::GetMutableMeta(dynamic_cast<DenseTensor*>(tensor_))
        ->strides = strides;
  }
}

104
void MetaTensor::set_dtype(DataType dtype) {
105
  ValidCheck(*this);
106
  if (phi::DenseTensor::classof(tensor_)) {
107
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
108
        ->dtype = dtype;
J
Jack Zhou 已提交
109 110
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set dtype
111
  } else if (phi::SelectedRows::classof(tensor_)) {
112 113 114
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->dtype = dtype;
Z
zhangkaihuo 已提交
115 116 117 118 119 120
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->dtype = dtype;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->dtype = dtype;
121
  } else if (phi::distributed::DistTensor::classof(tensor_)) {
122
    // skip, DistTensor no need to set dtype
123
  } else {
124
    PADDLE_THROW(phi::errors::Unimplemented(
125 126 127 128 129
        "Unsupported settting dtype for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::set_layout(DataLayout layout) {
130
  ValidCheck(*this);
131
  if (phi::DenseTensor::classof(tensor_)) {
W
wanghuancoder 已提交
132 133 134 135 136 137
    auto meta =
        DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_));
    meta->layout = layout;
    if (!strided_kernel_used_) {
      meta->strides = meta->calc_strides(meta->dims);
    }
J
Jack Zhou 已提交
138 139
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set layout
140
  } else if (phi::SelectedRows::classof(tensor_)) {
W
wanghuancoder 已提交
141 142 143 144 145 146
    auto meta = DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value());
    meta->layout = layout;
    if (!strided_kernel_used_) {
      meta->strides = meta->calc_strides(meta->dims);
    }
Z
zhangkaihuo 已提交
147 148 149 150 151 152
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->layout = layout;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->layout = layout;
153
  } else if (phi::distributed::DistTensor::classof(tensor_)) {
154
    // skip, DistTensor no need to set dtype
155
  } else {
156
    PADDLE_THROW(phi::errors::Unimplemented(
157 158 159 160 161
        "Unsupported settting layout for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
162 163
  ValidCheck(*this);
  ValidCheck(meta_tensor);
Z
zhangkaihuo 已提交
164
  if (phi::SparseCooTensor::classof(tensor_) ||
165 166
      phi::SparseCsrTensor::classof(tensor_) ||
      phi::distributed::DistTensor::classof(tensor_)) {
Z
zhangkaihuo 已提交
167 168
    return;
  }
169
  if (meta_tensor.lod().empty()) {
H
hong 已提交
170 171 172
    // no need share
    return;
  }
173
  if (phi::DenseTensor::classof(tensor_)) {
174 175
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod =
        meta_tensor.lod();
176
  } else if (phi::SelectedRows::classof(tensor_)) {
177 178 179
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->lod = meta_tensor.lod();
180
  } else {
181
    PADDLE_THROW(
182 183
        phi::errors::Unimplemented("Unsupported sharing lod inplace for `%s`.",
                                   tensor_->type_info().name()));
184 185 186
  }
}

187
void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
188
  ValidCheck(*this);
Y
YuanRisheng 已提交
189
  if (phi::DenseTensor::classof(tensor_) ||
Z
zhangkaihuo 已提交
190 191
      phi::SelectedRows::classof(tensor_) ||
      phi::SparseCooTensor::classof(tensor_) ||
192 193
      phi::SparseCsrTensor::classof(tensor_) ||
      phi::distributed::DistTensor::classof(tensor_)) {
Y
YuanRisheng 已提交
194
    share_dims(meta_tensor);
195 196 197
    set_dtype(meta_tensor.dtype());
    set_layout(meta_tensor.layout());
    share_lod(meta_tensor);
198
  } else {
199
    PADDLE_THROW(phi::errors::Unimplemented(
200
        "Unsupported sharing meta for `%s`.", tensor_->type_info().name()));
201 202 203
  }
}

Y
YuanRisheng 已提交
204 205 206 207 208 209
TensorBase* MetaTensor::tensor() const { return tensor_; }

bool MetaTensor::is_dense() const { return DenseTensor::classof(tensor_); }
bool MetaTensor::is_selected_rows() const {
  return SelectedRows::classof(tensor_);
}
210 211 212
bool MetaTensor::is_dist() const {
  return distributed::DistTensor::classof(tensor_);
}
Y
YuanRisheng 已提交
213 214
bool MetaTensor::is_tensor_array() const { return false; }

215 216 217 218
bool MetaTensor::is_same_tensor(const MetaTensor& meta_tensor) const {
  return tensor_ != nullptr && tensor_ == meta_tensor.tensor();
}

Y
YuanRisheng 已提交
219
void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
220
  ValidCheck(*this);
Y
YuanRisheng 已提交
221 222
  bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
  bool is_selected_rows = phi::SelectedRows::classof(tensor_);
Z
zhangkaihuo 已提交
223 224
  bool is_sparse_coo = phi::SparseCooTensor::classof(tensor_);
  bool is_sparse_csr = phi::SparseCsrTensor::classof(tensor_);
225
  bool is_dist_tensor = phi::distributed::DistTensor::classof(tensor_);
226 227
  if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr ||
      is_dist_tensor) {
Y
YuanRisheng 已提交
228
    if (is_selected_rows) {
C
Chen Weihang 已提交
229
      const auto in_tensor_base = meta_tensor.tensor();
Y
YuanRisheng 已提交
230 231 232 233 234 235 236 237 238
      PADDLE_ENFORCE_EQ(
          phi::SelectedRows::classof(in_tensor_base),
          true,
          errors::InvalidArgument("The input MetaTensor is SelectedRows, but "
                                  "the output MetaTensor is not this type."));
      auto* selected_rows_out = static_cast<SelectedRows*>(tensor_);
      auto* selected_rows_in = static_cast<SelectedRows*>(in_tensor_base);
      selected_rows_out->set_rows(selected_rows_in->rows());
      selected_rows_out->set_height(selected_rows_in->height());
W
wanghuancoder 已提交
239 240 241 242 243 244
      auto meta = DenseTensorUtils::GetMutableMeta(
          static_cast<SelectedRows*>(tensor_)->mutable_value());
      meta->dims = selected_rows_in->mutable_value()->dims();
      if (!strided_kernel_used_) {
        meta->strides = meta->calc_strides(meta->dims);
      }
245 246
    } else {
      set_dims(meta_tensor.dims());
Y
YuanRisheng 已提交
247 248 249 250 251 252 253
    }
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "Unsupported sharing dims for `%s`.", tensor_->type_info().name()));
  }
}

W
wanghuancoder 已提交
254 255 256 257 258 259 260
void MetaTensor::share_strides(const MetaTensor& meta_tensor) {
  ValidCheck(*this);
  if (phi::DenseTensor::classof(tensor_)) {
    set_strides(meta_tensor.strides());
  }
}

261 262
bool MetaTensor::initialized() const { return tensor_ != nullptr; }

263 264 265 266 267 268 269
// Private Member Methods

const LoD& MetaTensor::lod() const {
  if (phi::DenseTensor::classof(tensor_)) {
    return static_cast<DenseTensor*>(tensor_)->lod();
  } else if (phi::SelectedRows::classof(tensor_)) {
    return static_cast<SelectedRows*>(tensor_)->value().lod();
Z
zhangkaihuo 已提交
270 271 272 273
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    return static_cast<SparseCooTensor*>(tensor_)->non_zero_elements().lod();
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    return static_cast<SparseCsrTensor*>(tensor_)->non_zero_elements().lod();
274 275 276 277 278 279
  } else {
    PADDLE_THROW(phi::errors::Unimplemented("Unsupported getting lod of `%s`.",
                                            tensor_->type_info().name()));
  }
}

280
}  // namespace phi