未验证 提交 08b43cce 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Support kps backend and kernel registry (#39941)

* support kps backend and compile

* resolve conflict

* fix kps backend trans

* test in xpu2 device

* remove dummy kernel
上级 d4911594
...@@ -667,6 +667,7 @@ function(xpu_library TARGET_NAME) ...@@ -667,6 +667,7 @@ function(xpu_library TARGET_NAME)
else() else()
xpu_add_library(${TARGET_NAME} STATIC ${xpu_library_SRCS} DEPENDS ${xpu_library_DEPS}) xpu_add_library(${TARGET_NAME} STATIC ${xpu_library_SRCS} DEPENDS ${xpu_library_DEPS})
find_fluid_modules(${TARGET_NAME}) find_fluid_modules(${TARGET_NAME})
find_phi_modules(${TARGET_NAME})
endif() endif()
if (xpu_library_DEPS) if (xpu_library_DEPS)
add_dependencies(${TARGET_NAME} ${xpu_library_DEPS}) add_dependencies(${TARGET_NAME} ${xpu_library_DEPS})
......
...@@ -83,6 +83,8 @@ function(kernel_declare TARGET_LIST) ...@@ -83,6 +83,8 @@ function(kernel_declare TARGET_LIST)
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n")
elseif (${kernel_path} MATCHES "./gpudnn\/") elseif (${kernel_path} MATCHES "./gpudnn\/")
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, GPUDNN, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, GPUDNN, ALL_LAYOUT);\n")
elseif (${kernel_path} MATCHES "./kps\/")
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, KPS, ALL_LAYOUT);\n")
else () else ()
# deal with device independent kernel, now we use CPU temporaary # deal with device independent kernel, now we use CPU temporaary
file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n") file(APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n")
...@@ -97,6 +99,7 @@ function(kernel_library TARGET) ...@@ -97,6 +99,7 @@ function(kernel_library TARGET)
set(gpu_srcs) set(gpu_srcs)
set(xpu_srcs) set(xpu_srcs)
set(gpudnn_srcs) set(gpudnn_srcs)
set(kps_srcs)
set(selected_rows_srcs) set(selected_rows_srcs)
# parse and save the deps kerenl targets # parse and save the deps kerenl targets
set(all_srcs) set(all_srcs)
...@@ -128,6 +131,9 @@ function(kernel_library TARGET) ...@@ -128,6 +131,9 @@ function(kernel_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc)
list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc) list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu.cc)
endif() endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/kps/${TARGET}.cu)
list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/kps/${TARGET}.cu)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpudnn/${TARGET}_gpudnn.cu) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpudnn/${TARGET}_gpudnn.cu)
list(APPEND gpudnn_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpudnn/${TARGET}_gpudnn.cu) list(APPEND gpudnn_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpudnn/${TARGET}_gpudnn.cu)
endif() endif()
...@@ -137,6 +143,15 @@ function(kernel_library TARGET) ...@@ -137,6 +143,15 @@ function(kernel_library TARGET)
list(APPEND xpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/xpu/${TARGET}.cc) list(APPEND xpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/xpu/${TARGET}.cc)
endif() endif()
endif() endif()
if (WITH_XPU_KP)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/kps/${TARGET}.cu)
# Change XPU2 file suffix
# NOTE(chenweihang): If we can be sure that the *.kps suffix is no longer used, it can be copied directly to *.xpu
file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/kps/${TARGET}.cu DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/kps)
file(RENAME ${CMAKE_CURRENT_BINARY_DIR}/kps/${TARGET}.cu ${CMAKE_CURRENT_BINARY_DIR}/kps/${TARGET}.kps)
list(APPEND kps_srcs ${CMAKE_CURRENT_BINARY_DIR}/kps/${TARGET}.kps)
endif()
endif()
else() else()
# TODO(chenweihang): impl compile by source later # TODO(chenweihang): impl compile by source later
endif() endif()
...@@ -150,6 +165,7 @@ function(kernel_library TARGET) ...@@ -150,6 +165,7 @@ function(kernel_library TARGET)
list(APPEND all_srcs ${gpu_srcs}) list(APPEND all_srcs ${gpu_srcs})
list(APPEND all_srcs ${xpu_srcs}) list(APPEND all_srcs ${xpu_srcs})
list(APPEND all_srcs ${gpudnn_srcs}) list(APPEND all_srcs ${gpudnn_srcs})
list(APPEND all_srcs ${kps_srcs})
foreach(src ${all_srcs}) foreach(src ${all_srcs})
file(READ ${src} target_content) file(READ ${src} target_content)
string(REGEX MATCHALL "#include \"paddle\/phi\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content}) string(REGEX MATCHALL "#include \"paddle\/phi\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content})
...@@ -159,11 +175,11 @@ function(kernel_library TARGET) ...@@ -159,11 +175,11 @@ function(kernel_library TARGET)
string(REGEX MATCHALL "#include \"paddle\/phi\/kernels\/${kernel_library_SUB_DIR}\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content}) string(REGEX MATCHALL "#include \"paddle\/phi\/kernels\/${kernel_library_SUB_DIR}\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content})
endif() endif()
foreach(include_kernel ${include_kernels}) foreach(include_kernel ${include_kernels})
if ("${kernel_library_SUB_DIR}" STREQUAL "") if ("${kernel_library_SUB_DIR}" STREQUAL "")
string(REGEX REPLACE "#include \"paddle\/phi\/kernels\/" "" kernel_name ${include_kernel}) string(REGEX REPLACE "#include \"paddle\/phi\/kernels\/" "" kernel_name ${include_kernel})
else() else()
string(REGEX REPLACE "#include \"paddle\/phi\/kernels\/${kernel_library_SUB_DIR}\/" "" kernel_name ${include_kernel}) string(REGEX REPLACE "#include \"paddle\/phi\/kernels\/${kernel_library_SUB_DIR}\/" "" kernel_name ${include_kernel})
endif() endif()
string(REGEX REPLACE ".h\"" "" kernel_name ${kernel_name}) string(REGEX REPLACE ".h\"" "" kernel_name ${kernel_name})
list(APPEND kernel_deps ${kernel_name}) list(APPEND kernel_deps ${kernel_name})
endforeach() endforeach()
...@@ -176,11 +192,20 @@ function(kernel_library TARGET) ...@@ -176,11 +192,20 @@ function(kernel_library TARGET)
list(LENGTH gpu_srcs gpu_srcs_len) list(LENGTH gpu_srcs gpu_srcs_len)
list(LENGTH xpu_srcs xpu_srcs_len) list(LENGTH xpu_srcs xpu_srcs_len)
list(LENGTH gpudnn_srcs gpudnn_srcs_len) list(LENGTH gpudnn_srcs gpudnn_srcs_len)
list(LENGTH kps_srcs kps_srcs_len)
list(LENGTH selected_rows_srcs selected_rows_srcs_len) list(LENGTH selected_rows_srcs selected_rows_srcs_len)
# kernel source file level
# level 1: base device kernel
# - cpu_srcs / gpu_srcs / xpu_srcs / kps_srcs
# level 2: device-independent kernel
# - common_srcs
# level 3: Kernel implemented by reusing device-independent kernel
# - selected_rows_srcs
# Build Target according different src organization # Build Target according different src organization
if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR
${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) AND ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0) AND
(${common_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0)) (${common_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0))
# If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule. # If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule.
if (WITH_GPU) if (WITH_GPU)
...@@ -193,6 +218,11 @@ function(kernel_library TARGET) ...@@ -193,6 +218,11 @@ function(kernel_library TARGET)
hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part) hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif() endif()
elseif (WITH_XPU_KP)
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0)
xpu_library(${TARGET}_part SRCS ${cpu_srcs} ${xpu_srcs} ${kps_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
xpu_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET}_part)
endif()
else() else()
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
cc_library(${TARGET}_part SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) cc_library(${TARGET}_part SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
...@@ -200,7 +230,7 @@ function(kernel_library TARGET) ...@@ -200,7 +230,7 @@ function(kernel_library TARGET)
endif() endif()
endif() endif()
# If there are only specific device srcs, build target using this rule. # If there are only specific device srcs, build target using this rule.
elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0)
if (WITH_GPU) if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
...@@ -209,6 +239,10 @@ function(kernel_library TARGET) ...@@ -209,6 +239,10 @@ function(kernel_library TARGET)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
elseif (WITH_XPU_KP)
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0)
xpu_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${kps_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
else() else()
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
...@@ -222,6 +256,9 @@ function(kernel_library TARGET) ...@@ -222,6 +256,9 @@ function(kernel_library TARGET)
elseif (WITH_ROCM) elseif (WITH_ROCM)
hip_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) hip_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part) hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
elseif (WITH_XPU_KP)
xpu_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
xpu_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
else() else()
cc_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) cc_library(${TARGET}_part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part) cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET}_part)
...@@ -232,6 +269,8 @@ function(kernel_library TARGET) ...@@ -232,6 +269,8 @@ function(kernel_library TARGET)
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM) elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_XPU_KP)
xpu_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else() else()
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
...@@ -240,6 +279,8 @@ function(kernel_library TARGET) ...@@ -240,6 +279,8 @@ function(kernel_library TARGET)
nv_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) nv_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM) elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_XPU_KP)
xpu_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else() else()
cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif() endif()
...@@ -249,7 +290,7 @@ function(kernel_library TARGET) ...@@ -249,7 +290,7 @@ function(kernel_library TARGET)
if (${target_build_flag} EQUAL 1) if (${target_build_flag} EQUAL 1)
if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR
${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0 OR
${gpudnn_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0) ${gpudnn_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0)
# append target into PHI_KERNELS property # append target into PHI_KERNELS property
get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS) get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
...@@ -275,6 +316,9 @@ function(kernel_library TARGET) ...@@ -275,6 +316,9 @@ function(kernel_library TARGET)
if (${gpudnn_srcs_len} GREATER 0) if (${gpudnn_srcs_len} GREATER 0)
kernel_declare(${gpudnn_srcs}) kernel_declare(${gpudnn_srcs})
endif() endif()
if (${kps_srcs_len} GREATER 0)
kernel_declare(${kps_srcs})
endif()
if (${selected_rows_srcs_len} GREATER 0) if (${selected_rows_srcs_len} GREATER 0)
kernel_declare(${selected_rows_srcs}) kernel_declare(${selected_rows_srcs})
endif() endif()
......
...@@ -68,6 +68,8 @@ OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) { ...@@ -68,6 +68,8 @@ OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
library_type = LibraryType::kMKLDNN; library_type = LibraryType::kMKLDNN;
} else if (kernel_key.backend() == phi::Backend::GPUDNN) { } else if (kernel_key.backend() == phi::Backend::GPUDNN) {
library_type = LibraryType::kCUDNN; library_type = LibraryType::kCUDNN;
} else if (kernel_key.backend() == phi::Backend::KPS) {
library_type = LibraryType::kKP;
} else { } else {
// do nothing // do nothing
} }
...@@ -82,6 +84,8 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey( ...@@ -82,6 +84,8 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey(
backend = phi::Backend::MKLDNN; backend = phi::Backend::MKLDNN;
} else if (kernel_type.library_type_ == LibraryType::kCUDNN) { } else if (kernel_type.library_type_ == LibraryType::kCUDNN) {
backend = phi::Backend::GPUDNN; backend = phi::Backend::GPUDNN;
} else if (kernel_type.library_type_ == LibraryType::kKP) {
backend = phi::Backend::KPS;
} else { } else {
// do // do
} }
......
...@@ -227,4 +227,12 @@ class GPUContext : public DeviceContext { ...@@ -227,4 +227,12 @@ class GPUContext : public DeviceContext {
// must use different function name for cudnn kernel // must use different function name for cudnn kernel
using GPUDNNContext = GPUContext; using GPUDNNContext = GPUContext;
// KPS (Kernel PrimitiveS API) needs to exist as a kind of backend,
// because we want to implement a KPS-based kernel and make it run
// on GPU and XPU at the same time, so we need KPSContext when registering
// KPS Kernel. Note: XPU and GPU cannot be compiled at the same time!
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
using KPSContext = GPUContext;
#endif
} // namespace phi } // namespace phi
...@@ -66,4 +66,12 @@ class XPUContext : public DeviceContext { ...@@ -66,4 +66,12 @@ class XPUContext : public DeviceContext {
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;
}; };
// KPS (Kernel PrimitiveS API) needs to exist as a kind of backend,
// because we want to implement a KPS-based kernel and make it run
// on GPU and XPU at the same time, so we need KPSContext when registering
// KPS Kernel. Note: XPU and GPU cannot be compiled at the same time!
#if PADDLE_WITH_XPU_KP
using KPSContext = XPUContext;
#endif
} // namespace phi } // namespace phi
...@@ -52,6 +52,9 @@ enum class Backend : uint8_t { ...@@ -52,6 +52,9 @@ enum class Backend : uint8_t {
MKLDNN, MKLDNN,
GPUDNN, // cuDNN and hipDNN GPUDNN, // cuDNN and hipDNN
// paddle kernel primitives backend
KPS,
// end of backend types // end of backend types
NUM_BACKENDS, NUM_BACKENDS,
...@@ -115,6 +118,9 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) { ...@@ -115,6 +118,9 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) {
case Backend::GPUDNN: case Backend::GPUDNN:
os << "GPUDNN"; os << "GPUDNN";
break; break;
case Backend::KPS:
os << "KPS";
break;
default: { default: {
size_t device_type_id_ = static_cast<size_t>(backend) - size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(Backend::NUM_BACKENDS); static_cast<size_t>(Backend::NUM_BACKENDS);
...@@ -147,6 +153,8 @@ inline Backend StringToBackend(const char* backend_cstr) { ...@@ -147,6 +153,8 @@ inline Backend StringToBackend(const char* backend_cstr) {
return Backend::MKLDNN; return Backend::MKLDNN;
} else if (s == std::string("GPUDNN")) { } else if (s == std::string("GPUDNN")) {
return Backend::GPUDNN; return Backend::GPUDNN;
} else if (s == std::string("KPS")) {
return Backend::KPS;
} else { } else {
return static_cast<Backend>(static_cast<size_t>(Backend::NUM_BACKENDS) + return static_cast<Backend>(static_cast<size_t>(Backend::NUM_BACKENDS) +
phi::GetOrRegisterGlobalDeviceTypeId(s)); phi::GetOrRegisterGlobalDeviceTypeId(s));
......
...@@ -66,6 +66,14 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) { ...@@ -66,6 +66,14 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) {
case phi::Backend::XPU: case phi::Backend::XPU:
return phi::XPUPlace( return phi::XPUPlace(
set_device_id ? phi::backends::xpu::GetXPUCurrentDeviceId() : 0); set_device_id ? phi::backends::xpu::GetXPUCurrentDeviceId() : 0);
#endif
case phi::Backend::KPS:
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return phi::GPUPlace(
set_device_id ? phi::backends::gpu::GetCurrentDeviceId() : 0);
#elif defined(PADDLE_WITH_XPU_KP)
return phi::XPUPlace(
set_device_id ? phi::backends::xpu::GetXPUCurrentDeviceId() : 0);
#endif #endif
default: { default: {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
......
...@@ -44,6 +44,9 @@ TEST(Backend, OStream) { ...@@ -44,6 +44,9 @@ TEST(Backend, OStream) {
oss << phi::Backend::GPUDNN; oss << phi::Backend::GPUDNN;
EXPECT_EQ(oss.str(), "GPUDNN"); EXPECT_EQ(oss.str(), "GPUDNN");
oss.str(""); oss.str("");
oss << phi::Backend::KPS;
EXPECT_EQ(oss.str(), "KPS");
oss.str("");
try { try {
oss << phi::Backend::NUM_BACKENDS; oss << phi::Backend::NUM_BACKENDS;
} catch (const std::exception& exception) { } catch (const std::exception& exception) {
...@@ -61,6 +64,7 @@ TEST(Backend, StringToBackend) { ...@@ -61,6 +64,7 @@ TEST(Backend, StringToBackend) {
EXPECT_EQ(phi::Backend::NPU, pexp::StringToBackend("NPU")); EXPECT_EQ(phi::Backend::NPU, pexp::StringToBackend("NPU"));
EXPECT_EQ(phi::Backend::MKLDNN, pexp::StringToBackend("MKLDNN")); EXPECT_EQ(phi::Backend::MKLDNN, pexp::StringToBackend("MKLDNN"));
EXPECT_EQ(phi::Backend::GPUDNN, pexp::StringToBackend("GPUDNN")); EXPECT_EQ(phi::Backend::GPUDNN, pexp::StringToBackend("GPUDNN"));
EXPECT_EQ(phi::Backend::KPS, pexp::StringToBackend("KPS"));
EXPECT_EQ(static_cast<phi::Backend>( EXPECT_EQ(static_cast<phi::Backend>(
static_cast<size_t>(phi::Backend::NUM_BACKENDS) + 1), static_cast<size_t>(phi::Backend::NUM_BACKENDS) + 1),
pexp::StringToBackend("CustomBackend")); pexp::StringToBackend("CustomBackend"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册