未验证 提交 7635d699 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

support build C++ cuda shared lib (#2401)

* support build C++ cuda shared lib
上级 dde12f0d
...@@ -490,6 +490,9 @@ function(nv_binary TARGET_NAME) ...@@ -490,6 +490,9 @@ function(nv_binary TARGET_NAME)
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(nv_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(nv_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cuda_add_executable(${TARGET_NAME} ${nv_binary_SRCS}) cuda_add_executable(${TARGET_NAME} ${nv_binary_SRCS})
target_link_libraries(${TARGET_NAME} ${CUDNN_LIBRARY} ${CUBLAS_LIBRARIES})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${os_dependency_modules})
if(nv_binary_DEPS) if(nv_binary_DEPS)
target_link_libraries(${TARGET_NAME} ${nv_binary_DEPS}) target_link_libraries(${TARGET_NAME} ${nv_binary_DEPS})
add_dependencies(${TARGET_NAME} ${nv_binary_DEPS}) add_dependencies(${TARGET_NAME} ${nv_binary_DEPS})
...@@ -507,7 +510,7 @@ function(nv_test TARGET_NAME) ...@@ -507,7 +510,7 @@ function(nv_test TARGET_NAME)
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest
gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY} ${CUBLAS_LIBRARIES} ) gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY} ${CUBLAS_LIBRARIES} )
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest gflags glog) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest gflags glog)
common_link(${TARGET_NAME}) common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME})
......
...@@ -248,6 +248,7 @@ endfunction() ...@@ -248,6 +248,7 @@ endfunction()
set(arm_kernels CACHE INTERNAL "arm kernels") set(arm_kernels CACHE INTERNAL "arm kernels")
set(x86_kernels CACHE INTERNAL "x86 kernels") set(x86_kernels CACHE INTERNAL "x86 kernels")
set(cuda_kernels CACHE INTERNAL "cuda kernels")
set(fpga_kernels CACHE INTERNAL "fpga kernels") set(fpga_kernels CACHE INTERNAL "fpga kernels")
set(npu_kernels CACHE INTERNAL "npu kernels") set(npu_kernels CACHE INTERNAL "npu kernels")
set(xpu_kernels CACHE INTERNAL "xpu kernels") set(xpu_kernels CACHE INTERNAL "xpu kernels")
......
...@@ -117,6 +117,9 @@ if (LITE_WITH_X86) ...@@ -117,6 +117,9 @@ if (LITE_WITH_X86)
add_dependencies(publish_inference_x86_cxx_demos paddle_full_api_shared eigen3) add_dependencies(publish_inference_x86_cxx_demos paddle_full_api_shared eigen3)
endif() endif()
if(LITE_WITH_CUDA)
add_dependencies(publish_inference paddle_full_api_shared)
endif(LITE_WITH_CUDA)
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM) if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND LITE_WITH_ARM)
if (NOT LITE_ON_TINY_PUBLISH) if (NOT LITE_ON_TINY_PUBLISH)
# add cxx lib # add cxx lib
......
...@@ -9,7 +9,7 @@ if (LITE_ON_TINY_PUBLISH) ...@@ -9,7 +9,7 @@ if (LITE_ON_TINY_PUBLISH)
set(CMAKE_C_FLAGS_RELEASE "-Os -DNDEBUG") set(CMAKE_C_FLAGS_RELEASE "-Os -DNDEBUG")
endif() endif()
set(light_lib_DEPS light_api paddle_api paddle_api_light optimizer) set(light_lib_DEPS light_api paddle_api paddle_api_light optimizer)
if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_X86 OR ARM_TARGET_OS STREQUAL "android" OR ARM_TARGET_OS STREQUAL "armlinux")) if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR ARM_TARGET_OS STREQUAL "android" OR ARM_TARGET_OS STREQUAL "armlinux"))
#full api dynamic library #full api dynamic library
add_library(paddle_full_api_shared SHARED "") add_library(paddle_full_api_shared SHARED "")
target_sources(paddle_full_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc) target_sources(paddle_full_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc)
...@@ -19,7 +19,9 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_X86 OR ARM_TARGET_OS STREQUAL "and ...@@ -19,7 +19,9 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_X86 OR ARM_TARGET_OS STREQUAL "and
add_dependencies(paddle_full_api_shared xxhash) add_dependencies(paddle_full_api_shared xxhash)
target_link_libraries(paddle_full_api_shared xxhash) target_link_libraries(paddle_full_api_shared xxhash)
endif() endif()
if(LITE_WITH_CUDA)
target_link_libraries(paddle_full_api_shared ${math_cuda} "-Wl,--whole-archive" ${cuda_kernels} "-Wl,--no-whole-archive")
endif(LITE_WITH_CUDA)
#light api dynamic library #light api dynamic library
lite_cc_library(paddle_light_api_shared MODULE lite_cc_library(paddle_light_api_shared MODULE
SRCS light_api_shared.cc SRCS light_api_shared.cc
...@@ -59,6 +61,7 @@ endif() ...@@ -59,6 +61,7 @@ endif()
message(STATUS "get ops ${ops}") message(STATUS "get ops ${ops}")
message(STATUS "get X86 kernels ${x86_kernels}") message(STATUS "get X86 kernels ${x86_kernels}")
message(STATUS "get CUDA kernels ${cuda_kernels}")
message(STATUS "get Host kernels ${host_kernels}") message(STATUS "get Host kernels ${host_kernels}")
message(STATUS "get ARM kernels ${arm_kernels}") message(STATUS "get ARM kernels ${arm_kernels}")
message(STATUS "get NPU kernels ${npu_kernels}") message(STATUS "get NPU kernels ${npu_kernels}")
...@@ -289,7 +292,8 @@ if(NOT IOS) ...@@ -289,7 +292,8 @@ if(NOT IOS)
XPU_DEPS ${xpu_kernels} XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels} CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels} FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels}) X86_DEPS ${x86_kernels}
CUDA_DEPS ${cuda_kernels})
lite_cc_binary(benchmark_bin SRCS benchmark.cc DEPS paddle_api_full paddle_api_light gflags utils lite_cc_binary(benchmark_bin SRCS benchmark.cc DEPS paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${ops} ${host_kernels}
ARM_DEPS ${arm_kernels} ARM_DEPS ${arm_kernels}
...@@ -297,7 +301,9 @@ if(NOT IOS) ...@@ -297,7 +301,9 @@ if(NOT IOS)
XPU_DEPS ${xpu_kernels} XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels} CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels} FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels}) X86_DEPS ${x86_kernels}
CUDA_DEPS ${cuda_kernels})
endif() endif()
#lite_cc_binary(cxx_api_bin SRCS cxx_api_bin.cc #lite_cc_binary(cxx_api_bin SRCS cxx_api_bin.cc
......
...@@ -207,6 +207,13 @@ class Context<TargetType::kCUDA> { ...@@ -207,6 +207,13 @@ class Context<TargetType::kCUDA> {
ctx->cublas_fp32_ = cublas_fp32_; ctx->cublas_fp32_ = cublas_fp32_;
} }
CUDAContext& operator=(const CUDAContext& context) {
this->Init(
context.device_id_, context.exec_stream_id_, context.io_stream_id_);
this->cublas_fp32_ = context.cublas_fp32_;
return *this;
}
const cudaStream_t& exec_stream() const { return exec_stream_; } const cudaStream_t& exec_stream() const { return exec_stream_; }
void SetExecStream(cudaStream_t stream) { exec_stream_ = stream; } void SetExecStream(cudaStream_t stream) { exec_stream_ = stream; }
......
...@@ -237,10 +237,10 @@ function make_cuda { ...@@ -237,10 +237,10 @@ function make_cuda {
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF \ -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF \
-DWITH_TESTING=OFF \ -DWITH_TESTING=OFF \
-DLITE_WITH_ARM=OFF \ -DLITE_WITH_ARM=OFF \
-DLITE_WITH_PYTHON=ON \ -DLITE_WITH_PYTHON=${BUILD_PYTHON} \
-DLITE_BUILD_EXTRA=ON -DLITE_BUILD_EXTRA=ON
make publish_inference_python_lib -j8 make publish_inference -j4
cd - cd -
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册