未验证 提交 18665979 编写于 作者: Y Yan Chunwei 提交者: GitHub

add tensorrt build support(#9891)

上级 0032b4a4
......@@ -39,6 +39,7 @@ option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_F
option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF)
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND})
option(WITH_TENSORRT "Compile PaddlePaddle with TensorRT support." OFF)
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
......@@ -181,6 +182,11 @@ if(WITH_GPU)
include(cuda)
endif(WITH_GPU)
# TensorRT depends on GPU.
if (NOT WITH_GPU)
set(WITH_TENSORRT OFF)
endif()
if(WITH_AMD_GPU)
find_package(HIP)
include(hip)
......
......@@ -45,6 +45,13 @@ ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
# install glide
RUN curl -s -q https://glide.sh/get | sh
# Install TensorRT
# The unnecessary files has been removed to make the library small.
RUN wget -qO- http://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.0.3.Ubuntu-16.04.4.x86_64-gnu.cuda-8.0.cudnn7.0.tar.gz | \
tar -xz -C /usr/local && \
cp -rf /usr/local/TensorRT/include /usr && \
cp -rf /usr/local/TensorRT/lib /usr
# git credential to skip password typing
RUN git config --global credential.helper store
......
......@@ -21,4 +21,7 @@ endif()
if(WITH_TESTING)
add_subdirectory(tests/book)
if (WITH_TENSORRT)
add_subdirectory(tensorrt)
endif()
endif()
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "NvInfer.h"
#include "cuda.h"
#include "cuda_runtime_api.h"
#include "paddle/fluid/platform/dynload/tensorrt.h"
namespace dy = paddle::platform::dynload;
class Logger : public nvinfer1::ILogger {
public:
void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
switch (severity) {
case Severity::kINFO:
LOG(INFO) << msg;
break;
case Severity::kWARNING:
LOG(WARNING) << msg;
break;
case Severity::kINTERNAL_ERROR:
case Severity::kERROR:
LOG(ERROR) << msg;
break;
default:
break;
}
}
};
class ScopedWeights {
public:
ScopedWeights(float value) : value_(value) {
w.type = nvinfer1::DataType::kFLOAT;
w.values = &value_;
w.count = 1;
}
const nvinfer1::Weights& get() { return w; }
private:
float value_;
nvinfer1::Weights w;
};
// The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library.
nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
return static_cast<nvinfer1::IBuilder*>(
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
}
nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
}
const char* kInputTensor = "input";
const char* kOutputTensor = "output";
// Creates a network to compute y = 2x + 3
nvinfer1::IHostMemory* CreateNetwork() {
Logger logger;
// Create the engine.
nvinfer1::IBuilder* builder = createInferBuilder(logger);
ScopedWeights weights(2.);
ScopedWeights bias(3.);
nvinfer1::INetworkDefinition* network = builder->createNetwork();
// Add the input
auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 1, 1});
EXPECT_NE(input, nullptr);
// Add the hidden layer.
auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get());
EXPECT_NE(layer, nullptr);
// Mark the output.
auto output = layer->getOutput(0);
output->setName(kOutputTensor);
network->markOutput(*output);
// Build the engine.
builder->setMaxBatchSize(1);
builder->setMaxWorkspaceSize(1 << 10);
auto engine = builder->buildCudaEngine(*network);
EXPECT_NE(engine, nullptr);
// Serialize the engine to create a model, then close.
nvinfer1::IHostMemory* model = engine->serialize();
network->destroy();
engine->destroy();
builder->destroy();
return model;
}
void Execute(nvinfer1::IExecutionContext& context, const float* input,
float* output) {
const nvinfer1::ICudaEngine& engine = context.getEngine();
// Two binds, input and output
ASSERT_EQ(engine.getNbBindings(), 2);
const int input_index = engine.getBindingIndex(kInputTensor);
const int output_index = engine.getBindingIndex(kOutputTensor);
// Create GPU buffers and a stream
void* buffers[2];
ASSERT_EQ(0, cudaMalloc(&buffers[input_index], sizeof(float)));
ASSERT_EQ(0, cudaMalloc(&buffers[output_index], sizeof(float)));
cudaStream_t stream;
ASSERT_EQ(0, cudaStreamCreate(&stream));
// Copy the input to the GPU, execute the network, and copy the output back.
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
cudaMemcpyHostToDevice, stream));
context.enqueue(1, buffers, stream, nullptr);
ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
cudaMemcpyDeviceToHost, stream));
cudaStreamSynchronize(stream);
// Release the stream and the buffers
cudaStreamDestroy(stream);
ASSERT_EQ(0, cudaFree(buffers[input_index]));
ASSERT_EQ(0, cudaFree(buffers[output_index]));
}
TEST(TensorrtTest, BasicFunction) {
// Create the network serialized model.
nvinfer1::IHostMemory* model = CreateNetwork();
// Use the model to create an engine and an execution context.
Logger logger;
nvinfer1::IRuntime* runtime = createInferRuntime(logger);
nvinfer1::ICudaEngine* engine =
runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
model->destroy();
nvinfer1::IExecutionContext* context = engine->createExecutionContext();
// Execute the network.
float input = 1234;
float output;
Execute(*context, &input, &output);
EXPECT_EQ(output, input * 2 + 3);
// Destroy the engine.
context->destroy();
engine->destroy();
runtime->destroy();
}
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc nccl.cc)
if (WITH_TENSORRT)
list(APPEND CUDA_SRCS tensorrt.cc)
endif()
configure_file(cupti_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/cupti_lib_path.h)
if (CUPTI_FOUND)
list(APPEND CUDA_SRCS cupti.cc)
......
......@@ -45,6 +45,10 @@ DEFINE_string(nccl_dir, "",
DEFINE_string(cupti_dir, "", "Specify path for loading cupti.so.");
DEFINE_string(
tensorrt_dir, "",
"Specify path for loading tensorrt library, such as libnvinfer.so.");
namespace paddle {
namespace platform {
namespace dynload {
......@@ -194,6 +198,14 @@ void* GetNCCLDsoHandle() {
#endif
}
void* GetTensorRtDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.dylib");
#else
return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.so");
#endif
}
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -25,6 +25,7 @@ void* GetCurandDsoHandle();
void* GetWarpCTCDsoHandle();
void* GetLapackDsoHandle();
void* GetNCCLDsoHandle();
void* GetTensorRtDsoHandle();
} // namespace dynload
} // namespace platform
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/tensorrt.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag tensorrt_dso_flag;
void *tensorrt_dso_handle;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <NvInfer.h>
#include <dlfcn.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag tensorrt_dso_flag;
extern void* tensorrt_dso_handle;
#ifdef PADDLE_USE_DSO
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using tensorrt_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(tensorrt_dso_flag, []() { \
tensorrt_dso_handle = \
paddle::platform::dynload::GetTensorRtDsoHandle(); \
PADDLE_ENFORCE(tensorrt_dso_handle, "load tensorrt so failed"); \
}); \
void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \
PADDLE_ENFORCE(p_##__name, "load %s failed", #__name); \
return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#else
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
tensorrtResult_t operator()(Args... args) { \
return __name(args...); \
} \
}; \
extern DynLoad__##__name __name
#endif
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \
__macro(createInferBuilder_INTERNAL); \
__macro(createInferRuntime_INTERNAL);
TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP)
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -32,6 +32,8 @@ DEFINE_string(warpctc_dir, "", "Specify path for loading libwarpctc.so.");
DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so.");
DEFINE_string(tensorrt_dir, "", "Specify path for loading libnvinfer.so.");
static inline std::string join(const std::string& part1,
const std::string& part2) {
// directory separator
......@@ -157,3 +159,12 @@ void GetLapackDsoHandle(void** dso_handle) {
GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.so", dso_handle);
#endif
}
void GetTensorRtDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(
FLAGS_tensorrt_dir, "libnvinfer.dylib", dso_handle);
#else
GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.so", dso_handle);
#endif
}
......@@ -58,3 +58,11 @@ void GetWarpCTCDsoHandle(void** dso_handle);
*
*/
void GetLapackDsoHandle(void** dso_handle);
/**
* @brief load the DSO of tensorrt
*
* @param **dso_handle dso handler
*
*/
void GetTensorRtDsoHandle(void** dso_handle);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册