meta_tensor.cc 8.1 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/meta_tensor.h"
16

17 18
#include "glog/logging.h"

19 20 21
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
22 23
#include "paddle/phi/core/string_tensor.h"
#include "paddle/phi/core/string_tensor_utils.h"
24
#include "paddle/phi/core/tensor_utils.h"
25

26
namespace phi {
27

28 29 30 31 32 33
static inline void ValidCheck(const MetaTensor& meta_tensor) {
  PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
                    true,
                    phi::errors::InvalidArgument(
                        "The current MetaTensor is not initialized."));
}
34

35 36 37 38
int64_t MetaTensor::numel() const {
  ValidCheck(*this);
  return tensor_->numel();
}
39

40 41
DDim MetaTensor::dims() const {
  ValidCheck(*this);
42 43 44 45 46
  if (phi::SelectedRows::classof(tensor_)) {
    return static_cast<SelectedRows*>(tensor_)->GetCompleteDims();
  } else {
    return tensor_->dims();
  }
47
}
48

49 50 51 52 53 54 55 56 57
DataType MetaTensor::dtype() const {
  ValidCheck(*this);
  return tensor_->dtype();
}

DataLayout MetaTensor::layout() const {
  ValidCheck(*this);
  return tensor_->layout();
}
58 59

void MetaTensor::set_dims(const DDim& dims) {
60
  ValidCheck(*this);
61
  if (phi::DenseTensor::classof(tensor_)) {
62 63
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->dims =
        dims;
J
Jack Zhou 已提交
64 65 66
  } else if (phi::StringTensor::classof(tensor_)) {
    StringTensorUtils::GetMutableMeta(static_cast<StringTensor*>(tensor_))
        ->dims = dims;
67
  } else if (phi::SelectedRows::classof(tensor_)) {
68
    static_cast<SelectedRows*>(tensor_)->set_height(dims[0]);
Z
zhangkaihuo 已提交
69 70 71 72 73 74
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->dims = dims;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->dims = dims;
75
  } else {
76
    PADDLE_THROW(phi::errors::Unimplemented(
77 78 79 80 81
        "Unsupported setting dims for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::set_dtype(DataType dtype) {
82
  ValidCheck(*this);
83
  if (phi::DenseTensor::classof(tensor_)) {
84
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
85
        ->dtype = dtype;
J
Jack Zhou 已提交
86 87
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set dtype
88
  } else if (phi::SelectedRows::classof(tensor_)) {
89 90 91
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->dtype = dtype;
Z
zhangkaihuo 已提交
92 93 94 95 96 97 98
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->dtype = dtype;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->dtype = dtype;
    // No need to set dtype
99
  } else {
100
    PADDLE_THROW(phi::errors::Unimplemented(
101 102 103 104 105
        "Unsupported settting dtype for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::set_layout(DataLayout layout) {
106
  ValidCheck(*this);
107
  if (phi::DenseTensor::classof(tensor_)) {
108
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
109
        ->layout = layout;
J
Jack Zhou 已提交
110 111
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set layout
112
  } else if (phi::SelectedRows::classof(tensor_)) {
113 114 115
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->layout = layout;
Z
zhangkaihuo 已提交
116 117 118 119 120 121
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->layout = layout;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->layout = layout;
122
  } else {
123
    PADDLE_THROW(phi::errors::Unimplemented(
124 125 126 127 128
        "Unsupported settting layout for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
129 130
  ValidCheck(*this);
  ValidCheck(meta_tensor);
Z
zhangkaihuo 已提交
131 132 133 134
  if (phi::SparseCooTensor::classof(tensor_) ||
      phi::SparseCsrTensor::classof(tensor_)) {
    return;
  }
135
  if (meta_tensor.lod().empty()) {
H
hong 已提交
136 137 138
    // no need share
    return;
  }
139
  if (phi::DenseTensor::classof(tensor_)) {
140 141
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod =
        meta_tensor.lod();
142
  } else if (phi::SelectedRows::classof(tensor_)) {
143 144 145
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->lod = meta_tensor.lod();
146
  } else {
147
    PADDLE_THROW(
148 149
        phi::errors::Unimplemented("Unsupported sharing lod inplace for `%s`.",
                                   tensor_->type_info().name()));
150 151 152
  }
}

153
void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
154
  ValidCheck(*this);
Y
YuanRisheng 已提交
155
  if (phi::DenseTensor::classof(tensor_) ||
Z
zhangkaihuo 已提交
156 157 158
      phi::SelectedRows::classof(tensor_) ||
      phi::SparseCooTensor::classof(tensor_) ||
      phi::SparseCsrTensor::classof(tensor_)) {
Y
YuanRisheng 已提交
159
    share_dims(meta_tensor);
160 161 162
    set_dtype(meta_tensor.dtype());
    set_layout(meta_tensor.layout());
    share_lod(meta_tensor);
163
  } else {
164
    PADDLE_THROW(phi::errors::Unimplemented(
165
        "Unsupported sharing meta for `%s`.", tensor_->type_info().name()));
166 167 168
  }
}

Y
YuanRisheng 已提交
169 170 171 172 173 174 175 176
TensorBase* MetaTensor::tensor() const { return tensor_; }

bool MetaTensor::is_dense() const { return DenseTensor::classof(tensor_); }
bool MetaTensor::is_selected_rows() const {
  return SelectedRows::classof(tensor_);
}
bool MetaTensor::is_tensor_array() const { return false; }

Y
YuanRisheng 已提交
177
void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
178
  ValidCheck(*this);
Y
YuanRisheng 已提交
179 180
  bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
  bool is_selected_rows = phi::SelectedRows::classof(tensor_);
Z
zhangkaihuo 已提交
181 182 183
  bool is_sparse_coo = phi::SparseCooTensor::classof(tensor_);
  bool is_sparse_csr = phi::SparseCsrTensor::classof(tensor_);
  if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr) {
Y
YuanRisheng 已提交
184
    if (is_selected_rows) {
C
Chen Weihang 已提交
185
      const auto in_tensor_base = meta_tensor.tensor();
Y
YuanRisheng 已提交
186 187 188 189 190 191 192 193 194
      PADDLE_ENFORCE_EQ(
          phi::SelectedRows::classof(in_tensor_base),
          true,
          errors::InvalidArgument("The input MetaTensor is SelectedRows, but "
                                  "the output MetaTensor is not this type."));
      auto* selected_rows_out = static_cast<SelectedRows*>(tensor_);
      auto* selected_rows_in = static_cast<SelectedRows*>(in_tensor_base);
      selected_rows_out->set_rows(selected_rows_in->rows());
      selected_rows_out->set_height(selected_rows_in->height());
195 196 197 198 199
      DenseTensorUtils::GetMutableMeta(
          static_cast<SelectedRows*>(tensor_)->mutable_value())
          ->dims = selected_rows_in->mutable_value()->dims();
    } else {
      set_dims(meta_tensor.dims());
Y
YuanRisheng 已提交
200 201 202 203 204 205 206
    }
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "Unsupported sharing dims for `%s`.", tensor_->type_info().name()));
  }
}

207 208
bool MetaTensor::initialized() const { return tensor_ != nullptr; }

209 210 211 212 213 214 215
// Private Member Methods

const LoD& MetaTensor::lod() const {
  if (phi::DenseTensor::classof(tensor_)) {
    return static_cast<DenseTensor*>(tensor_)->lod();
  } else if (phi::SelectedRows::classof(tensor_)) {
    return static_cast<SelectedRows*>(tensor_)->value().lod();
Z
zhangkaihuo 已提交
216 217 218 219
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    return static_cast<SparseCooTensor*>(tensor_)->non_zero_elements().lod();
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    return static_cast<SparseCsrTensor*>(tensor_)->non_zero_elements().lod();
220 221 222 223 224 225
  } else {
    PADDLE_THROW(phi::errors::Unimplemented("Unsupported getting lod of `%s`.",
                                            tensor_->type_info().name()));
  }
}

226
}  // namespace phi