meta_tensor.cc 6.1 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/core/meta_tensor.h"
16

17 18
#include "glog/logging.h"

19 20 21
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"
J
Jack Zhou 已提交
22 23
#include "paddle/phi/core/string_tensor.h"
#include "paddle/phi/core/string_tensor_utils.h"
24
#include "paddle/phi/core/tensor_utils.h"
25

26
namespace phi {
27

28 29 30 31 32 33
static inline void ValidCheck(const MetaTensor& meta_tensor) {
  PADDLE_ENFORCE_EQ(meta_tensor.initialized(),
                    true,
                    phi::errors::InvalidArgument(
                        "The current MetaTensor is not initialized."));
}
34

35 36 37 38
int64_t MetaTensor::numel() const {
  ValidCheck(*this);
  return tensor_->numel();
}
39

40 41 42 43
DDim MetaTensor::dims() const {
  ValidCheck(*this);
  return tensor_->dims();
}
44

45 46 47 48 49 50 51 52 53
DataType MetaTensor::dtype() const {
  ValidCheck(*this);
  return tensor_->dtype();
}

DataLayout MetaTensor::layout() const {
  ValidCheck(*this);
  return tensor_->layout();
}
54 55

void MetaTensor::set_dims(const DDim& dims) {
56
  ValidCheck(*this);
57
  if (phi::DenseTensor::classof(tensor_)) {
58 59
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->dims =
        dims;
J
Jack Zhou 已提交
60 61 62
  } else if (phi::StringTensor::classof(tensor_)) {
    StringTensorUtils::GetMutableMeta(static_cast<StringTensor*>(tensor_))
        ->dims = dims;
63
  } else if (phi::SelectedRows::classof(tensor_)) {
64 65 66
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->dims = dims;
67
  } else {
68
    PADDLE_THROW(phi::errors::Unimplemented(
69 70 71 72 73
        "Unsupported setting dims for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::set_dtype(DataType dtype) {
74
  ValidCheck(*this);
75
  if (phi::DenseTensor::classof(tensor_)) {
76
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
77
        ->dtype = dtype;
J
Jack Zhou 已提交
78 79
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set dtype
80
  } else if (phi::SelectedRows::classof(tensor_)) {
81 82 83
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->dtype = dtype;
84
  } else {
85
    PADDLE_THROW(phi::errors::Unimplemented(
86 87 88 89 90
        "Unsupported settting dtype for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::set_layout(DataLayout layout) {
91
  ValidCheck(*this);
92
  if (phi::DenseTensor::classof(tensor_)) {
93
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
94
        ->layout = layout;
J
Jack Zhou 已提交
95 96
  } else if (phi::StringTensor::classof(tensor_)) {
    // No need to set layout
97
  } else if (phi::SelectedRows::classof(tensor_)) {
98 99 100
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->layout = layout;
101
  } else {
102
    PADDLE_THROW(phi::errors::Unimplemented(
103 104 105 106 107
        "Unsupported settting layout for `%s`.", tensor_->type_info().name()));
  }
}

void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
108 109
  ValidCheck(*this);
  ValidCheck(meta_tensor);
H
hong 已提交
110 111 112 113
  if (meta_tensor.lod().size() == 0) {
    // no need share
    return;
  }
114
  if (phi::DenseTensor::classof(tensor_)) {
115 116
    DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod =
        meta_tensor.lod();
117
  } else if (phi::SelectedRows::classof(tensor_)) {
118 119 120
    DenseTensorUtils::GetMutableMeta(
        static_cast<SelectedRows*>(tensor_)->mutable_value())
        ->lod = meta_tensor.lod();
121
  } else {
122
    PADDLE_THROW(
123 124
        phi::errors::Unimplemented("Unsupported sharing lod inplace for `%s`.",
                                   tensor_->type_info().name()));
125 126 127
  }
}

128
void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
129
  ValidCheck(*this);
Y
YuanRisheng 已提交
130 131 132
  if (phi::DenseTensor::classof(tensor_) ||
      phi::SelectedRows::classof(tensor_)) {
    share_dims(meta_tensor);
133 134 135
    set_dtype(meta_tensor.dtype());
    set_layout(meta_tensor.layout());
    share_lod(meta_tensor);
136
  } else {
137
    PADDLE_THROW(phi::errors::Unimplemented(
138
        "Unsupported sharing meta for `%s`.", tensor_->type_info().name()));
139 140 141
  }
}

Y
YuanRisheng 已提交
142
void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
143
  ValidCheck(*this);
Y
YuanRisheng 已提交
144 145 146 147 148
  bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
  bool is_selected_rows = phi::SelectedRows::classof(tensor_);
  if (is_dense_tensor || is_selected_rows) {
    set_dims(meta_tensor.dims());
    if (is_selected_rows) {
C
Chen Weihang 已提交
149
      const auto in_tensor_base = meta_tensor.tensor();
Y
YuanRisheng 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
      PADDLE_ENFORCE_EQ(
          phi::SelectedRows::classof(in_tensor_base),
          true,
          errors::InvalidArgument("The input MetaTensor is SelectedRows, but "
                                  "the output MetaTensor is not this type."));
      auto* selected_rows_out = static_cast<SelectedRows*>(tensor_);
      auto* selected_rows_in = static_cast<SelectedRows*>(in_tensor_base);
      selected_rows_out->set_rows(selected_rows_in->rows());
      selected_rows_out->set_height(selected_rows_in->height());
    }
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "Unsupported sharing dims for `%s`.", tensor_->type_info().name()));
  }
}

166 167
bool MetaTensor::initialized() const { return tensor_ != nullptr; }

168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
// Private Member Methods

const LoD& MetaTensor::lod() const {
  if (phi::DenseTensor::classof(tensor_)) {
    return static_cast<DenseTensor*>(tensor_)->lod();
  } else if (phi::SelectedRows::classof(tensor_)) {
    return static_cast<SelectedRows*>(tensor_)->value().lod();
  } else {
    PADDLE_THROW(phi::errors::Unimplemented("Unsupported getting lod of `%s`.",
                                            tensor_->type_info().name()));
  }
}

TensorBase* MetaTensor::tensor() const { return tensor_; }

183
}  // namespace phi