tensor_utils.cc 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/api/lib/utils/tensor_utils.h"
16

17
#include <utility>
18 19
#include <vector>

20
#include "paddle/phi/core/tensor_utils.h"
21

22 23 24 25 26 27 28 29 30 31 32 33
namespace paddle {
namespace experimental {

template <typename DstLoD, typename SrcLoD>
void SetLoD(DstLoD* dst, const SrcLoD& src) {
  dst->reserve(src.size());
  dst->clear();
  for (auto&& v : src) {
    dst->emplace_back(v);
  }
}

34
std::unique_ptr<phi::DenseTensor> MakePhiDenseTensor(
35
    const paddle::framework::Tensor& src) {
36
  return std::make_unique<phi::DenseTensor>(src);
37 38
}

39 40
phi::Scalar MakePhiScalarFromVar(const framework::Variable& variable) {
  auto expected_place = phi::TransToPhiPlace(phi::Backend::CPU);
41 42 43 44 45
  if (variable.IsType<framework::LoDTensor>()) {
    const auto& tensor = variable.Get<framework::LoDTensor>();
    if (!platform::is_same_place(tensor.place(), expected_place)) {
      framework::LoDTensor tmp_tensor;
      framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
C
chentianyu03 已提交
46
      return {tmp_tensor};
47
    } else {
C
chentianyu03 已提交
48
      return {tensor};
49 50 51 52 53 54 55 56 57
    }
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Unsupport casting input `%s` type to Scalar when call pt "
        "kernel.",
        framework::ToTypeName(variable.Type())));
  }
}

58
phi::ScalarArray MakePhiScalarArray(const paddle::framework::Tensor& src) {
C
chentianyu03 已提交
59
  return {src};
60 61
}

62
phi::ScalarArray MakePhiScalarArrayFromVar(
63
    const framework::Variable& variable) {
64
  auto expected_place = phi::TransToPhiPlace(phi::Backend::CPU);
65 66 67 68 69
  if (variable.IsType<framework::LoDTensor>()) {
    const auto& tensor = variable.Get<framework::LoDTensor>();
    if (!platform::is_same_place(tensor.place(), expected_place)) {
      framework::LoDTensor tmp_tensor;
      framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
70
      return MakePhiScalarArray(tmp_tensor);
71
    } else {
72
      return MakePhiScalarArray(tensor);
73 74 75 76 77 78 79 80 81
    }
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Unsupport casting input `%s` type to ScalarArray when call pt "
        "kernel.",
        framework::ToTypeName(variable.Type())));
  }
}

C
chentianyu03 已提交
82
// TODO(chentianyu03): Inplace with ScalarArray constructor
83
phi::ScalarArray MakePhiScalarArrayFromVarList(
84 85
    const std::vector<framework::Variable*>& variable_list) {
  if (variable_list.size() == 0) {
86
    return phi::ScalarArray();
87
  }
88
  auto expected_place = phi::TransToPhiPlace(phi::Backend::CPU);
89 90 91 92

  std::vector<int64_t> vector_data;
  vector_data.reserve(variable_list.size());

C
chentianyu03 已提交
93
  for (auto* var : variable_list) {
94
    paddle::experimental::DataType data_type;
C
chentianyu03 已提交
95 96
    if (var->IsType<framework::LoDTensor>()) {
      const auto& tensor = var->Get<framework::LoDTensor>();
97 98
      data_type = tensor.dtype();
      if (data_type == paddle::experimental::DataType::INT64) {
99
        const auto& tensor = var->Get<framework::LoDTensor>();
C
chentianyu03 已提交
100 101
        if (tensor.IsInitialized() &&
            !platform::is_same_place(tensor.place(), expected_place)) {
102 103 104 105 106 107
          framework::LoDTensor tmp_tensor;
          framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
          vector_data.push_back(*tmp_tensor.data<int64_t>());
        } else {
          vector_data.push_back(*tensor.data<int64_t>());
        }
108
      } else if (data_type == paddle::experimental::DataType::INT32) {
109
        const auto& tensor = var->Get<framework::LoDTensor>();
C
chentianyu03 已提交
110 111
        if (tensor.IsInitialized() &&
            !platform::is_same_place(tensor.place(), expected_place)) {
112 113 114 115 116 117 118
          framework::LoDTensor tmp_tensor;
          framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
          vector_data.push_back(*tmp_tensor.data<int32_t>());
        } else {
          vector_data.push_back(*tensor.data<int32_t>());
        }
      } else {
119
        PADDLE_THROW(phi::errors::InvalidArgument(
C
chentianyu03 已提交
120 121 122 123
            "Data type error. When cast a LoDTensor to VectorTensor, "
            "the data type of LoDTensor must be int32 or int64, "
            "but now data type is %s.",
            data_type));
124
      }
C
chentianyu03 已提交
125
    } else {
126
      PADDLE_THROW(phi::errors::Unimplemented(
C
chentianyu03 已提交
127 128 129
          "Unsupport casting input `%s` type to VectorTensor when call pt "
          "kernel.",
          framework::ToTypeName(var->Type())));
130 131 132
    }
  }

133
  phi::ScalarArray result{vector_data};
134
  result.SetFromTensor(true);
C
chentianyu03 已提交
135 136

  return result;
137 138
}

139 140
}  // namespace experimental
}  // namespace paddle