提交 ca8a42ad 编写于 作者: Z zhangjun

add model encrypt,test=develop

上级 a0ca45bd
......@@ -17,12 +17,13 @@
#include <pthread.h>
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "core/configure/include/configure_parser.h"
#include "core/configure/inferencer_configure.pb.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/common/utils.h"
#include "core/predictor/framework/infer.h"
#include "paddle_inference_api.h" // NOLINT
namespace baidu {
......@@ -37,31 +38,31 @@ using paddle_infer::CreatePredictor;
DECLARE_int32(gpuid);
const static int max_batch = 32;
const static int min_subgraph_size = 3;
static const int max_batch = 32;
static const int min_subgraph_size = 3;
// Engine Base
class PaddleEngineBase {
public:
virtual ~PaddleEngineBase() {}
virtual std::vector<std::string> GetInputNames() {
return _predictor -> GetInputNames();
return _predictor->GetInputNames();
}
virtual std::unique_ptr<Tensor> GetInputHandle(const std::string& name) {
return _predictor -> GetInputHandle(name);
return _predictor->GetInputHandle(name);
}
virtual std::vector<std::string> GetOutputNames() {
return _predictor -> GetOutputNames();
return _predictor->GetOutputNames();
}
virtual std::unique_ptr<Tensor> GetOutputHandle(const std::string& name) {
return _predictor -> GetOutputHandle(name);
return _predictor->GetOutputHandle(name);
}
virtual bool Run() {
if (!_predictor -> Run()) {
if (!_predictor->Run()) {
LOG(ERROR) << "Failed call Run with paddle predictor";
return false;
}
......@@ -75,8 +76,8 @@ class PaddleEngineBase {
LOG(ERROR) << "origin paddle Predictor is null.";
return -1;
}
Predictor* prep = static_cast<Predictor*>(predictor);
_predictor = prep -> Clone();
Predictor* prep = static_cast<Predictor*>(predictor);
_predictor = prep->Clone();
if (_predictor.get() == NULL) {
LOG(ERROR) << "fail to clone paddle predictor: " << predictor;
return -1;
......@@ -103,8 +104,8 @@ class PaddleInferenceEngine : public PaddleEngineBase {
Config config;
// todo, auto config(zhangjun)
if(engine_conf.has_combined_model()) {
if(!engine_conf.combined_model()) {
if (engine_conf.has_combined_model()) {
if (!engine_conf.combined_model()) {
config.SetModel(model_path);
} else {
config.SetParamsFile(model_path + "/__params__");
......@@ -114,14 +115,14 @@ class PaddleInferenceEngine : public PaddleEngineBase {
config.SetParamsFile(model_path + "/__params__");
config.SetProgFile(model_path + "/__model__");
}
config.SwitchSpecifyInputNames(true);
config.SetCpuMathLibraryNumThreads(1);
if (engine_conf.has_use_gpu() && engine_conf.use_gpu()) {
// 2000MB GPU memory
config.EnableUseGpu(2000, FLAGS_gpuid);
}
if (engine_conf.has_use_trt() && engine_conf.use_trt()) {
config.EnableTensorRtEngine(1 << 20,
max_batch,
......@@ -140,19 +141,33 @@ class PaddleInferenceEngine : public PaddleEngineBase {
// 2 MB l3 cache
config.EnableXpu(2 * 1024 * 1024);
}
if (engine_conf.has_enable_ir_optimization() && !engine_conf.enable_ir_optimization()) {
if (engine_conf.has_enable_ir_optimization() &&
!engine_conf.enable_ir_optimization()) {
config.SwitchIrOptim(false);
} else {
config.SwitchIrOptim(true);
}
if (engine_conf.has_enable_memory_optimization() && engine_conf.enable_memory_optimization()) {
if (engine_conf.has_enable_memory_optimization() &&
engine_conf.enable_memory_optimization()) {
config.EnableMemoryOptim();
}
if (false) {
// todo, encrypt model
//analysis_config.SetModelBuffer();
if (engine_conf.has_encrypted_model() && engine_conf.encrypted_mode()) {
// decrypt model
std::string model_buffer, params_buffer, key_buffer;
ReadBinaryFile(model_path + "encrypt_model", &model_buffer);
ReadBinaryFile(model_path + "encrypt_params", &params_buffer);
ReadBinaryFile(model_path + "key", &key_buffer);
auto cipher = paddle::MakeCipher("");
std::string real_model_buffer = cipher->Decrypt(model_buffer, key_buffer);
std::string real_params_buffer =
cipher->Decrypt(params_buffer, key_buffer);
config.SetModelBuffer(&real_model_buffer[0],
real_model_buffer.size(),
&real_params_buffer[0],
real_params_buffer.size());
}
predictor::AutoLock lock(predictor::GlobalCreateMutex::instance());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册