提交 ca8a42ad 编写于 作者: Z zhangjun

add model encrypt,test=develop

上级 a0ca45bd
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
#include <pthread.h> #include <pthread.h>
#include <fstream> #include <fstream>
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "core/configure/include/configure_parser.h" #include "core/configure/include/configure_parser.h"
#include "core/configure/inferencer_configure.pb.h" #include "core/configure/inferencer_configure.pb.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/common/utils.h" #include "core/predictor/common/utils.h"
#include "core/predictor/framework/infer.h"
#include "paddle_inference_api.h" // NOLINT #include "paddle_inference_api.h" // NOLINT
namespace baidu { namespace baidu {
...@@ -37,31 +38,31 @@ using paddle_infer::CreatePredictor; ...@@ -37,31 +38,31 @@ using paddle_infer::CreatePredictor;
DECLARE_int32(gpuid); DECLARE_int32(gpuid);
const static int max_batch = 32; static const int max_batch = 32;
const static int min_subgraph_size = 3; static const int min_subgraph_size = 3;
// Engine Base // Engine Base
class PaddleEngineBase { class PaddleEngineBase {
public: public:
virtual ~PaddleEngineBase() {} virtual ~PaddleEngineBase() {}
virtual std::vector<std::string> GetInputNames() { virtual std::vector<std::string> GetInputNames() {
return _predictor -> GetInputNames(); return _predictor->GetInputNames();
} }
virtual std::unique_ptr<Tensor> GetInputHandle(const std::string& name) { virtual std::unique_ptr<Tensor> GetInputHandle(const std::string& name) {
return _predictor -> GetInputHandle(name); return _predictor->GetInputHandle(name);
} }
virtual std::vector<std::string> GetOutputNames() { virtual std::vector<std::string> GetOutputNames() {
return _predictor -> GetOutputNames(); return _predictor->GetOutputNames();
} }
virtual std::unique_ptr<Tensor> GetOutputHandle(const std::string& name) { virtual std::unique_ptr<Tensor> GetOutputHandle(const std::string& name) {
return _predictor -> GetOutputHandle(name); return _predictor->GetOutputHandle(name);
} }
virtual bool Run() { virtual bool Run() {
if (!_predictor -> Run()) { if (!_predictor->Run()) {
LOG(ERROR) << "Failed call Run with paddle predictor"; LOG(ERROR) << "Failed call Run with paddle predictor";
return false; return false;
} }
...@@ -75,8 +76,8 @@ class PaddleEngineBase { ...@@ -75,8 +76,8 @@ class PaddleEngineBase {
LOG(ERROR) << "origin paddle Predictor is null."; LOG(ERROR) << "origin paddle Predictor is null.";
return -1; return -1;
} }
Predictor* prep = static_cast<Predictor*>(predictor); Predictor* prep = static_cast<Predictor*>(predictor);
_predictor = prep -> Clone(); _predictor = prep->Clone();
if (_predictor.get() == NULL) { if (_predictor.get() == NULL) {
LOG(ERROR) << "fail to clone paddle predictor: " << predictor; LOG(ERROR) << "fail to clone paddle predictor: " << predictor;
return -1; return -1;
...@@ -103,8 +104,8 @@ class PaddleInferenceEngine : public PaddleEngineBase { ...@@ -103,8 +104,8 @@ class PaddleInferenceEngine : public PaddleEngineBase {
Config config; Config config;
// todo, auto config(zhangjun) // todo, auto config(zhangjun)
if(engine_conf.has_combined_model()) { if (engine_conf.has_combined_model()) {
if(!engine_conf.combined_model()) { if (!engine_conf.combined_model()) {
config.SetModel(model_path); config.SetModel(model_path);
} else { } else {
config.SetParamsFile(model_path + "/__params__"); config.SetParamsFile(model_path + "/__params__");
...@@ -114,14 +115,14 @@ class PaddleInferenceEngine : public PaddleEngineBase { ...@@ -114,14 +115,14 @@ class PaddleInferenceEngine : public PaddleEngineBase {
config.SetParamsFile(model_path + "/__params__"); config.SetParamsFile(model_path + "/__params__");
config.SetProgFile(model_path + "/__model__"); config.SetProgFile(model_path + "/__model__");
} }
config.SwitchSpecifyInputNames(true); config.SwitchSpecifyInputNames(true);
config.SetCpuMathLibraryNumThreads(1); config.SetCpuMathLibraryNumThreads(1);
if (engine_conf.has_use_gpu() && engine_conf.use_gpu()) { if (engine_conf.has_use_gpu() && engine_conf.use_gpu()) {
// 2000MB GPU memory // 2000MB GPU memory
config.EnableUseGpu(2000, FLAGS_gpuid); config.EnableUseGpu(2000, FLAGS_gpuid);
} }
if (engine_conf.has_use_trt() && engine_conf.use_trt()) { if (engine_conf.has_use_trt() && engine_conf.use_trt()) {
config.EnableTensorRtEngine(1 << 20, config.EnableTensorRtEngine(1 << 20,
max_batch, max_batch,
...@@ -140,19 +141,33 @@ class PaddleInferenceEngine : public PaddleEngineBase { ...@@ -140,19 +141,33 @@ class PaddleInferenceEngine : public PaddleEngineBase {
// 2 MB l3 cache // 2 MB l3 cache
config.EnableXpu(2 * 1024 * 1024); config.EnableXpu(2 * 1024 * 1024);
} }
if (engine_conf.has_enable_ir_optimization() && !engine_conf.enable_ir_optimization()) { if (engine_conf.has_enable_ir_optimization() &&
!engine_conf.enable_ir_optimization()) {
config.SwitchIrOptim(false); config.SwitchIrOptim(false);
} else { } else {
config.SwitchIrOptim(true); config.SwitchIrOptim(true);
} }
if (engine_conf.has_enable_memory_optimization() && engine_conf.enable_memory_optimization()) { if (engine_conf.has_enable_memory_optimization() &&
engine_conf.enable_memory_optimization()) {
config.EnableMemoryOptim(); config.EnableMemoryOptim();
} }
if (false) { if (engine_conf.has_encrypted_model() && engine_conf.encrypted_mode()) {
// todo, encrypt model // decrypt model
//analysis_config.SetModelBuffer(); std::string model_buffer, params_buffer, key_buffer;
ReadBinaryFile(model_path + "encrypt_model", &model_buffer);
ReadBinaryFile(model_path + "encrypt_params", &params_buffer);
ReadBinaryFile(model_path + "key", &key_buffer);
auto cipher = paddle::MakeCipher("");
std::string real_model_buffer = cipher->Decrypt(model_buffer, key_buffer);
std::string real_params_buffer =
cipher->Decrypt(params_buffer, key_buffer);
config.SetModelBuffer(&real_model_buffer[0],
real_model_buffer.size(),
&real_params_buffer[0],
real_params_buffer.size());
} }
predictor::AutoLock lock(predictor::GlobalCreateMutex::instance()); predictor::AutoLock lock(predictor::GlobalCreateMutex::instance());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册