meta_tensor.cc 9.1 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/meta_tensor.h"
16

17 18
#include "glog/logging.h"

19 20 21
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
22 23
#include "paddle/phi/core/string_tensor.h"
#include "paddle/phi/core/string_tensor_utils.h"
24
#include "paddle/phi/core/tensor_utils.h"
25

26
namespace phi {
27

28 29 30 31 32 33
static inline void ValidCheck(const MetaTensor& meta_tensor) {
  PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
                    true,
                    phi::errors::InvalidArgument(
                        "The current MetaTensor is not initialized."));
}
34

35 36 37 38
int64_t MetaTensor::numel() const {
  ValidCheck(*this);
  return tensor_->numel();
}
39

40 41
DDim MetaTensor::dims() const {
  ValidCheck(*this);
42 43 44 45 46
  if (phi::SelectedRows::classof(tensor_)) {
    return static_cast<SelectedRows*>(tensor_)->GetCompleteDims();
  } else {
    return tensor_->dims();
  }
47
}
48

W
wanghuancoder 已提交
49 50 51 52 53 54 55 56
DDim MetaTensor::strides() const {
  ValidCheck(*this);
  if (dynamic_cast<DenseTensor*>(tensor_)) {
    return dynamic_cast<DenseTensor*>(tensor_)->strides();
  }
  return DDim();
}

57 58 59 60 61 62 63 64 65
DataType MetaTensor::dtype() const {
  ValidCheck(*this);
  return tensor_->dtype();
}

DataLayout MetaTensor::layout() const {
  ValidCheck(*this);
  return tensor_->layout();
}
66 67

void MetaTensor::set_dims(const DDim& dims) {
68
  ValidCheck(*this);
69
  if (phi::DenseTensor::classof(tensor_)) {
W
wanghuancoder 已提交
70 71 72 73 74 75
    auto meta =
        DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_));
    meta->dims = dims;
    if (!strided_kernel_used_) {
      meta->strides = meta->calc_strides(dims);
    }
J
Jack Zhou 已提交
76 77 78
  } else if (phi::StringTensor::classof(tensor_)) {
    StringTensorUtils::GetMutableMeta(static_cast<StringTensor*>(tensor_))
        ->dims = dims;
79
  } else if (phi::SelectedRows::classof(tensor_)) {
80
    static_cast<SelectedRows*>(tensor_)->set_height(dims[0]);
Z
zhangkaihuo 已提交
81 82 83 84 85 86
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->dims = dims;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->dims = dims;
87
  } else {
88
    PADDLE_THROW(phi::errors::Unimplemented(
89 90 91 92
        "Unsupported setting dims for `%s`.", tensor_->type_info().name()));
  }
}

W
wanghuancoder 已提交
93 94 95 96 97 98 99 100
void MetaTensor::set_strides(const DDim& strides) {
  ValidCheck(*this);
  if (dynamic_cast<DenseTensor*>(tensor_)) {
    DenseTensorUtils::GetMutableMeta(dynamic_cast<DenseTensor*>(tensor_))
        ->strides = strides;
  }
}

