未验证 提交 ea0abf93 编写于 作者: W Wilber 提交者: GitHub

Support trt cuda graph. (#53406)

上级 72cb09e3
......@@ -231,6 +231,7 @@ struct Argument {
TensorRtUseStaticEngine,
bool);
DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool);
DECL_ARGUMENT_FIELD(tensorrt_use_cuda_graph, TensorRtUseCudaGraph, bool);
DECL_ARGUMENT_FIELD(tensorrt_use_varseqlen, TensorRtUseOSS, bool);
DECL_ARGUMENT_FIELD(tensorrt_with_interleaved, TensorRtWithInterleaved, bool);
DECL_ARGUMENT_FIELD(tensorrt_transformer_posid,
......
......@@ -165,6 +165,8 @@ void IRPassManager::CreatePasses(Argument *argument,
new AnalysisConfig::Precision(precision_mode));
pass->Set("context_memory_sharing",
new bool(argument->trt_engine_memory_sharing()));
pass->Set("use_cuda_graph",
new bool(argument->tensorrt_use_cuda_graph()));
bool use_static_engine = argument->tensorrt_use_static_engine();
bool model_from_memory = argument->model_from_memory();
std::string optim_cache_dir = argument->optim_cache_dir();
......
......@@ -101,6 +101,22 @@ void OutputProcess(framework::ir::Graph *graph,
}
}
// Determine whether the whole graph offload to tensorrt. If so we can try to
// enable optimization such as cudaGraph.
bool AllNodesLowerToTrtPostProcess(framework::ir::Graph *graph) {
std::unordered_set<std::string> trt_nodes_set{
"feed", "fetch", "tensorrt_engine"};
bool all_nodes_offload_to_trt = true;
for (auto *node : graph->Nodes()) {
if (node->IsOp()) {
if (!trt_nodes_set.count(node->Op()->Type())) {
all_nodes_offload_to_trt = false;
break;
}
}
}
return all_nodes_offload_to_trt;
}
} // namespace
using framework::ir::Node;
......@@ -124,6 +140,7 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
auto enable_int8 = Get<bool>("enable_int8");
auto use_calib_mode = Get<bool>("use_calib_mode");
bool use_cuda_graph = Get<bool>("use_cuda_graph");
bool no_calib_int8 = enable_int8 && !(use_calib_mode);
auto trt_disabled_ops = Get<std::vector<std::string>>("trt_disabled_ops");
auto with_dynamic_shape = Get<bool>("with_dynamic_shape");
......@@ -165,13 +182,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
// those parameter already exist in trt, and should not have another copy in
// fluid.
std::vector<std::string> repetitive_params;
std::vector<std::string> engine_names;
for (auto *node : graph->Nodes()) {
if (node->IsOp() && !framework::ir::Agent(node).subgraph()->empty()) {
CreateTensorRTOp(node, graph, graph_param_names, &repetitive_params);
std::unordered_set<const Node *> nodes2remove(
framework::ir::Agent(node).subgraph()->begin(),
framework::ir::Agent(node).subgraph()->end());
framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
engine_names.push_back(CreateTensorRTOp(
node, graph, graph_param_names, &repetitive_params, use_cuda_graph));
}
}
......@@ -184,6 +199,32 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
graph->Set(framework::ir::kRepetitiveParamAttr,
new std::vector<std::string>(repetitive_params));
bool all_nodes_offload_to_trt = AllNodesLowerToTrtPostProcess(graph);
if (all_nodes_offload_to_trt) {
LOG(INFO) << "The entire graph is offloaded to TensorRT.";
}
if (use_cuda_graph && !all_nodes_offload_to_trt) {
LOG_FIRST_N(WARNING, 1)
<< "You have enabled CudaGraph, but not the entire graph offload to "
"trt, now return to normal mode.";
use_cuda_graph = false;
}
if (use_cuda_graph && all_nodes_offload_to_trt) {
for (auto &name : engine_names) {
PADDLE_ENFORCE_EQ(
paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.Has(name),
true,
platform::errors::PreconditionNotMet(
"TRTEnegineManager shoud has engine %s, but not found.", name));
paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.Get(name)
->SetAllNodesLowerToTrt(use_cuda_graph);
}
}
}
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
......@@ -191,6 +232,7 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
const std::string &predictor_id,
const std::string &max_batch_size,
const std::string &precision,
bool use_cuda_graph,
const bool for_calibration) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
......@@ -209,17 +251,21 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
engine_hash_key += "#";
engine_hash_key += precision;
engine_hash_key += "#";
engine_hash_key += use_cuda_graph;
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
VLOG(2) << "TRT engine hash key: " << engine_hash_key;
VLOG(2) << "TRT engine key: " << engine_key;
return engine_key;
}
void TensorRtSubgraphPass::CreateTensorRTOp(
std::string TensorRtSubgraphPass::CreateTensorRTOp(
framework::ir::Node *node,
framework::ir::Graph *graph,
const std::vector<std::string> &graph_params,
std::vector<std::string> *repetitive_params) const {
std::vector<std::string> *repetitive_params,
bool use_cuda_graph) const {
auto *op_desc = node->Op();
auto &subgraph = *framework::ir::Agent(node).subgraph();
PADDLE_ENFORCE_EQ(subgraph.empty(),
......@@ -506,6 +552,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::to_string(0),
std::to_string(max_batch_size),
std::to_string(static_cast<int>(precision_mode)),
use_cuda_graph,
false);
auto calibration_engine_key =
GenerateEngineKey(input_names_with_id,
......@@ -513,6 +560,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::to_string(0),
std::to_string(max_batch_size),
std::to_string(static_cast<int>(precision_mode)),
use_cuda_graph,
true);
auto predictor_id = Get<int>("predictor_id");
......@@ -547,7 +595,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
(enable_int8 && calibration_data.size() == 0 && use_calib_mode);
if (calibration_mode) {
// calibraion mode means generate int8 calibration table data process.
return;
return calibration_engine_key;
}
std::copy(params_not_shared.begin(),
......@@ -582,6 +630,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
"recommend using the same TRT version at runtime.";
}
std::unordered_set<const Node *> nodes2remove(
framework::ir::Agent(node).subgraph()->begin(),
framework::ir::Agent(node).subgraph()->end());
framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not
// run fp16.
// When running fp16, the output accuracy of the model will be affected,
......@@ -628,7 +681,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
LOG(INFO) << "Load TRT Optimized Info from "
<< GetTrtEngineSerializedPath(
Get<std::string>("model_opt_cache_dir"), engine_key);
return;
return engine_key + std::to_string(predictor_id);
} catch (const std::exception &exp) {
LOG(WARNING)
<< "Fail to load TRT Optimized Info from "
......@@ -643,7 +696,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// If with_dynamic_shape is configured,but min_input_shape is empty,
// create trt engine in runtime instead of in pass.
if (with_dynamic_shape && min_input_shape.empty()) {
return;
return engine_key + std::to_string(predictor_id);
}
// the following code will NOT run in following situation:
......@@ -676,6 +729,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
<< GetTrtEngineSerializedPath(
Get<std::string>("model_opt_cache_dir"), engine_key);
}
return engine_key + std::to_string(predictor_id);
}
} // namespace analysis
......
......@@ -42,10 +42,11 @@ class TensorRtSubgraphPass : public framework::ir::FusePassBase {
void ApplyImpl(framework::ir::Graph *graph) const override;
private:
void CreateTensorRTOp(framework::ir::Node *x,
framework::ir::Graph *graph,
const std::vector<std::string> &graph_params,
std::vector<std::string> *repetitive_params) const;
std::string CreateTensorRTOp(framework::ir::Node *x,
framework::ir::Graph *graph,
const std::vector<std::string> &graph_params,
std::vector<std::string> *repetitive_params,
bool use_cuda_graph) const;
void CleanIntermediateOutputs(framework::ir::Node *node);
};
......
......@@ -16,6 +16,7 @@
#include <string>
#include <tuple>
#include "glog/logging.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
......@@ -442,6 +443,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(trt_dla_core_);
CP_MEMBER(trt_use_static_engine_);
CP_MEMBER(trt_use_calib_mode_);
CP_MEMBER(trt_use_cuda_graph_);
CP_MEMBER(trt_use_varseqlen_);
CP_MEMBER(trt_with_interleaved_);
CP_MEMBER(tensorrt_transformer_posid_);
......@@ -719,7 +721,8 @@ void AnalysisConfig::EnableTensorRtEngine(
int min_subgraph_size,
AnalysisConfig::Precision precision_mode,
bool use_static,
bool use_calib_mode) {
bool use_calib_mode,
bool use_cuda_graph) {
#ifdef PADDLE_WITH_TENSORRT
if (!use_gpu()) {
LOG(ERROR) << "To use TensorRT engine, please call EnableUseGpu() first";
......@@ -733,6 +736,11 @@ void AnalysisConfig::EnableTensorRtEngine(
tensorrt_precision_mode_ = precision_mode;
trt_use_static_engine_ = use_static;
trt_use_calib_mode_ = use_calib_mode;
trt_use_cuda_graph_ = use_cuda_graph;
if (use_cuda_graph) {
LOG_FIRST_N(INFO, 1) << "You have enabled Trt Cuda Graph, you must ensure "
"that the input Shape remains unchanged.";
}
Update();
#else
......@@ -1313,6 +1321,8 @@ std::string AnalysisConfig::Summary() {
trt_use_static_engine_ ? "true" : "false"});
os.InsertRow(
{"tensorrt_use_calib_mode", trt_use_calib_mode_ ? "true" : "false"});
os.InsertRow(
{"tensorrt_use_cuda_graph", trt_use_cuda_graph_ ? "true" : "false"});
// dynamic_shape
os.InsertRow({"tensorrt_enable_dynamic_shape",
......
......@@ -1352,6 +1352,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetTensorRtDLACore(config_.trt_dla_core_);
argument_->SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
argument_->SetTensorRtUseCalibMode(config_.trt_use_calib_mode_);
argument_->SetTensorRtUseCudaGraph(config_.trt_use_cuda_graph_);
argument_->SetCloseTrtPluginFp16(config_.disable_trt_plugin_fp16_);
argument_->SetTensorRtShapeRangeInfoPath(config_.shape_range_info_path());
argument_->SetTensorRtAllowBuildAtRuntime(
......
......@@ -586,6 +586,9 @@ struct PD_INFER_DECL AnalysisConfig {
/// \param use_static Serialize optimization information to disk for reusing.
/// \param use_calib_mode Use TRT int8 calibration(post training
/// quantization).
/// \param use_cuda_graph Use CudaGraph to reduce the time consumption of
/// enqueue. Note that this option can only be enabled when your input is
/// constant (including the batch dimension).
///
///
void EnableTensorRtEngine(int64_t workspace_size = 1 << 30,
......@@ -593,7 +596,8 @@ struct PD_INFER_DECL AnalysisConfig {
int min_subgraph_size = 3,
Precision precision = Precision::kFloat32,
bool use_static = false,
bool use_calib_mode = true);
bool use_calib_mode = true,
bool use_cuda_graph = false);
///
/// \brief A boolean state telling whether the TensorRT engine is used.
///
......@@ -1114,6 +1118,7 @@ struct PD_INFER_DECL AnalysisConfig {
Precision tensorrt_precision_mode_{Precision::kFloat32};
bool trt_use_static_engine_{false};
bool trt_use_calib_mode_{true};
bool trt_use_cuda_graph_{false};
bool trt_use_varseqlen_{false};
bool trt_with_interleaved_{false};
std::string tensorrt_transformer_posid_{""};
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace inference {
......@@ -129,12 +130,60 @@ void TensorRTEngine::Execute(int batch_size,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
infer_context->setDeviceMemory(context_memory);
}
// TODO(wilber): Is cudaGraph has conflict with memory sharing?
if (startup_with_cudagraph_ && !cudagraph_inited_) {
// Avoid capturing initialization calls by executing the enqueue function at
// least once before starting CUDA graph capture.
const auto ret = Enqueue(infer_context, buffers, batch_size, stream);
PADDLE_ENFORCE_EQ(
ret,
true,
phi::errors::PreconditionNotMet("Trt CudaGraph test run failed."));
cudaStreamSynchronize(stream);
cuda_graph_.BeginCapture(stream);
// The built TRT engine may contain operations that are not permitted under
// CUDA graph capture mode. When the stream is capturing, the call may
// return false if the current CUDA graph capture fails.
if (Enqueue(infer_context, buffers, batch_size, stream)) {
cuda_graph_.EndCapture(stream);
cudagraph_inited_ = true;
} else {
cuda_graph_.EndCaptureOnError(stream);
// Ensure any CUDA error has been cleaned up.
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetLastError());
LOG(WARNING) << "The built TensorRT engine contains operations that are "
"not permitted under "
"CUDA graph capture mode. The specified UseCudaGraph "
"flag has been ignored. The inference will be "
"launched without using CUDA graph launch.";
cudagraph_inited_ = false;
}
startup_with_cudagraph_ = false;
}
Enqueue(infer_context, buffers, batch_size, stream);
}
bool TensorRTEngine::Enqueue(nvinfer1::IExecutionContext *context,
std::vector<void *> *buffers,
int batch_size,
cudaStream_t stream) {
if (cudagraph_inited_) {
VLOG(1) << "cuda_graph init success, so we will use cuda graph launch the "
"entire graph.";
return cuda_graph_.Launch(stream);
}
bool ret;
if (!with_dynamic_shape()) {
infer_context->enqueue(batch_size, buffers->data(), stream, nullptr);
ret = context->enqueue(batch_size, buffers->data(), stream, nullptr);
} else {
infer_context->enqueueV2(buffers->data(), stream, nullptr);
ret = context->enqueueV2(buffers->data(), stream, nullptr);
}
SetRuntimeBatch(batch_size);
return ret;
}
void TensorRTEngine::FreezeNetwork() {
......
......@@ -49,6 +49,64 @@ namespace paddle {
namespace inference {
namespace tensorrt {
// The code is mainly from TensorRT, thanks to the project.
class TrtCudaGraph {
public:
TrtCudaGraph() = default;
~TrtCudaGraph() {
if (cuda_graph_exec_) {
cudaGraphExecDestroy(cuda_graph_exec_);
}
}
void BeginCapture(cudaStream_t stream) {
PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamBeginCapture(stream, cudaStreamCaptureModeThreadLocal));
}
bool Launch(cudaStream_t stream) {
return cudaGraphLaunch(cuda_graph_exec_, stream);
}
void EndCapture(cudaStream_t stream) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamEndCapture(stream, &cuda_graph_));
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphInstantiate(
&cuda_graph_exec_, cuda_graph_, nullptr, nullptr, 0));
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphDestroy(cuda_graph_));
}
void EndCaptureOnError(cudaStream_t stream) {
// There are two possibilities why stream capture would fail:
// (1) stream is in cudaErrorStreamCaptureInvalidated state.
// (2) TRT reports a failure.
// In case (1), the returning cuda_graph_ should be nullptr.
// In case (2), the returning cuda_graph_ is not nullptr, but it should not
// be used.
const auto ret = cudaStreamEndCapture(stream, &cuda_graph_);
if (ret == cudaErrorStreamCaptureInvalidated) {
PADDLE_ENFORCE_EQ(cuda_graph_ == nullptr,
true,
platform::errors::PreconditionNotMet(
"CudaGraph capture stream failed."));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(ret);
PADDLE_ENFORCE_NOT_NULL(
cuda_graph_,
phi::errors::PreconditionNotMet("CudaGraph capture stream failed."));
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphDestroy(cuda_graph_));
cuda_graph_ = nullptr;
}
// Clean up any cuda error.
cudaGetLastError();
LOG(WARNING) << "The TRT CUDA graph capture on the stream has failed.";
}
private:
DISABLE_COPY_AND_ASSIGN(TrtCudaGraph);
cudaGraph_t cuda_graph_{};
cudaGraphExec_t cuda_graph_exec_{};
};
namespace plugin {
class PluginTensorRT;
} // namespace plugin
......@@ -445,6 +503,11 @@ class TensorRTEngine {
std::vector<void*>* buffers,
cudaStream_t stream = nullptr);
bool Enqueue(nvinfer1::IExecutionContext* context,
std::vector<void*>* buffers,
int batch,
cudaStream_t stream);
nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }
ShapeMapType min_input_shape() { return min_input_shape_; }
......@@ -682,6 +745,11 @@ class TensorRTEngine {
context_memory_sharing_ = context_memory_sharing;
}
void SetAllNodesLowerToTrt(bool all_nodes_offload_to_trt) {
// all nodes are in trt, so we can use cudaGraph to optimize runtime.
startup_with_cudagraph_ = all_nodes_offload_to_trt;
}
private:
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
......@@ -744,6 +812,11 @@ class TensorRTEngine {
infer_ptr<nvinfer1::IHostMemory> ihost_memory_;
std::unordered_map<nvinfer1::ITensor*, float> quant_dynamic_range_;
// cudagraph related
TrtCudaGraph cuda_graph_;
bool cudagraph_inited_{false};
bool startup_with_cudagraph_{false};
std::unordered_map<std::string, paddle::any> attrs_;
std::unordered_map<std::string, std::function<void(void)>> attr_dels_;
#if IS_TRT_VERSION_GE(6000)
......
......@@ -274,6 +274,7 @@ TEST_F(TensorRTEngineTest, test_pool2d) {
buffers[0] = reinterpret_cast<void *>(x_v_gpu_data);
buffers[1] = reinterpret_cast<void *>(y_gpu_data);
engine_->SetAllNodesLowerToTrt(true);
engine_->Execute(2, &buffers, ctx_->stream());
LOG(INFO) << "to get output";
......
......@@ -840,7 +840,8 @@ void BindAnalysisConfig(py::module *m) {
py::arg("min_subgraph_size") = 3,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
py::arg("use_static") = false,
py::arg("use_calib_mode") = true)
py::arg("use_calib_mode") = true,
py::arg("use_cuda_graph") = false)
.def("enable_tensorrt_memory_optim",
&AnalysisConfig::EnableTensorRTMemoryOptim,
py::arg("engine_memory_sharing") = true,
......
......@@ -1377,7 +1377,7 @@ if(WITH_TESTING AND WITH_INFERENCE_API_TEST)
set_tests_properties(test_analyzer_ernie PROPERTIES TIMEOUT 120)
endif()
if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(trt_mobilenet_test PROPERTIES TIMEOUT 120)
set_tests_properties(trt_mobilenet_test PROPERTIES TIMEOUT 240)
if(WITH_MKLDNN)
set_tests_properties(test_analyzer_bfloat16_resnet50 PROPERTIES TIMEOUT
120)
......
......@@ -99,4 +99,30 @@ TEST(PredictorPool, use_gpu) {
predictor->Run();
}
TEST(PredictorPool, use_trt_cuda_graph) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
Config config;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir);
config.EnableTensorRtEngine(
1 << 20, 1, 3, PrecisionType::kFloat32, false, false, true);
config.Exp_DisableTensorRtOPs({"fc"});
config.EnableTensorRtDLA(0);
services::PredictorPool pred_pool(config, 1);
auto predictor = pred_pool.Retrive(0);
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputHandle(input_names[0]);
std::vector<int> in_shape = {1, 3, 224, 224};
int in_num =
std::accumulate(in_shape.begin(), in_shape.end(), 1, [](int &a, int &b) {
return a * b;
});
std::vector<float> input(in_num, 0);
input_t->Reshape(in_shape);
input_t->CopyFromCpu(input.data());
predictor->Run();
}
} // namespace paddle_infer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册