meta_tensor.cc 8.1 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/meta_tensor.h"
16

17 18
#include "glog/logging.h"

19 20 21
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
22 23
#include "paddle/phi/core/string_tensor.h"
#include "paddle/phi/core/string_tensor_utils.h"
24
#include "paddle/phi/core/tensor_utils.h"
25

26
namespace phi {
27

28 29 30 31 32 33
static inline void ValidCheck(const MetaTensor& meta_tensor) {
  PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
                    true,
                    phi::errors::InvalidArgument(
                        "The current MetaTensor is not initialized."));
}
34

35 36 37 38
int64_t MetaTensor::numel() const {
  ValidCheck(*this);
  return tensor_->numel();
}
39

40 41
DDim MetaTensor::dims() const {
  ValidCheck(*this);
42 43 44 45 46
  if (phi::SelectedRows::classof(tensor_)) {
    return static_cast<SelectedRows*>(tensor_)->GetCompleteDims();
  } else {
    return tensor_->dims();
  }
47
}
48

49 50 51 52 53 54 55 56 57
DataType MetaTensor::dtype() const {
  ValidCheck(*this);
  return tensor_->dtype();
}

DataLayout MetaTensor::layout() const {
  ValidCheck(*this);
  return tensor_->layout();
}
58 59

void MetaTensor::set_dims(const DDim& dims) {
60
  ValidCheck(*this);
61
  if (phi::DenseTensor::classof(tensor_)) {
62 63
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->dims =
        dims;
J
Jack Zhou 已提交
64 65 66
  } else if (phi::StringTensor::classof(tensor_)) {
    StringTensorUtils::GetMutableMeta(static_cast<StringTensor*>(tensor_))
        ->dims = dims;
67
  } else if (phi::SelectedRows::classof(tensor_)) {
68
    static_cast<SelectedRows*>(tensor_)->set_height(dims[0]);
69 70 71 72 73 74
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->dims = dims;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->dims = dims;
75
  } else {
76
    PADDLE_THROW(phi::errors::Unimplemented(
77 78 79 80 81
        "Unsupported setting dims for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::set_dtype(DataType dtype) {
82
  ValidCheck(*this);
83
  if (phi::DenseTensor::classof(tensor_)) {
84
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
85
        ->dtype = dtype;
J
Jack Zhou 已提交
86 87
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set dtype
88
  } else if (phi::SelectedRows::classof(tensor_)) {
89 90 91
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->dtype = dtype;
92 93 94 95 96 97 98
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->dtype = dtype;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->dtype = dtype;
    // No need to set dtype
99
  } else {
100
    PADDLE_THROW(phi::errors::Unimplemented(
101 102 103 104 105
        "Unsupported settting dtype for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::set_layout(DataLayout layout) {
106
  ValidCheck(*this);
107
  if (phi::DenseTensor::classof(tensor_)) {
108
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
109
        ->layout = layout;
J
Jack Zhou 已提交
110 111
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set layout
112
  } else if (phi::SelectedRows::classof(tensor_)) {
113 114 115
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->layout = layout;
116 117 118 119 120 121
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCooTensor*>(tensor_))
        ->layout = layout;
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    DenseTensorUtils::GetMutableMeta(static_cast<SparseCsrTensor*>(tensor_))
        ->layout = layout;
122
  } else {
123
    PADDLE_THROW(phi::errors::Unimplemented(
124 125 126 127 128
        "Unsupported settting layout for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
129 130
  ValidCheck(*this);
  ValidCheck(meta_tensor);
131 132 133 134
  if (phi::SparseCooTensor::classof(tensor_) ||
      phi::SparseCsrTensor::classof(tensor_)) {
    return;
  }
H
hong 已提交
135 136 137 138
  if (meta_tensor.lod().size() == 0) {
    // no need share
    return;
  }
139
  if (phi::DenseTensor::classof(tensor_)) {
140 141
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod =
        meta_tensor.lod();
142
  } else if (phi::SelectedRows::classof(tensor_)) {
143 144 145
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->lod = meta_tensor.lod();
146
  } else {
147
    PADDLE_THROW(
148 149
        phi::errors::Unimplemented("Unsupported sharing lod inplace for `%s`.",
                                   tensor_->type_info().name()));
150 151 152
  }
}

153
void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
154
  ValidCheck(*this);
Y
YuanRisheng 已提交
155
  if (phi::DenseTensor::classof(tensor_) ||
156 157 158
      phi::SelectedRows::classof(tensor_) ||
      phi::SparseCooTensor::classof(tensor_) ||
      phi::SparseCsrTensor::classof(tensor_)) {
Y
YuanRisheng 已提交
159
    share_dims(meta_tensor);
160 161 162
    set_dtype(meta_tensor.dtype());
    set_layout(meta_tensor.layout());
    share_lod(meta_tensor);
163
  } else {
164
    PADDLE_THROW(phi::errors::Unimplemented(
165
        "Unsupported sharing meta for `%s`.", tensor_->type_info().name()));
166 167 168
  }
}

169 170
TensorBase* MetaTensor::tensor() const { return tensor_; }

171
bool MetaTensor::is_dense() const { return DenseTensor::classof(tensor_); }
172 173 174 175
bool MetaTensor::is_selected_rows() const {
  return SelectedRows::classof(tensor_);
}

176 177
bool MetaTensor::is_tensor_array() const { return false; }

Y
YuanRisheng 已提交
178
void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
179
  ValidCheck(*this);
Y
YuanRisheng 已提交
180 181
  bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
  bool is_selected_rows = phi::SelectedRows::classof(tensor_);
182 183 184
  bool is_sparse_coo = phi::SparseCooTensor::classof(tensor_);
  bool is_sparse_csr = phi::SparseCsrTensor::classof(tensor_);
  if (is_dense_tensor || is_selected_rows || is_sparse_coo || is_sparse_csr) {
Y
YuanRisheng 已提交
185
    if (is_selected_rows) {
C
Chen Weihang 已提交
186
      const auto in_tensor_base = meta_tensor.tensor();
Y
YuanRisheng 已提交
187 188 189 190 191 192 193 194 195
      PADDLE_ENFORCE_EQ(
          phi::SelectedRows::classof(in_tensor_base),
          true,
          errors::InvalidArgument("The input MetaTensor is SelectedRows, but "
                                  "the output MetaTensor is not this type."));
      auto* selected_rows_out = static_cast<SelectedRows*>(tensor_);
      auto* selected_rows_in = static_cast<SelectedRows*>(in_tensor_base);
      selected_rows_out->set_rows(selected_rows_in->rows());
      selected_rows_out->set_height(selected_rows_in->height());
196 197 198 199 200
      DenseTensorUtils::GetMutableMeta(
          static_cast<SelectedRows*>(tensor_)->mutable_value())
          ->dims = selected_rows_in->mutable_value()->dims();
    } else {
      set_dims(meta_tensor.dims());
Y
YuanRisheng 已提交
201 202 203 204 205 206 207
    }
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "Unsupported sharing dims for `%s`.", tensor_->type_info().name()));
  }
}

208 209
bool MetaTensor::initialized() const { return tensor_ != nullptr; }

210 211 212 213 214 215 216
// Private Member Methods

const LoD& MetaTensor::lod() const {
  if (phi::DenseTensor::classof(tensor_)) {
    return static_cast<DenseTensor*>(tensor_)->lod();
  } else if (phi::SelectedRows::classof(tensor_)) {
    return static_cast<SelectedRows*>(tensor_)->value().lod();
217 218 219 220
  } else if (phi::SparseCooTensor::classof(tensor_)) {
    return static_cast<SparseCooTensor*>(tensor_)->non_zero_elements().lod();
  } else if (phi::SparseCsrTensor::classof(tensor_)) {
    return static_cast<SparseCsrTensor*>(tensor_)->non_zero_elements().lod();
221 222 223 224 225 226
  } else {
    PADDLE_THROW(phi::errors::Unimplemented("Unsupported getting lod of `%s`.",
                                            tensor_->type_info().name()));
  }
}

227
}  // namespace phi