提交 87c688ed 编写于 作者: J jack

add encrypted model loading

上级 7df7269a
...@@ -81,7 +81,7 @@ int main(int argc, char** argv) { ...@@ -81,7 +81,7 @@ int main(int argc, char** argv) {
auto start = system_clock::now(); auto start = system_clock::now();
// 读图像 // 读图像
int im_vec_size = int im_vec_size =
std::min(static_cat<int>(image_paths.size()), i + FLAGS_batch_size); std::min(static_cast<int>(image_paths.size()), i + FLAGS_batch_size);
std::vector<cv::Mat> im_vec(im_vec_size - i); std::vector<cv::Mat> im_vec(im_vec_size - i);
std::vector<PaddleX::ClsResult> results(im_vec_size - i, std::vector<PaddleX::ClsResult> results(im_vec_size - i,
PaddleX::ClsResult()); PaddleX::ClsResult());
......
...@@ -95,10 +95,10 @@ class Model { ...@@ -95,10 +95,10 @@ class Model {
* This method aims to load model configurations which include * This method aims to load model configurations which include
* transform steps and label list * transform steps and label list
* *
* @param model_dir: the directory which contains model.yml * @param yaml_file: model configuration
* @return true if load configuration successfully * @return true if load configuration successfully
* */ * */
bool load_config(const std::string& model_dir); bool load_config(const std::string& yaml_file);
/* /*
* @brief * @brief
......
...@@ -23,22 +23,25 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -23,22 +23,25 @@ void Model::create_predictor(const std::string& model_dir,
int gpu_id, int gpu_id,
std::string key, std::string key,
int batch_size) { int batch_size) {
// 读取配置文件
if (!load_config(model_dir)) {
std::cerr << "Parse file 'model.yml' failed!" << std::endl;
exit(-1);
}
paddle::AnalysisConfig config; paddle::AnalysisConfig config;
std::string model_file = model_dir + OS_PATH_SEP + "__model__"; std::string model_file = model_dir + OS_PATH_SEP + "__model__";
std::string params_file = model_dir + OS_PATH_SEP + "__params__"; std::string params_file = model_dir + OS_PATH_SEP + "__params__";
std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml";
#ifdef WITH_ENCRYPTION #ifdef WITH_ENCRYPTION
if (key != "") { if (key != "") {
model_file = model_dir + OS_PATH_SEP + "__model__.encrypted"; model_file = model_dir + OS_PATH_SEP + "__model__.encrypted";
params_file = model_dir + OS_PATH_SEP + "__params__.encrypted"; params_file = model_dir + OS_PATH_SEP + "__params__.encrypted";
std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml.encrypted";
paddle_security_load_model( paddle_security_load_model(
&config, key.c_str(), model_file.c_str(), params_file.c_str()); &config, key.c_str(), model_file.c_str(), params_file.c_str());
} }
#endif #endif
// 读取配置文件
if (!load_config(yaml_file)) {
std::cerr << "Parse file 'model.yml' failed!" << std::endl;
exit(-1);
}
if (key == "") { if (key == "") {
config.SetModel(model_file, params_file); config.SetModel(model_file, params_file);
} }
...@@ -64,8 +67,8 @@ void Model::create_predictor(const std::string& model_dir, ...@@ -64,8 +67,8 @@ void Model::create_predictor(const std::string& model_dir,
inputs_batch_.assign(batch_size, ImageBlob()); inputs_batch_.assign(batch_size, ImageBlob());
} }
bool Model::load_config(const std::string& model_dir) { bool Model::load_config(const std::string& yaml_file) {
std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml"; // std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml";
YAML::Node config = YAML::LoadFile(yaml_file); YAML::Node config = YAML::LoadFile(yaml_file);
type = config["_Attributes"]["model_type"].as<std::string>(); type = config["_Attributes"]["model_type"].as<std::string>();
name = config["Model"].as<std::string>(); name = config["Model"].as<std::string>();
......
#!/bin/bash #!/bin/bash
set -e # set -e
#
readonly VERSION="3.8" # readonly VERSION="3.8"
#
version=$(clang-format -version) # version=$(clang-format -version)
#
if ! [[ $version == *"$VERSION"* ]]; then # if ! [[ $version == *"$VERSION"* ]]; then
echo "clang-format version check failed." # echo "clang-format version check failed."
echo "a version contains '$VERSION' is needed, but get '$version'" # echo "a version contains '$VERSION' is needed, but get '$version'"
echo "you can install the right version, and make an soft-link to '\$PATH' env" # echo "you can install the right version, and make an soft-link to '\$PATH' env"
exit -1 # exit -1
fi # fi
#
clang-format $@ # clang-format $@
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册