Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
80b7ef6f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
80b7ef6f
编写于
8月 12, 2019
作者:
W
wopeizl
提交者:
GitHub
8月 12, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add tensorrt support for windows (#19084)
* add tensorrt support for windows
上级
744279fe
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
60 addition
and
19 deletion
+60
-19
cmake/configure.cmake
cmake/configure.cmake
+14
-8
cmake/inference_lib.cmake
cmake/inference_lib.cmake
+1
-1
cmake/tensorrt.cmake
cmake/tensorrt.cmake
+16
-2
paddle/fluid/inference/anakin/convert/op_converter.h
paddle/fluid/inference/anakin/convert/op_converter.h
+1
-1
paddle/fluid/inference/api/demo_ci/CMakeLists.txt
paddle/fluid/inference/api/demo_ci/CMakeLists.txt
+12
-0
paddle/fluid/inference/tensorrt/convert/op_converter.h
paddle/fluid/inference/tensorrt/convert/op_converter.h
+3
-3
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+2
-2
paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h
paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h
+1
-1
paddle/fluid/platform/dynload/dynamic_loader.cc
paddle/fluid/platform/dynload/dynamic_loader.cc
+2
-0
paddle/fluid/platform/dynload/tensorrt.h
paddle/fluid/platform/dynload/tensorrt.h
+3
-1
python/setup.py.in
python/setup.py.in
+5
-0
未找到文件。
cmake/configure.cmake
浏览文件 @
80b7ef6f
...
...
@@ -88,14 +88,20 @@ if(WITH_GPU)
include_directories
(
${
CUDA_TOOLKIT_INCLUDE
}
)
if
(
TENSORRT_FOUND
)
if
(
${
CUDA_VERSION_MAJOR
}
VERSION_LESS 8
)
message
(
FATAL_ERROR
"TensorRT needs CUDA >= 8.0 to compile"
)
endif
()
if
(
${
CUDNN_MAJOR_VERSION
}
VERSION_LESS 7
)
message
(
FATAL_ERROR
"TensorRT needs CUDNN >= 7.0 to compile"
)
endif
()
if
(
${
TENSORRT_MAJOR_VERSION
}
VERSION_LESS 4
)
message
(
FATAL_ERROR
"Paddle needs TensorRT >= 4.0 to compile"
)
if
(
WIN32
)
if
(
${
CUDA_VERSION_MAJOR
}
VERSION_LESS 9
)
message
(
FATAL_ERROR
"TensorRT needs CUDA >= 9.0 to compile on Windows"
)
endif
()
else
()
if
(
${
CUDA_VERSION_MAJOR
}
VERSION_LESS 8
)
message
(
FATAL_ERROR
"TensorRT needs CUDA >= 8.0 to compile"
)
endif
()
if
(
${
CUDNN_MAJOR_VERSION
}
VERSION_LESS 7
)
message
(
FATAL_ERROR
"TensorRT needs CUDNN >= 7.0 to compile"
)
endif
()
if
(
${
TENSORRT_MAJOR_VERSION
}
VERSION_LESS 4
)
message
(
FATAL_ERROR
"Paddle needs TensorRT >= 4.0 to compile"
)
endif
()
endif
()
include_directories
(
${
TENSORRT_INCLUDE_DIR
}
)
endif
()
...
...
cmake/inference_lib.cmake
浏览文件 @
80b7ef6f
...
...
@@ -211,7 +211,7 @@ set(module "inference/api")
if
(
TENSORRT_FOUND
)
copy
(
tensorrt_lib DEPS
${
inference_deps
}
SRCS
${
TENSORRT_ROOT
}
/include/Nv*.h
${
TENSORRT_ROOT
}
/lib/
lib
nvinfer*
SRCS
${
TENSORRT_ROOT
}
/include/Nv*.h
${
TENSORRT_ROOT
}
/lib/
*
nvinfer*
DSTS
${
FLUID_INSTALL_DIR
}
/third_party/install/tensorrt/include
${
FLUID_INSTALL_DIR
}
/third_party/install/tensorrt/lib
)
endif
()
...
...
cmake/tensorrt.cmake
浏览文件 @
80b7ef6f
...
...
@@ -2,14 +2,28 @@ if(NOT WITH_GPU)
return
()
endif
()
set
(
TENSORRT_ROOT
"/usr"
CACHE PATH
"TENSORRT ROOT"
)
if
(
WIN32
)
if
(
"
${
TENSORRT_ROOT
}
"
STREQUAL
""
)
message
(
WARNING
"Please specify the TensorRT root path: TENSORRT_ROOT."
)
endif
()
string
(
REPLACE
"
\\
"
"/"
TENSORRT_ROOT
"
${
TENSORRT_ROOT
}
"
)
set
(
TR_INFER_LIB nvinfer.lib
)
set
(
TR_INFER_RT nvinfer.dll
)
set
(
TR_INFER_PLUGIN_RT nvinfer_plugin.dll
)
else
()
set
(
TENSORRT_ROOT
"/usr"
CACHE PATH
"TENSORRT ROOT"
)
set
(
TR_INFER_LIB libnvinfer.a
)
set
(
TR_INFER_RT libnvinfer.so
)
set
(
TR_INFER_PLUGIN_RT libnvinfer_plugin.so
)
endif
()
find_path
(
TENSORRT_INCLUDE_DIR NvInfer.h
PATHS
${
TENSORRT_ROOT
}
${
TENSORRT_ROOT
}
/include
$ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/include
NO_DEFAULT_PATH
)
find_library
(
TENSORRT_LIBRARY NAMES
libnvinfer.so libnvinfer.a
find_library
(
TENSORRT_LIBRARY NAMES
${
TR_INFER_LIB
}
${
TR_INFER_RT
}
PATHS
${
TENSORRT_ROOT
}
${
TENSORRT_ROOT
}
/lib
$ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/lib
NO_DEFAULT_PATH
...
...
paddle/fluid/inference/anakin/convert/op_converter.h
浏览文件 @
80b7ef6f
...
...
@@ -219,7 +219,7 @@ template class AnakinOpConverter<::anakin::saber::X86,
#define USE_ANAKIN_CONVERTER_BASE(op_type__, place_type__, precision_type__) \
extern int Touch_anakin_##op_type__##_##place_type__##_##precision_type__(); \
int use_converter_anakin_##op_type__##_##place_type__##_##precision_type__ \
__attribute__((unused)) =
\
UNUSED =
\
Touch_anakin_##op_type__##_##place_type__##_##precision_type__();
#if defined(PADDLE_WITH_CUDA) && defined(ANAKIN_X86_PLACE)
...
...
paddle/fluid/inference/api/demo_ci/CMakeLists.txt
浏览文件 @
80b7ef6f
...
...
@@ -144,6 +144,10 @@ if(WITH_GPU)
endif
()
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/libcudart
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
)
else
()
if
(
USE_TENSORRT
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/nvinfer
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
TENSORRT_LIB_DIR
}
/nvinfer_plugin
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
endif
()
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/cudart
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/cublas
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
set
(
DEPS
${
DEPS
}
${
CUDA_LIB
}
/cudnn
${
CMAKE_STATIC_LIBRARY_SUFFIX
}
)
...
...
@@ -153,6 +157,14 @@ endif()
add_executable
(
${
DEMO_NAME
}
${
DEMO_NAME
}
.cc
)
target_link_libraries
(
${
DEMO_NAME
}
${
DEPS
}
)
if
(
WIN32
)
if
(
USE_TENSORRT
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
TENSORRT_LIB_DIR
}
/nvinfer
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
TENSORRT_LIB_DIR
}
/nvinfer_plugin
${
CMAKE_SHARED_LIBRARY_SUFFIX
}
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
)
endif
()
if
(
WITH_MKL
)
add_custom_command
(
TARGET
${
DEMO_NAME
}
POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E copy
${
MKLDNN_PATH
}
/lib/mkldnn.dll
${
CMAKE_BINARY_DIR
}
/
${
CMAKE_BUILD_TYPE
}
...
...
paddle/fluid/inference/tensorrt/convert/op_converter.h
浏览文件 @
80b7ef6f
...
...
@@ -225,7 +225,7 @@ class OpConverter {
return 0; \
}
#define USE_TRT_CONVERTER(op_type__)
\
extern int TouchConverterRegister_##op_type__();
\
static int use_op_converter_trt_##op_type__
__attribute__((unused))
= \
#define USE_TRT_CONVERTER(op_type__) \
extern int TouchConverterRegister_##op_type__(); \
static int use_op_converter_trt_##op_type__
UNUSED
= \
TouchConverterRegister_##op_type__();
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
80b7ef6f
...
...
@@ -219,8 +219,8 @@ class TensorRTEngine {
// TensorRT has too many layers, so that is not wise to add member functions for
// them, and an macro like this is more extensible when underlying TensorRT
// library add new layer supports.
#define TRT_ENGINE_ADD_LAYER(engine__, layer__,
ARGS
...) \
engine__->network()->add##layer__(
ARGS
);
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ...) \
engine__->network()->add##layer__(
__VA_ARGS__
);
class
TRTEngineManager
{
public:
...
...
paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h
浏览文件 @
80b7ef6f
...
...
@@ -68,7 +68,7 @@ class TrtPluginRegistrar {
#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func) \
static paddle::inference::tensorrt::plugin::TrtPluginRegistrar \
trt_plugin_registrar##ctr
__attribute__((unused)) =
\
trt_plugin_registrar##ctr
UNUSED =
\
paddle::inference::tensorrt::plugin::TrtPluginRegistrar( \
name, deserialize_func)
...
...
paddle/fluid/platform/dynload/dynamic_loader.cc
浏览文件 @
80b7ef6f
...
...
@@ -247,6 +247,8 @@ void* GetNCCLDsoHandle() {
void
*
GetTensorRtDsoHandle
()
{
#if defined(__APPLE__) || defined(__OSX__)
return
GetDsoHandleFromSearchPath
(
FLAGS_tensorrt_dir
,
"libnvinfer.dylib"
);
#elif defined(_WIN32)
return
GetDsoHandleFromSearchPath
(
FLAGS_mklml_dir
,
"nvinfer.dll"
);
#else
return
GetDsoHandleFromSearchPath
(
FLAGS_tensorrt_dir
,
"libnvinfer.so"
);
#endif
...
...
paddle/fluid/platform/dynload/tensorrt.h
浏览文件 @
80b7ef6f
...
...
@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#include <NvInfer.h>
#if !defined(_WIN32)
#include <dlfcn.h>
#endif
#include <mutex> // NOLINT
...
...
@@ -34,7 +36,7 @@ extern void* tensorrt_dso_handle;
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using tensorrt_func = decltype(
__name(args...)) (*)(Args...);
\
using tensorrt_func = decltype(
&::__name);
\
std::call_once(tensorrt_dso_flag, []() { \
tensorrt_dso_handle = \
paddle::platform::dynload::GetTensorRtDsoHandle(); \
...
...
python/setup.py.in
浏览文件 @
80b7ef6f
...
...
@@ -166,6 +166,11 @@ package_data['paddle.libs']= []
package_data['paddle.libs']=[('libwarpctc' if os.name != 'nt' else 'warpctc') + ext_name]
shutil.copy('${WARPCTC_LIBRARIES}', libs_path)
if '${TENSORRT_FOUND}' == 'ON' and os.name == 'nt':
shutil.copy(os.path.join('${TENSORRT_ROOT}', 'lib', '${TR_INFER_RT}'), libs_path)
shutil.copy(os.path.join('${TENSORRT_ROOT}', 'lib', '${TR_INFER_PLUGIN_RT}'), libs_path)
package_data['paddle.libs'] += ['${TR_INFER_RT}', '${TR_INFER_PLUGIN_RT}']
if '${WITH_MKL}' == 'ON':
shutil.copy('${MKLML_SHARED_LIB}', libs_path)
shutil.copy('${MKLML_SHARED_IOMP_LIB}', libs_path)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录