diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index 7ba4cd55183fafb6e25816c5a28ccc4db78b2d9f..65b90c88085f816030641fdcf2b8216f1503e462 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -912,4 +912,40 @@ template phi::dtype::complex GetValue(const phi::DenseTensor* x); template phi::dtype::complex GetValue(const phi::DenseTensor* x); +template +std::vector GetVectorFromTensor(const phi::DenseTensor* x) { + std::vector vec_new_data; + if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT32) { + auto* data = x->data(); + phi::DenseTensor cpu_attr_tensor; + if (x->place().GetType() != phi::AllocationType::CPU) { + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto dev_ctx = pool.Get(x->place()); + phi::Copy(*dev_ctx, *x, CPUPlace(), true, &cpu_attr_tensor); + data = cpu_attr_tensor.data(); + } + vec_new_data = std::vector(data, data + x->numel()); + } else if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT64) { + auto* data = x->data(); + phi::DenseTensor cpu_attr_tensor; + if (x->place().GetType() != phi::AllocationType::CPU) { + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto dev_ctx = pool.Get(x->place()); + phi::Copy(*dev_ctx, *x, CPUPlace(), true, &cpu_attr_tensor); + data = cpu_attr_tensor.data(); + } + // NOTE: Converting int64 to int32 may cause data overflow. + vec_new_data = std::vector(data, data + x->numel()); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "The dtype of Tensor must be int32 or int64, but received: %s", + phi::TransToProtoVarType(x->dtype()))); + } + return vec_new_data; +} + +template std::vector GetVectorFromTensor(const phi::DenseTensor* x); + +template std::vector GetVectorFromTensor(const phi::DenseTensor* x); + } // namespace phi diff --git a/paddle/phi/core/tensor_utils.h b/paddle/phi/core/tensor_utils.h index 9203621931963a8aa60a5a674efa0f3faf391c01..e122cdb954640440c655f57ac5ebdcd7e5e21ca9 100644 --- a/paddle/phi/core/tensor_utils.h +++ b/paddle/phi/core/tensor_utils.h @@ -14,17 +14,12 @@ limitations under the License. */ #pragma once -#include "paddle/phi/backends/all_context.h" -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/tensor_meta.h" -#include "paddle/phi/core/utils/data_type.h" - namespace phi { class DenseTensorUtils { @@ -149,35 +144,6 @@ inline T GetValue(const Context& dev_ctx, const DenseTensor& x) { } template -inline std::vector GetVectorFromTensor(const phi::DenseTensor* x) { - std::vector vec_new_data; - if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT32) { - auto* data = x->data(); - phi::DenseTensor cpu_attr_tensor; - if (!paddle::platform::is_cpu_place(x->place())) { - phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); - auto dev_ctx = pool.Get(x->place()); - phi::Copy(*dev_ctx, *x, CPUPlace(), true, &cpu_attr_tensor); - data = cpu_attr_tensor.data(); - } - vec_new_data = std::vector(data, data + x->numel()); - } else if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT64) { - auto* data = x->data(); - phi::DenseTensor cpu_attr_tensor; - if (!paddle::platform::is_cpu_place(x->place())) { - phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); - auto dev_ctx = pool.Get(x->place()); - phi::Copy(*dev_ctx, *x, CPUPlace(), true, &cpu_attr_tensor); - data = cpu_attr_tensor.data(); - } - // NOTE: Converting int64 to int32 may cause data overflow. - vec_new_data = std::vector(data, data + x->numel()); - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "The dtype of Tensor must be int32 or int64, but received: %s", - phi::TransToProtoVarType(x->dtype()))); - } - return vec_new_data; -} +std::vector GetVectorFromTensor(const phi::DenseTensor* x); } // namespace phi