未验证 提交 c39aa18e 编写于 作者: A Aganlengzi 提交者: GitHub

[custom kernel]Upgrade support for multiple libs (#40223)

* [custom kernel]Upgade support for multi libs

* upgrade phi_custom_kernel deps
上级 c722ee69
...@@ -443,7 +443,7 @@ cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framewo ...@@ -443,7 +443,7 @@ cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framewo
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ) #cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
#cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) #cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator phi_custom_kernel) set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator)
cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES}) cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES})
......
...@@ -31,7 +31,7 @@ cc_library(paddle_infer_contrib SRCS paddle_infer_contrib.cc DEPS zero_copy_tens ...@@ -31,7 +31,7 @@ cc_library(paddle_infer_contrib SRCS paddle_infer_contrib.cc DEPS zero_copy_tens
cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc) cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc)
set(paddle_inference_api_deps lod_tensor scope reset_tensor_array set(paddle_inference_api_deps lod_tensor scope reset_tensor_array
analysis_config paddle_infer_contrib zero_copy_tensor trainer_desc_proto custom_operator phi_custom_kernel) analysis_config paddle_infer_contrib zero_copy_tensor trainer_desc_proto custom_operator)
if(WITH_CRYPTO) if(WITH_CRYPTO)
list(APPEND paddle_inference_api_deps paddle_crypto) list(APPEND paddle_inference_api_deps paddle_crypto)
......
...@@ -117,7 +117,7 @@ endif() ...@@ -117,7 +117,7 @@ endif()
cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost) cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
# seperate init from device_context to avoid cycle dependencies # seperate init from device_context to avoid cycle dependencies
cc_library(init SRCS init.cc DEPS device_context phi_custom_kernel) cc_library(init SRCS init.cc DEPS device_context custom_kernel)
# memcpy depends on device_context, here add deps individually for # memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies # avoiding cycle dependencies
......
...@@ -154,8 +154,8 @@ void LoadCustomDevice(const std::string &library_dir) { ...@@ -154,8 +154,8 @@ void LoadCustomDevice(const std::string &library_dir) {
"Fail to open library: %s with error: %s", lib_path, dlerror())); "Fail to open library: %s with error: %s", lib_path, dlerror()));
phi::LoadCustomRuntimeLib(lib_path, dso_handle); phi::LoadCustomRuntimeLib(lib_path, dso_handle);
phi::LoadCustomKernelLib(lib_path, dso_handle);
} }
phi::CustomKernelMap::Instance().RegisterCustomKernels();
LOG(INFO) << "Finished in LoadCustomDevice with libs_path: [" << library_dir LOG(INFO) << "Finished in LoadCustomDevice with libs_path: [" << library_dir
<< "]"; << "]";
} }
......
...@@ -25,7 +25,7 @@ cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor) ...@@ -25,7 +25,7 @@ cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor phi_enforce ddim memcpy) cc_library(selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor phi_enforce ddim memcpy)
cc_library(phi_device_context SRCS device_context.cc DEPS dense_tensor selected_rows) cc_library(phi_device_context SRCS device_context.cc DEPS dense_tensor selected_rows)
cc_library(phi_custom_kernel SRCS custom_kernel.cc DEPS kernel_factory convert_utils op_registry phi_tensor_raw) cc_library(custom_kernel SRCS custom_kernel.cc DEPS kernel_factory)
# Will remove once we implemented MKLDNN_Tensor # Will remove once we implemented MKLDNN_Tensor
if(WITH_MKLDNN) if(WITH_MKLDNN)
......
...@@ -12,21 +12,29 @@ ...@@ -12,21 +12,29 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include "paddle/phi/core/custom_kernel.h" #include "paddle/phi/core/custom_kernel.h"
namespace phi { namespace phi {
void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) { void CustomKernelMap::RegisterCustomKernel(const std::string& name,
auto& kernel_info_map = custom_kernel_map.GetMap(); const KernelKey& key,
VLOG(3) << "Size of custom_kernel_map: " << kernel_info_map.size(); const Kernel& kernel) {
PADDLE_ENFORCE_EQ(kernels_[name].find(key),
kernels_[name].end(),
phi::errors::AlreadyExists(
"The custom kernel [%s:%s] has been already existed in "
"CustomKernelMap, please check if any duplicate kernel "
"info in your lib(s) before load again.",
name,
key));
kernels_[name][key] = kernel;
}
void CustomKernelMap::RegisterCustomKernels() {
VLOG(3) << "Size of custom_kernel_map: " << kernels_.size();
auto& kernels = KernelFactory::Instance().kernels(); auto& kernels = KernelFactory::Instance().kernels();
for (auto& pair : kernel_info_map) { for (auto& pair : kernels_) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
kernels.find(pair.first), kernels.find(pair.first),
kernels.end(), kernels.end(),
...@@ -38,8 +46,8 @@ void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) { ...@@ -38,8 +46,8 @@ void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
kernels[pair.first].find(info_pair.first), kernels[pair.first].find(info_pair.first),
kernels[pair.first].end(), kernels[pair.first].end(),
phi::errors::InvalidArgument( phi::errors::AlreadyExists(
"The operator <%s>'s kernel: %s has been already existed " "The kernel [%s:%s] has been already existed "
"in Paddle, please contribute PR if it is necessary " "in Paddle, please contribute PR if it is necessary "
"to optimize the kernel code. Custom kernel does NOT support " "to optimize the kernel code. Custom kernel does NOT support "
"to replace existing kernel in Paddle.", "to replace existing kernel in Paddle.",
...@@ -48,43 +56,14 @@ void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) { ...@@ -48,43 +56,14 @@ void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) {
kernels[pair.first][info_pair.first] = info_pair.second; kernels[pair.first][info_pair.first] = info_pair.second;
VLOG(3) << "Successed in registering operator <" << pair.first VLOG(3) << "Successed in registering kernel [" << pair.first << ":"
<< ">'s kernel: " << info_pair.first << info_pair.first
<< " to Paddle. It will be used like native ones."; << "] to Paddle. It will be used like native ones.";
} }
kernels_[pair.first].clear();
} }
LOG(INFO) << "Successed in loading custom kernels.";
kernels_.clear();
} }
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle) {
#ifdef _LINUX
typedef phi::CustomKernelMap& get_custom_kernel_map_t();
auto* func = reinterpret_cast<get_custom_kernel_map_t*>(
dlsym(dso_handle, "PD_GetCustomKernelMap"));
if (func == nullptr) {
LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find "
<< "PD_GetCustomKernelMap symbol in this lib.";
return;
}
auto& custom_kernel_map = func();
phi::RegisterCustomKernels(custom_kernel_map);
LOG(INFO) << "Successed in loading custom kernels in lib: " << dso_lib_path;
#else
VLOG(3) << "Unsupported: Custom kernel is only implemented on Linux.";
#endif
return;
}
} // namespace phi } // namespace phi
#ifdef __cplusplus
extern "C" {
#endif
// C-API to get global CustomKernelMap.
phi::CustomKernelMap& PD_GetCustomKernelMap() {
return phi::CustomKernelMap::Instance();
}
#ifdef __cplusplus
} // end extern "C"
#endif
...@@ -29,6 +29,12 @@ class CustomKernelMap { ...@@ -29,6 +29,12 @@ class CustomKernelMap {
return g_custom_kernel_info_map; return g_custom_kernel_info_map;
} }
void RegisterCustomKernel(const std::string& kernel_name,
const KernelKey& kernel_key,
const Kernel& kernel);
void RegisterCustomKernels();
KernelNameMap& Kernels() { return kernels_; } KernelNameMap& Kernels() { return kernels_; }
const KernelNameMap& GetMap() const { return kernels_; } const KernelNameMap& GetMap() const { return kernels_; }
...@@ -40,12 +46,4 @@ class CustomKernelMap { ...@@ -40,12 +46,4 @@ class CustomKernelMap {
KernelNameMap kernels_; KernelNameMap kernels_;
}; };
/**
* Note:
* Used to register custom kernels to KernelFactory.
*/
void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map);
// Load custom kernel lib and register
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle);
} // namespace phi } // namespace phi
...@@ -210,7 +210,8 @@ struct KernelRegistrar { ...@@ -210,7 +210,8 @@ struct KernelRegistrar {
if (reg_type == RegType::INNER) { if (reg_type == RegType::INNER) {
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
} else { } else {
CustomKernelMap::Instance().Kernels()[kernel_name][kernel_key] = kernel; CustomKernelMap::Instance().RegisterCustomKernel(
kernel_name, kernel_key, kernel);
} }
} }
}; };
......
...@@ -10,7 +10,7 @@ add_subdirectory(funcs) ...@@ -10,7 +10,7 @@ add_subdirectory(funcs)
set_property(GLOBAL PROPERTY PHI_KERNELS "") set_property(GLOBAL PROPERTY PHI_KERNELS "")
# [ 1. Common kernel compilation dependencies ] # [ 1. Common kernel compilation dependencies ]
set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils) set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor softmax) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor softmax)
# remove this dep after removing fluid deps on tensor creation # remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils)
......
set(SPARSE_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils math_function) set(SPARSE_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils math_function custom_kernel)
register_kernels(DEPS ${SPARSE_KERNEL_DEPS} SUB_DIR "sparse_kernel") register_kernels(DEPS ${SPARSE_KERNEL_DEPS} SUB_DIR "sparse_kernel")
cc_test(test_custom_kernel SRCS test_custom_kernel.cc DEPS phi_custom_kernel) cc_test(test_custom_kernel SRCS test_custom_kernel.cc DEPS custom_kernel)
cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor) cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor)
cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc)
cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_type_info SRCS test_type_info.cc)
......
...@@ -172,7 +172,9 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -172,7 +172,9 @@ TEST(CustomKernel, custom_kernel_dot) {
fake_dot_kernels.end()); fake_dot_kernels.end());
// register // register
phi::RegisterCustomKernels(phi::CustomKernelMap::Instance()); phi::CustomKernelMap::Instance().RegisterCustomKernels();
EXPECT_EQ(0, static_cast<int>(custom_fake_dot_kernels.size()));
EXPECT_TRUE(fake_dot_kernels.find( EXPECT_TRUE(fake_dot_kernels.find(
phi::KernelKey(backend, layout, phi::DataType::FLOAT32)) != phi::KernelKey(backend, layout, phi::DataType::FLOAT32)) !=
......
# for paddle test case # for paddle test case
if(WITH_TESTING) if(WITH_TESTING)
cc_library(paddle_gtest_main SRCS paddle_gtest_main.cc DEPS init device_context memory gtest gflags) cc_library(paddle_gtest_main SRCS paddle_gtest_main.cc DEPS init device_context memory gtest gflags proto_desc)
endif() endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册