未验证 提交 6404c438 编写于 作者: S Shang Zhizhou 提交者: GitHub

support trt serialize when load model from memory (#31342)

* support trt serialize when load model from memory

* delete conv_bn_fuse_pass before tensorrt, with which trt serialize engine id is not stable

* Revert "delete conv_bn_fuse_pass before tensorrt, with which trt serialize engine id is not stable"

performance degradation, fix in the future

This reverts commit fa6cd17e60b15df351efda379ddd00e9e9c1fea9.

* add delete conv_bn

* delete path when delete_cache_files
上级 a2c0b604
......@@ -114,13 +114,25 @@ void IRPassManager::CreatePasses(Argument *argument,
"When you are in TRT INT8 mode, and load model from "
"memory, you should set optim_cache_dir using "
"config.SetOptimCacheDir()"));
PADDLE_ENFORCE_EQ(
!(model_from_memory && use_static_engine), true,
platform::errors::PreconditionNotMet(
"When you are using Paddle-TRT, and also using load model "
"from memory, you should set the use_static to false."));
if (model_from_memory && use_static_engine) {
PADDLE_ENFORCE_EQ(
optim_cache_dir.empty(), false,
platform::errors::PreconditionNotMet(
"When you are using Paddle-TRT, and using load model "
"from memory, and also set the use_static to true. "
"you must set optim_cache_dir using "
"config.SetOptimCacheDir()."));
}
if (!optim_cache_dir.empty()) {
if (!PathExists(optim_cache_dir)) {
PADDLE_ENFORCE_NE(
MKDIR(optim_cache_dir.c_str()), -1,
platform::errors::PreconditionNotMet(
"Can not create optimize cache directory: %s, Make sure you "
"have permission to write",
optim_cache_dir));
}
pass->Set("model_opt_cache_dir", new std::string(optim_cache_dir));
} else if (use_static_engine || enable_int8) {
std::string model_opt_cache_dir =
......
......@@ -250,7 +250,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
auto predictor_id = Get<int>("predictor_id");
// Get "" when there is no cached calibration table data.
bool load_from_memory = Get<bool>("model_from_memory");
std::string calibration_data = "";
if (enable_int8 && use_calib_mode) {
calibration_data = GetTrtCalibTableData(
......@@ -323,8 +322,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass));
bool need_serialize = (use_static_engine && !load_from_memory);
if (need_serialize) {
if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData(
Get<std::string>("model_opt_cache_dir"), engine_key);
// we can load the engine info serialized before from the disk.
......@@ -352,7 +350,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, trt_engine);
if (need_serialize) {
if (use_static_engine) {
nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize();
trt_engine_serialized_data =
std::string((const char *)serialized_engine_data->data(),
......
......@@ -21,17 +21,32 @@ limitations under the License. */
namespace paddle {
namespace inference {
void TestDynamic(bool with_dynamic = true) {
void TestDynamic(bool with_dynamic = true, bool delete_cache = true,
bool delete_conv_bn = false) {
std::string model_dir =
FLAGS_infer_model + "/conv_bn_swish_split_gelu/conv_bn_swish_split_gelu";
std::string opt_cache_dir = model_dir + "/my_cache";
if (delete_cache) {
delete_cache_files(opt_cache_dir);
}
AnalysisConfig config;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir + "/model", model_dir + "/params");
std::string buffer_prog, buffer_param;
ReadBinaryFile(model_dir + "/model", &buffer_prog);
ReadBinaryFile(model_dir + "/params", &buffer_param);
config.SetModelBuffer(&buffer_prog[0], buffer_prog.size(), &buffer_param[0],
buffer_param.size());
config.SetOptimCacheDir(opt_cache_dir);
config.SwitchUseFeedFetchOps(false);
// Set the input's min, max, opt shape
config.EnableTensorRtEngine(1 << 30, 1, 1,
AnalysisConfig::Precision::kFloat32, false, true);
AnalysisConfig::Precision::kFloat32, true, true);
if (delete_conv_bn) {
config.pass_builder()->DeletePass("conv_bn_fuse_pass");
}
if (with_dynamic) {
std::map<std::string, std::vector<int>> min_input_shape = {
{"image", {1, 1, 3, 3}}};
......@@ -130,6 +145,12 @@ void TestDynamic2() {
TEST(AnalysisPredictor, trt_dynamic) { TestDynamic(true); }
TEST(AnalysisPredictor, trt_static) { TestDynamic(false); }
TEST(AnalysisPredictor, trt_memory_serialize) {
// serailize
TestDynamic(false, true, true);
// deserailize
TestDynamic(false, false, true);
}
TEST(AnalysisPredictor, trt_dynamic2) { TestDynamic2(); }
} // namespace inference
......
......@@ -148,6 +148,7 @@ void delete_cache_files(std::string path) {
remove(file_rm.c_str());
}
}
remove(path.c_str());
}
} // namespace inference
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册