提交 8aa1cc83 编写于 作者: J jack

adapt encryption in linux and windows

上级 2f8c5cd4
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <omp.h> #include <omp.h>
#include <algorithm> #include <algorithm>
#include <fstream>
#include <cstring> #include <cstring>
#include "include/paddlex/paddlex.h" #include "include/paddlex/paddlex.h"
namespace PaddleX { namespace PaddleX {
...@@ -27,17 +28,28 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -27,17 +28,28 @@ void Model::create_predictor(const std::string& model_dir,
std::string model_file = model_dir + OS_PATH_SEP + "__model__"; std::string model_file = model_dir + OS_PATH_SEP + "__model__";
std::string params_file = model_dir + OS_PATH_SEP + "__params__"; std::string params_file = model_dir + OS_PATH_SEP + "__params__";
std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml"; std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml";
std::string yaml_input = "";
#ifdef WITH_ENCRYPTION #ifdef WITH_ENCRYPTION
if (key != "") { if (key != "") {
model_file = model_dir + OS_PATH_SEP + "__model__.encrypted"; model_file = model_dir + OS_PATH_SEP + "__model__.encrypted";
params_file = model_dir + OS_PATH_SEP + "__params__.encrypted"; params_file = model_dir + OS_PATH_SEP + "__params__.encrypted";
std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml.encrypted"; yaml_file = model_dir + OS_PATH_SEP + "model.yml.encrypted";
paddle_security_load_model( paddle_security_load_model(
&config, key.c_str(), model_file.c_str(), params_file.c_str()); &config, key.c_str(), model_file.c_str(), params_file.c_str());
yaml_input = decrypt_file(yaml_file.c_str(), key.c_str());
} }
#endif #endif
// 读取配置文件 if (yaml_input == "") {
if (!load_config(yaml_file)) { // 读取配置文件
std::ifstream yaml_fin(yaml_file);
yaml_fin.seekg(0, std::ios::end);
size_t yaml_file_size = yaml_fin.tellg();
yaml_input.assign(yaml_file_size, ' ');
yaml_fin.seekg(0);
yaml_fin.read(&yaml_input[0], yaml_file_size);
}
// 读取配置文件内容
if (!load_config(yaml_input)) {
std::cerr << "Parse file 'model.yml' failed!" << std::endl; std::cerr << "Parse file 'model.yml' failed!" << std::endl;
exit(-1); exit(-1);
} }
...@@ -67,9 +79,8 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -67,9 +79,8 @@ void Model::create_predictor(const std::string& model_dir,
inputs_batch_.assign(batch_size, ImageBlob()); inputs_batch_.assign(batch_size, ImageBlob());
} }
bool Model::load_config(const std::string& yaml_file) { bool Model::load_config(const std::string& yaml_input) {
// std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml"; YAML::Node config = YAML::Load(yaml_input);
YAML::Node config = YAML::LoadFile(yaml_file);
type = config["_Attributes"]["model_type"].as<std::string>(); type = config["_Attributes"]["model_type"].as<std::string>();
name = config["Model"].as<std::string>(); name = config["Model"].as<std::string>();
std::string version = config["version"].as<std::string>(); std::string version = config["version"].as<std::string>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册