tensor_utils.cc 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/utils/tensor_utils.h"
16

17
#include <utility>
18 19
#include <vector>

20
#include "paddle/phi/core/tensor_utils.h"
21

22 23 24 25 26 27 28 29 30 31 32 33
namespace paddle {
namespace experimental {

template <typename DstLoD, typename SrcLoD>
void SetLoD(DstLoD* dst, const SrcLoD& src) {
  dst->reserve(src.size());
  dst->clear();
  for (auto&& v : src) {
    dst->emplace_back(v);
  }
}

34
std::unique_ptr<phi::DenseTensor> MakePhiDenseTensor(
35
    const phi::DenseTensor& src) {
36
  return std::make_unique<phi::DenseTensor>(src);
37 38
}

39 40
phi::Scalar MakePhiScalarFromVar(const framework::Variable& variable) {
  auto expected_place = phi::TransToPhiPlace(phi::Backend::CPU);
41 42
  if (variable.IsType<framework::LoDTensor>()) {
    const auto& tensor = variable.Get<framework::LoDTensor>();
43 44 45 46 47 48 49
    PADDLE_ENFORCE_EQ(
        tensor.numel(),
        1UL,
        platform::errors::InvalidArgument("The DenseTensor used to construct "
                                          "the Scalar contains more than 1 "
                                          "value, it contains `%d` values.",
                                          tensor.numel()));
50 51 52
    if (!platform::is_same_place(tensor.place(), expected_place)) {
      framework::LoDTensor tmp_tensor;
      framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
C
chentianyu03 已提交
53
      return {tmp_tensor};
54
    } else {
C
chentianyu03 已提交
55
      return {tensor};
56 57 58 59 60 61 62 63 64
    }
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Unsupport casting input `%s` type to Scalar when call pt "
        "kernel.",
        framework::ToTypeName(variable.Type())));
  }
}

65
phi::IntArray MakePhiIntArray(const phi::DenseTensor& src) { return {src}; }
66

67
phi::IntArray MakePhiIntArrayFromVar(const framework::Variable& variable) {
68 69
  if (variable.IsType<framework::LoDTensor>()) {
    const auto& tensor = variable.Get<framework::LoDTensor>();
70
    return MakePhiIntArray(tensor);
71 72
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
73
        "Unsupport casting input `%s` type to IntArray when call pt "
74 75 76 77 78
        "kernel.",
        framework::ToTypeName(variable.Type())));
  }
}

79 80
// TODO(chentianyu03): Inplace with IntArray constructor
phi::IntArray MakePhiIntArrayFromVarList(
81 82
    const std::vector<framework::Variable*>& variable_list) {
  if (variable_list.size() == 0) {
83
    return phi::IntArray();
84
  }
85
  auto expected_place = phi::TransToPhiPlace(phi::Backend::CPU);
86 87 88 89

  std::vector<int64_t> vector_data;
  vector_data.reserve(variable_list.size());

C
chentianyu03 已提交
90
  for (auto* var : variable_list) {
91
    paddle::experimental::DataType data_type;
C
chentianyu03 已提交
92 93
    if (var->IsType<framework::LoDTensor>()) {
      const auto& tensor = var->Get<framework::LoDTensor>();
94 95
      data_type = tensor.dtype();
      if (data_type == paddle::experimental::DataType::INT64) {
96
        const auto& tensor = var->Get<framework::LoDTensor>();
C
chentianyu03 已提交
97 98
        if (tensor.IsInitialized() &&
            !platform::is_same_place(tensor.place(), expected_place)) {
99 100 101 102 103 104
          framework::LoDTensor tmp_tensor;
          framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
          vector_data.push_back(*tmp_tensor.data<int64_t>());
        } else {
          vector_data.push_back(*tensor.data<int64_t>());
        }
105
      } else if (data_type == paddle::experimental::DataType::INT32) {
106
        const auto& tensor = var->Get<framework::LoDTensor>();
C
chentianyu03 已提交
107 108
        if (tensor.IsInitialized() &&
            !platform::is_same_place(tensor.place(), expected_place)) {
109 110 111 112 113 114 115
          framework::LoDTensor tmp_tensor;
          framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
          vector_data.push_back(*tmp_tensor.data<int32_t>());
        } else {
          vector_data.push_back(*tensor.data<int32_t>());
        }
      } else {
116
        PADDLE_THROW(phi::errors::InvalidArgument(
C
chentianyu03 已提交
117 118 119 120
            "Data type error. When cast a LoDTensor to VectorTensor, "
            "the data type of LoDTensor must be int32 or int64, "
            "but now data type is %s.",
            data_type));
121
      }
C
chentianyu03 已提交
122
    } else {
123
      PADDLE_THROW(phi::errors::Unimplemented(
C
chentianyu03 已提交
124 125 126
          "Unsupport casting input `%s` type to VectorTensor when call pt "
          "kernel.",
          framework::ToTypeName(var->Type())));
127 128 129
    }
  }

130
  phi::IntArray result{vector_data};
131
  result.SetFromTensor(true);
C
chentianyu03 已提交
132 133

  return result;
134 135
}

136 137
}  // namespace experimental
}  // namespace paddle