提交 2891070c 编写于 作者: N nhzlx 提交者: ceci3

cant not pass ci

add if use static engine for trt
test=develop
上级 717bbc08
...@@ -23,8 +23,12 @@ ...@@ -23,8 +23,12 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -133,6 +137,8 @@ struct Argument { ...@@ -133,6 +137,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int); DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode,
AnalysisConfig::Precision); AnalysisConfig::Precision);
DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine,
bool);
// Memory optimized related. // Memory optimized related.
DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool); DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
......
...@@ -82,6 +82,8 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -82,6 +82,8 @@ void IRPassManager::CreatePasses(Argument *argument,
"model_opt_cache_dir", "model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir))); new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
pass->Set("gpu_device_id", new int(argument->gpu_device_id())); pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
pass->Set("use_static_engine",
new bool(argument->tensorrt_use_static_engine()));
} }
pre_pass = pass_name; pre_pass = pass_name;
......
...@@ -226,10 +226,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -226,10 +226,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
calibrator.reset(new tensorrt::TRTInt8Calibrator(calibration_data)); calibrator.reset(new tensorrt::TRTInt8Calibrator(calibration_data));
} }
bool use_static_engine = Get<bool>("use_static_engine");
// When in int8 mode and calibration_mode, the program just produce the // When in int8 mode and calibration_mode, the program just produce the
// calibration table data. // calibration table data.
bool calibration_mode = (enable_int8 && calibration_data.size() == 0); bool calibration_mode = (enable_int8 && calibration_data.size() == 0);
if (!calibration_mode) { if (!calibration_mode && use_static_engine) {
std::copy(params.begin(), params.end(), std::copy(params.begin(), params.end(),
std::back_inserter(*repetitive_params)); std::back_inserter(*repetitive_params));
std::string trt_engine_serialized_data = GetTrtEngineSerializedData( std::string trt_engine_serialized_data = GetTrtEngineSerializedData(
......
...@@ -103,6 +103,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -103,6 +103,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(tensorrt_max_batchsize_); CP_MEMBER(tensorrt_max_batchsize_);
CP_MEMBER(tensorrt_min_subgraph_size_); CP_MEMBER(tensorrt_min_subgraph_size_);
CP_MEMBER(tensorrt_precision_mode_); CP_MEMBER(tensorrt_precision_mode_);
CP_MEMBER(trt_use_static_engine_);
// MKLDNN related. // MKLDNN related.
CP_MEMBER(use_mkldnn_); CP_MEMBER(use_mkldnn_);
CP_MEMBER(mkldnn_enabled_op_types_); CP_MEMBER(mkldnn_enabled_op_types_);
...@@ -144,7 +145,7 @@ void AnalysisConfig::EnableMKLDNN() { ...@@ -144,7 +145,7 @@ void AnalysisConfig::EnableMKLDNN() {
void AnalysisConfig::EnableTensorRtEngine( void AnalysisConfig::EnableTensorRtEngine(
int workspace_size, int max_batch_size, int min_subgraph_size, int workspace_size, int max_batch_size, int min_subgraph_size,
AnalysisConfig::Precision precision_mode) { AnalysisConfig::Precision precision_mode, bool use_static) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (!use_gpu()) { if (!use_gpu()) {
LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first"; LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first";
...@@ -156,6 +157,7 @@ void AnalysisConfig::EnableTensorRtEngine( ...@@ -156,6 +157,7 @@ void AnalysisConfig::EnableTensorRtEngine(
tensorrt_max_batchsize_ = max_batch_size; tensorrt_max_batchsize_ = max_batch_size;
tensorrt_min_subgraph_size_ = min_subgraph_size; tensorrt_min_subgraph_size_ = min_subgraph_size;
tensorrt_precision_mode_ = precision_mode; tensorrt_precision_mode_ = precision_mode;
trt_use_static_engine_ = use_static;
Update(); Update();
#else #else
......
...@@ -370,6 +370,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -370,6 +370,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_); argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_); argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
} }
if (config_.use_mkldnn_) { if (config_.use_mkldnn_) {
......
...@@ -135,7 +135,8 @@ struct AnalysisConfig { ...@@ -135,7 +135,8 @@ struct AnalysisConfig {
*/ */
void EnableTensorRtEngine(int workspace_size = 1 << 20, void EnableTensorRtEngine(int workspace_size = 1 << 20,
int max_batch_size = 1, int min_subgraph_size = 3, int max_batch_size = 1, int min_subgraph_size = 3,
Precision precision = Precision::kFloat32); Precision precision = Precision::kFloat32,
bool use_static = true);
/** A boolean state telling whether the TensorRT engine is used. /** A boolean state telling whether the TensorRT engine is used.
*/ */
bool tensorrt_engine_enabled() const { return use_tensorrt_; } bool tensorrt_engine_enabled() const { return use_tensorrt_; }
...@@ -233,6 +234,7 @@ struct AnalysisConfig { ...@@ -233,6 +234,7 @@ struct AnalysisConfig {
// subgraph, 3 as default value. // subgraph, 3 as default value.
int tensorrt_min_subgraph_size_{3}; int tensorrt_min_subgraph_size_{3};
Precision tensorrt_precision_mode_; Precision tensorrt_precision_mode_;
bool trt_use_static_engine_;
// memory reuse related. // memory reuse related.
bool enable_memory_optim_{false}; bool enable_memory_optim_{false};
......
...@@ -54,7 +54,8 @@ void SetConfig<AnalysisConfig>(AnalysisConfig* config, std::string model_dir, ...@@ -54,7 +54,8 @@ void SetConfig<AnalysisConfig>(AnalysisConfig* config, std::string model_dir,
if (use_gpu) { if (use_gpu) {
config->EnableUseGpu(100, 0); config->EnableUseGpu(100, 0);
if (use_tensorrt) { if (use_tensorrt) {
config->EnableTensorRtEngine(1 << 10, batch_size); config->EnableTensorRtEngine(1 << 10, batch_size, 3,
AnalysisConfig::Precision::kFloat32, false);
config->pass_builder()->DeletePass("conv_bn_fuse_pass"); config->pass_builder()->DeletePass("conv_bn_fuse_pass");
config->pass_builder()->DeletePass("fc_fuse_pass"); config->pass_builder()->DeletePass("fc_fuse_pass");
config->pass_builder()->TurnOnDebug(); config->pass_builder()->TurnOnDebug();
......
...@@ -227,7 +227,8 @@ void BindAnalysisConfig(py::module *m) { ...@@ -227,7 +227,8 @@ void BindAnalysisConfig(py::module *m) {
.def("enable_tensorrt_engine", &AnalysisConfig::EnableTensorRtEngine, .def("enable_tensorrt_engine", &AnalysisConfig::EnableTensorRtEngine,
py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1, py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1,
py::arg("min_subgraph_size") = 3, py::arg("min_subgraph_size") = 3,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32) py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
py::arg("use_static") = true)
.def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
.def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug, .def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug,
py::arg("x") = true) py::arg("x") = true)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册