tensor_utils.cc 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/utils/tensor_utils.h"
16

17
#include <utility>
18 19
#include <vector>

20
#include "paddle/phi/core/tensor_utils.h"
21

22 23 24 25 26 27 28 29 30 31 32 33
namespace paddle {
namespace experimental {

template <typename DstLoD, typename SrcLoD>
void SetLoD(DstLoD* dst, const SrcLoD& src) {
  dst->reserve(src.size());
  dst->clear();
  for (auto&& v : src) {
    dst->emplace_back(v);
  }
}

34
std::unique_ptr<phi::DenseTensor> MakePtenDenseTensor(
35
    const paddle::framework::Tensor& src) {
36
  return std::make_unique<phi::DenseTensor>(src);
37 38
}

39 40
phi::Scalar MakePtenScalarFromVar(const framework::Variable& variable) {
  auto expected_place = phi::TransToPtenPlace(phi::Backend::CPU);
41 42 43 44 45
  if (variable.IsType<framework::LoDTensor>()) {
    const auto& tensor = variable.Get<framework::LoDTensor>();
    if (!platform::is_same_place(tensor.place(), expected_place)) {
      framework::LoDTensor tmp_tensor;
      framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
C
chentianyu03 已提交
46
      return {tmp_tensor};
47
    } else {
C
chentianyu03 已提交
48
      return {tensor};
49 50 51 52 53 54 55 56 57
    }
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Unsupport casting input `%s` type to Scalar when call pt "
        "kernel.",
        framework::ToTypeName(variable.Type())));
  }
}

58
phi::ScalarArray MakePtenScalarArray(const paddle::framework::Tensor& src) {
C
chentianyu03 已提交
59
  return {src};
60 61
}

62
phi::ScalarArray MakePtenScalarArrayFromVar(
63
    const framework::Variable& variable) {
64
  auto expected_place = phi::TransToPtenPlace(phi::Backend::CPU);
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
  if (variable.IsType<framework::LoDTensor>()) {
    const auto& tensor = variable.Get<framework::LoDTensor>();
    if (!platform::is_same_place(tensor.place(), expected_place)) {
      framework::LoDTensor tmp_tensor;
      framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
      return MakePtenScalarArray(tmp_tensor);
    } else {
      return MakePtenScalarArray(tensor);
    }
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Unsupport casting input `%s` type to ScalarArray when call pt "
        "kernel.",
        framework::ToTypeName(variable.Type())));
  }
}

C
chentianyu03 已提交
82
// TODO(chentianyu03): Inplace with ScalarArray constructor
83
phi::ScalarArray MakePtenScalarArrayFromVarList(
84 85
    const std::vector<framework::Variable*>& variable_list) {
  if (variable_list.size() == 0) {
86
    return phi::ScalarArray();
87
  }
88
  auto expected_place = phi::TransToPtenPlace(phi::Backend::CPU);
89 90 91 92

  std::vector<int64_t> vector_data;
  vector_data.reserve(variable_list.size());

C
chentianyu03 已提交
93
  for (auto* var : variable_list) {
94
    paddle::experimental::DataType data_type;
C
chentianyu03 已提交
95 96
    if (var->IsType<framework::LoDTensor>()) {
      const auto& tensor = var->Get<framework::LoDTensor>();
97 98
      data_type = tensor.dtype();
      if (data_type == paddle::experimental::DataType::INT64) {
99
        const auto& tensor = var->Get<framework::LoDTensor>();
C
chentianyu03 已提交
100 101
        if (tensor.IsInitialized() &&
            !platform::is_same_place(tensor.place(), expected_place)) {
102 103 104 105 106 107
          framework::LoDTensor tmp_tensor;
          framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
          vector_data.push_back(*tmp_tensor.data<int64_t>());
        } else {
          vector_data.push_back(*tensor.data<int64_t>());
        }
108
      } else if (data_type == paddle::experimental::DataType::INT32) {
109
        const auto& tensor = var->Get<framework::LoDTensor>();
C
chentianyu03 已提交
110 111
        if (tensor.IsInitialized() &&
            !platform::is_same_place(tensor.place(), expected_place)) {
112 113 114 115 116 117 118
          framework::LoDTensor tmp_tensor;
          framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
          vector_data.push_back(*tmp_tensor.data<int32_t>());
        } else {
          vector_data.push_back(*tensor.data<int32_t>());
        }
      } else {
119
        PADDLE_THROW(phi::errors::InvalidArgument(
C
chentianyu03 已提交
120 121 122 123
            "Data type error. When cast a LoDTensor to VectorTensor, "
            "the data type of LoDTensor must be int32 or int64, "
            "but now data type is %s.",
            data_type));
124
      }
C
chentianyu03 已提交
125
    } else {
126
      PADDLE_THROW(phi::errors::Unimplemented(
C
chentianyu03 已提交
127 128 129
          "Unsupport casting input `%s` type to VectorTensor when call pt "
          "kernel.",
          framework::ToTypeName(var->Type())));
130 131 132
    }
  }

133
  phi::ScalarArray result{vector_data};
C
chentianyu03 已提交
134 135 136
  result.setInitByTensor(true);

  return result;
137 138
}

139 140
void ResetTensorDtypeAndLayoutByArgDef(phi::TensorBase* dst,
                                       const phi::TensorArgDef& arg_def) {
141
  VLOG(5) << "ResetTensor by TensorArgDef.";
142 143 144
  if (phi::DenseTensor::classof(dst)) {
    auto* dense_t = static_cast<phi::DenseTensor*>(dst);
    auto* meta = phi::DenseTensorUtils::GetMutableMeta(dense_t);
145 146
    meta->dtype = arg_def.dtype;
    meta->layout = arg_def.layout;
147 148
  } else if (phi::SelectedRows::classof(dst)) {
    auto* selected_rows = static_cast<phi::SelectedRows*>(dst);
149
    auto* meta =
150
        phi::DenseTensorUtils::GetMutableMeta(selected_rows->mutable_value());
151 152 153
    meta->dtype = arg_def.dtype;
    meta->layout = arg_def.layout;
  } else {
154
    PADDLE_THROW(phi::errors::Unimplemented(
155 156 157
        "Unsupported tensor type is received when reseting tensor dtype and "
        "layout by argument definition."));
  }
158 159
}

160 161
}  // namespace experimental
}  // namespace paddle