101
void MetaTensor::set_dtype(DataType dtype) {
102
  ValidCheck(*this);
103
  if (phi::DenseTensor::classof(tensor_)) {
104
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
105
        ->dtype = dtype;
J
Jack Zhou 已提交
106 107
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set dtype
108
  } else if (phi::SelectedRows::classof(tensor_)) {
109 110 111
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->dtype = dtype;
Z
zhangkaihuo 已提交
112 113 114 115 116 117 118
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->dtype = dtype;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->dtype = dtype;
    // No need to set dtype
119
  } else {
120
    PADDLE_THROW(phi::errors::Unimplemented(
121 122 123 124 125
        "Unsupported settting dtype for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::set_layout(DataLayout layout) {
126
  ValidCheck(*this);
127
  if (phi::DenseTensor::classof(tensor_)) {
W
wanghuancoder 已提交
128 129 130 131 132 133
    auto meta =
        DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_));
    meta->layout = layout;
    if (!strided_kernel_used_) {
      meta->strides = meta->calc_strides(meta->dims);
    }
J
Jack Zhou 已提交
134 135
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set layout
136
  } else if (phi::SelectedRows::classof(tensor_)) {
W
wanghuancoder 已提交
137 138 139 140 141 142
    auto meta = DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value());
    meta->layout = layout;
    if (!strided_kernel_used_) {
      meta->strides = meta->calc_strides(meta->dims);
    }
Z
zhangkaihuo 已提交
143 144 145 146 147 148
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->layout = layout;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->layout = layout;
149
  } else {
150
    PADDLE_THROW(phi::errors::Unimplemented(
151 152 153 154 155
        "Unsupported settting layout for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
156 157
  ValidCheck(*this);
  ValidCheck(meta_tensor);
Z
zhangkaihuo 已提交
158 159 160 161
  if (phi::SparseCooTensor::classof(tensor_) ||
      phi::SparseCsrTensor::classof(tensor_)) {
    return;
  }
162
  if (meta_tensor.lod().empty()) {
H
hong 已提交
163 164 165
    // no need share
    return;
  }
166
  if (phi::DenseTensor::classof(tensor_)) {
167 168
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod =
        meta_tensor.lod();
169
  } else if (phi::SelectedRows::classof(tensor_)) {
170 171 172
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->lod = meta_tensor.lod();
173
  } else {
174
    PADDLE_THROW(
175 176
        phi::errors::Unimplemented("Unsupported sharing lod inplace for `%s`.",
                                   tensor_->type_info().name()));
177 178 179
  }
}

180
void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
181
  ValidCheck(*this);
Y
YuanRisheng 已提交
182
  if (phi::DenseTensor::classof(tensor_) ||
Z
zhangkaihuo 已提交
183 184 185
      phi::SelectedRows::classof(tensor_) ||
      phi::SparseCooTensor::classof(tensor_) ||
      phi::SparseCsrTensor::classof(tensor_)) {
Y
YuanRisheng 已提交
186
    share_dims(meta_tensor);
187 188 189
    set_dtype(meta_tensor.dtype());
    set_layout(meta_tensor.layout());
    share_lod(meta_tensor);
190
  } else {
191
    PADDLE_THROW(phi::errors::Unimplemented(
192
        "Unsupported sharing meta for `%s`.", tensor_->type_info().name()));
193 194 195
  }
}

Y
YuanRisheng 已提交
196 197 198 199 200 201 202 203
TensorBase* MetaTensor::tensor() const { return tensor_; }

bool MetaTensor::is_dense() const { return DenseTensor::classof(tensor_); }
bool MetaTensor::is_selected_rows() const {
  return SelectedRows::classof(tensor_);
}
bool MetaTensor::is_tensor_array() const { return false; }

Y
YuanRisheng 已提交
204
void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
205
  ValidCheck(*this);
Y
YuanRisheng 已提交
206 207
  bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
  bool is_selected_rows = phi::SelectedRows::classof(tensor_);
Z
zhangkaihuo 已提交
208 209 210
  bool is_sparse_coo = phi::SparseCooTensor::classof(tensor_);
  bool is_sparse_csr = phi::SparseCsrTensor::classof(tensor_);
  if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr) {
Y
YuanRisheng 已提交
211
    if (is_selected_rows) {
C
Chen Weihang 已提交
212
      const auto in_tensor_base = meta_tensor.tensor();
Y
YuanRisheng 已提交
213 214 215 216 217 218 219 220 221
      PADDLE_ENFORCE_EQ(
          phi::SelectedRows::classof(in_tensor_base),
          true,
          errors::InvalidArgument("The input MetaTensor is SelectedRows, but "
                                  "the output MetaTensor is not this type."));
      auto* selected_rows_out = static_cast<SelectedRows*>(tensor_);
      auto* selected_rows_in = static_cast<SelectedRows*>(in_tensor_base);
      selected_rows_out->set_rows(selected_rows_in->rows());
      selected_rows_out->set_height(selected_rows_in->height());
W
wanghuancoder 已提交
222 223 224 225 226 227
      auto meta = DenseTensorUtils::GetMutableMeta(
          static_cast<SelectedRows*>(tensor_)->mutable_value());
      meta->dims = selected_rows_in->mutable_value()->dims();
      if (!strided_kernel_used_) {
        meta->strides = meta->calc_strides(meta->dims);
      }
228 229
    } else {
      set_dims(meta_tensor.dims());
Y
YuanRisheng 已提交
230 231 232 233 234 235 236
    }
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "Unsupported sharing dims for `%s`.", tensor_->type_info().name()));
  }
}

W
wanghuancoder 已提交
237 238 239 240 241 242 243
void MetaTensor::share_strides(const MetaTensor& meta_tensor) {
  ValidCheck(*this);
  if (phi::DenseTensor::classof(tensor_)) {
    set_strides(meta_tensor.strides());
  }
}

244 245
bool MetaTensor::initialized() const { return tensor_ != nullptr; }

246 247 248 249 250 251 252
// Private Member Methods

const LoD& MetaTensor::lod() const {
  if (phi::DenseTensor::classof(tensor_)) {
    return static_cast<DenseTensor*>(tensor_)->lod();
  } else if (phi::SelectedRows::classof(tensor_)) {
    return static_cast<SelectedRows*>(tensor_)->value().lod();
Z
zhangkaihuo 已提交
253 254 255 256
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    return static_cast<SparseCooTensor*>(tensor_)->non_zero_elements().lod();
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    return static_cast<SparseCsrTensor*>(tensor_)->non_zero_elements().lod();
257 258 259 260 261 262
  } else {
    PADDLE_THROW(phi::errors::Unimplemented("Unsupported getting lod of `%s`.",
                                            tensor_->type_info().name()));
  }
}

263
}  // namespace phi