未验证 提交 a548e70c 编写于 作者: E engineer1109 提交者: GitHub

fix custom plugin include headers error (#51013)

上级 ed511175
...@@ -912,4 +912,40 @@ template phi::dtype::complex<float> GetValue(const phi::DenseTensor* x); ...@@ -912,4 +912,40 @@ template phi::dtype::complex<float> GetValue(const phi::DenseTensor* x);
template phi::dtype::complex<double> GetValue(const phi::DenseTensor* x); template phi::dtype::complex<double> GetValue(const phi::DenseTensor* x);
template <typename T>
std::vector<T> GetVectorFromTensor(const phi::DenseTensor* x) {
std::vector<T> vec_new_data;
if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT32) {
auto* data = x->data<int>();
phi::DenseTensor cpu_attr_tensor;
if (x->place().GetType() != phi::AllocationType::CPU) {
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto dev_ctx = pool.Get(x->place());
phi::Copy(*dev_ctx, *x, CPUPlace(), true, &cpu_attr_tensor);
data = cpu_attr_tensor.data<int>();
}
vec_new_data = std::vector<T>(data, data + x->numel());
} else if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT64) {
auto* data = x->data<int64_t>();
phi::DenseTensor cpu_attr_tensor;
if (x->place().GetType() != phi::AllocationType::CPU) {
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto dev_ctx = pool.Get(x->place());
phi::Copy(*dev_ctx, *x, CPUPlace(), true, &cpu_attr_tensor);
data = cpu_attr_tensor.data<int64_t>();
}
// NOTE: Converting int64 to int32 may cause data overflow.
vec_new_data = std::vector<T>(data, data + x->numel());
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The dtype of Tensor must be int32 or int64, but received: %s",
phi::TransToProtoVarType(x->dtype())));
}
return vec_new_data;
}
template std::vector<int32_t> GetVectorFromTensor(const phi::DenseTensor* x);
template std::vector<int64_t> GetVectorFromTensor(const phi::DenseTensor* x);
} // namespace phi } // namespace phi
...@@ -14,17 +14,12 @@ limitations under the License. */ ...@@ -14,17 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
class DenseTensorUtils { class DenseTensorUtils {
...@@ -149,35 +144,6 @@ inline T GetValue(const Context& dev_ctx, const DenseTensor& x) { ...@@ -149,35 +144,6 @@ inline T GetValue(const Context& dev_ctx, const DenseTensor& x) {
} }
template <typename T = int32_t> template <typename T = int32_t>
inline std::vector<T> GetVectorFromTensor(const phi::DenseTensor* x) { std::vector<T> GetVectorFromTensor(const phi::DenseTensor* x);
std::vector<T> vec_new_data;
if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT32) {
auto* data = x->data<int>();
phi::DenseTensor cpu_attr_tensor;
if (!paddle::platform::is_cpu_place(x->place())) {
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto dev_ctx = pool.Get(x->place());
phi::Copy(*dev_ctx, *x, CPUPlace(), true, &cpu_attr_tensor);
data = cpu_attr_tensor.data<int>();
}
vec_new_data = std::vector<T>(data, data + x->numel());
} else if (phi::TransToProtoVarType(x->dtype()) == ProtoDataType::INT64) {
auto* data = x->data<int64_t>();
phi::DenseTensor cpu_attr_tensor;
if (!paddle::platform::is_cpu_place(x->place())) {
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto dev_ctx = pool.Get(x->place());
phi::Copy(*dev_ctx, *x, CPUPlace(), true, &cpu_attr_tensor);
data = cpu_attr_tensor.data<int64_t>();
}
// NOTE: Converting int64 to int32 may cause data overflow.
vec_new_data = std::vector<T>(data, data + x->numel());
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The dtype of Tensor must be int32 or int64, but received: %s",
phi::TransToProtoVarType(x->dtype())));
}
return vec_new_data;
}
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册