diff --git a/paddle_inference/paddle/include/paddle_engine.h b/paddle_inference/paddle/include/paddle_engine.h index 71cc9fc285900e5851e1ef2796b07f503f1522ef..c9233c0488b48d85a2bec7f687865090bd7a6507 100644 --- a/paddle_inference/paddle/include/paddle_engine.h +++ b/paddle_inference/paddle/include/paddle_engine.h @@ -17,12 +17,13 @@ #include #include #include +#include #include #include #include "core/configure/include/configure_parser.h" #include "core/configure/inferencer_configure.pb.h" -#include "core/predictor/framework/infer.h" #include "core/predictor/common/utils.h" +#include "core/predictor/framework/infer.h" #include "paddle_inference_api.h" // NOLINT namespace baidu { @@ -37,31 +38,31 @@ using paddle_infer::CreatePredictor; DECLARE_int32(gpuid); -const static int max_batch = 32; -const static int min_subgraph_size = 3; +static const int max_batch = 32; +static const int min_subgraph_size = 3; // Engine Base class PaddleEngineBase { public: virtual ~PaddleEngineBase() {} virtual std::vector GetInputNames() { - return _predictor -> GetInputNames(); + return _predictor->GetInputNames(); } virtual std::unique_ptr GetInputHandle(const std::string& name) { - return _predictor -> GetInputHandle(name); + return _predictor->GetInputHandle(name); } virtual std::vector GetOutputNames() { - return _predictor -> GetOutputNames(); + return _predictor->GetOutputNames(); } virtual std::unique_ptr GetOutputHandle(const std::string& name) { - return _predictor -> GetOutputHandle(name); + return _predictor->GetOutputHandle(name); } virtual bool Run() { - if (!_predictor -> Run()) { + if (!_predictor->Run()) { LOG(ERROR) << "Failed call Run with paddle predictor"; return false; } @@ -75,8 +76,8 @@ class PaddleEngineBase { LOG(ERROR) << "origin paddle Predictor is null."; return -1; } - Predictor* prep = static_cast(predictor); - _predictor = prep -> Clone(); + Predictor* prep = static_cast(predictor); + _predictor = prep->Clone(); if (_predictor.get() == NULL) { LOG(ERROR) << "fail to clone paddle predictor: " << predictor; return -1; @@ -103,8 +104,8 @@ class PaddleInferenceEngine : public PaddleEngineBase { Config config; // todo, auto config(zhangjun) - if(engine_conf.has_combined_model()) { - if(!engine_conf.combined_model()) { + if (engine_conf.has_combined_model()) { + if (!engine_conf.combined_model()) { config.SetModel(model_path); } else { config.SetParamsFile(model_path + "/__params__"); @@ -114,14 +115,14 @@ class PaddleInferenceEngine : public PaddleEngineBase { config.SetParamsFile(model_path + "/__params__"); config.SetProgFile(model_path + "/__model__"); } - + config.SwitchSpecifyInputNames(true); config.SetCpuMathLibraryNumThreads(1); if (engine_conf.has_use_gpu() && engine_conf.use_gpu()) { // 2000MB GPU memory config.EnableUseGpu(2000, FLAGS_gpuid); } - + if (engine_conf.has_use_trt() && engine_conf.use_trt()) { config.EnableTensorRtEngine(1 << 20, max_batch, @@ -140,19 +141,33 @@ class PaddleInferenceEngine : public PaddleEngineBase { // 2 MB l3 cache config.EnableXpu(2 * 1024 * 1024); } - if (engine_conf.has_enable_ir_optimization() && !engine_conf.enable_ir_optimization()) { + if (engine_conf.has_enable_ir_optimization() && + !engine_conf.enable_ir_optimization()) { config.SwitchIrOptim(false); } else { config.SwitchIrOptim(true); } - if (engine_conf.has_enable_memory_optimization() && engine_conf.enable_memory_optimization()) { + if (engine_conf.has_enable_memory_optimization() && + engine_conf.enable_memory_optimization()) { config.EnableMemoryOptim(); } - - if (false) { - // todo, encrypt model - //analysis_config.SetModelBuffer(); + + if (engine_conf.has_encrypted_model() && engine_conf.encrypted_mode()) { + // decrypt model + std::string model_buffer, params_buffer, key_buffer; + ReadBinaryFile(model_path + "encrypt_model", &model_buffer); + ReadBinaryFile(model_path + "encrypt_params", ¶ms_buffer); + ReadBinaryFile(model_path + "key", &key_buffer); + + auto cipher = paddle::MakeCipher(""); + std::string real_model_buffer = cipher->Decrypt(model_buffer, key_buffer); + std::string real_params_buffer = + cipher->Decrypt(params_buffer, key_buffer); + config.SetModelBuffer(&real_model_buffer[0], + real_model_buffer.size(), + &real_params_buffer[0], + real_params_buffer.size()); } predictor::AutoLock lock(predictor::GlobalCreateMutex::instance());