diff --git a/deploy/cpp/src/paddlex.cpp b/deploy/cpp/src/paddlex.cpp index 92249006eeb6a50bc27ef0446a1d50b126a9b933..e7fd9402b8ec6daa87dbba701699659a36416cad 100644 --- a/deploy/cpp/src/paddlex.cpp +++ b/deploy/cpp/src/paddlex.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include #include +#include #include #include "include/paddlex/paddlex.h" namespace PaddleX { @@ -27,17 +28,28 @@ void Model::create_predictor(const std::string& model_dir, std::string model_file = model_dir + OS_PATH_SEP + "__model__"; std::string params_file = model_dir + OS_PATH_SEP + "__params__"; std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml"; + std::string yaml_input = ""; #ifdef WITH_ENCRYPTION if (key != "") { model_file = model_dir + OS_PATH_SEP + "__model__.encrypted"; params_file = model_dir + OS_PATH_SEP + "__params__.encrypted"; - std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml.encrypted"; + yaml_file = model_dir + OS_PATH_SEP + "model.yml.encrypted"; paddle_security_load_model( &config, key.c_str(), model_file.c_str(), params_file.c_str()); + yaml_input = decrypt_file(yaml_file.c_str(), key.c_str()); } #endif - // 读取配置文件 - if (!load_config(yaml_file)) { + if (yaml_input == "") { + // 读取配置文件 + std::ifstream yaml_fin(yaml_file); + yaml_fin.seekg(0, std::ios::end); + size_t yaml_file_size = yaml_fin.tellg(); + yaml_input.assign(yaml_file_size, ' '); + yaml_fin.seekg(0); + yaml_fin.read(&yaml_input[0], yaml_file_size); + } + // 读取配置文件内容 + if (!load_config(yaml_input)) { std::cerr << "Parse file 'model.yml' failed!" << std::endl; exit(-1); } @@ -67,9 +79,8 @@ void Model::create_predictor(const std::string& model_dir, inputs_batch_.assign(batch_size, ImageBlob()); } -bool Model::load_config(const std::string& yaml_file) { - // std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml"; - YAML::Node config = YAML::LoadFile(yaml_file); +bool Model::load_config(const std::string& yaml_input) { + YAML::Node config = YAML::Load(yaml_input); type = config["_Attributes"]["model_type"].as(); name = config["Model"].as(); std::string version = config["version"].as();