未验证 提交 daa6cb92 编写于 作者: R ronnywang 提交者: GitHub

[CustomKernel] phi capi add inference support (#44268)

上级 7dc7fc4b
...@@ -40,6 +40,10 @@ get_property(phi_modules GLOBAL PROPERTY PHI_MODULES) ...@@ -40,6 +40,10 @@ get_property(phi_modules GLOBAL PROPERTY PHI_MODULES)
get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS) get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
set(utils_modules stringpiece pretty_log string_helper benchmark) set(utils_modules stringpiece pretty_log string_helper benchmark)
if(WITH_CUSTOM_DEVICE)
set(fluid_modules ${fluid_modules} phi_capi)
endif()
add_subdirectory(api) add_subdirectory(api)
# Create static inference library if needed # Create static inference library if needed
......
...@@ -55,6 +55,9 @@ set(paddle_inference_api_deps ...@@ -55,6 +55,9 @@ set(paddle_inference_api_deps
if(WITH_CRYPTO) if(WITH_CRYPTO)
list(APPEND paddle_inference_api_deps paddle_crypto) list(APPEND paddle_inference_api_deps paddle_crypto)
endif() endif()
if(WITH_CUSTOM_DEVICE)
set(paddle_inference_api_deps ${paddle_inference_api_deps} phi_capi)
endif()
cc_library( cc_library(
paddle_inference_api paddle_inference_api
......
...@@ -156,3 +156,7 @@ std::shared_ptr<framework::Cipher> MakeCipher(const std::string &config_file) { ...@@ -156,3 +156,7 @@ std::shared_ptr<framework::Cipher> MakeCipher(const std::string &config_file) {
#endif #endif
} // namespace paddle } // namespace paddle
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/capi/capi.h"
#endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*profile*; *profile*;
*phi*; *phi*;
*FLAGS_*; *FLAGS_*;
PD_*;
local: local:
*; *;
}; };
...@@ -24,7 +24,7 @@ extern "C" { ...@@ -24,7 +24,7 @@ extern "C" {
typedef struct PD_Tensor PD_Tensor; typedef struct PD_Tensor PD_Tensor;
PD_DataType PD_TensorGetDataType(const PD_Tensor *tensor, PD_Status *status); PD_DataType PD_TensorGetPDDataType(const PD_Tensor *tensor, PD_Status *status);
PD_DataLayout PD_TensorGetDataLayout(const PD_Tensor *tensor, PD_DataLayout PD_TensorGetDataLayout(const PD_Tensor *tensor,
PD_Status *status); PD_Status *status);
......
...@@ -128,7 +128,7 @@ class DenseTensor : public WrapperBase<PD_Tensor> { ...@@ -128,7 +128,7 @@ class DenseTensor : public WrapperBase<PD_Tensor> {
PD_DataType dtype() const { PD_DataType dtype() const {
C_Status status; C_Status status;
auto data_type = PD_TensorGetDataType(raw_data(), &status); auto data_type = PD_TensorGetPDDataType(raw_data(), &status);
PD_CHECK_STATUS(status); PD_CHECK_STATUS(status);
return data_type; return data_type;
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
PD_DataType PD_TensorGetDataType(const PD_Tensor* tensor, PD_Status* status) { PD_DataType PD_TensorGetPDDataType(const PD_Tensor* tensor, PD_Status* status) {
if (status) { if (status) {
if (!tensor) { if (!tensor) {
*status = C_FAILED; *status = C_FAILED;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册