提交 87c688ed 编写于 作者: J jack

add encrypted model loading

上级 7df7269a
......@@ -81,7 +81,7 @@ int main(int argc, char** argv) {
auto start = system_clock::now();
// 读图像
int im_vec_size =
std::min(static_cat<int>(image_paths.size()), i + FLAGS_batch_size);
std::min(static_cast<int>(image_paths.size()), i + FLAGS_batch_size);
std::vector<cv::Mat> im_vec(im_vec_size - i);
std::vector<PaddleX::ClsResult> results(im_vec_size - i,
PaddleX::ClsResult());
......
......@@ -95,10 +95,10 @@ class Model {
* This method aims to load model configurations which include
* transform steps and label list
*
* @param model_dir: the directory which contains model.yml
* @param yaml_file: model configuration
* @return true if load configuration successfully
* */
bool load_config(const std::string& model_dir);
bool load_config(const std::string& yaml_file);
/*
* @brief
......
......@@ -23,22 +23,25 @@ void Model::create_predictor(const std::string& model_dir,
int gpu_id,
std::string key,
int batch_size) {
// 读取配置文件
if (!load_config(model_dir)) {
std::cerr << "Parse file 'model.yml' failed!" << std::endl;
exit(-1);
}
paddle::AnalysisConfig config;
std::string model_file = model_dir + OS_PATH_SEP + "__model__";
std::string params_file = model_dir + OS_PATH_SEP + "__params__";
std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml";
#ifdef WITH_ENCRYPTION
if (key != "") {
model_file = model_dir + OS_PATH_SEP + "__model__.encrypted";
params_file = model_dir + OS_PATH_SEP + "__params__.encrypted";
std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml.encrypted";
paddle_security_load_model(
&config, key.c_str(), model_file.c_str(), params_file.c_str());
}
#endif
// 读取配置文件
if (!load_config(yaml_file)) {
std::cerr << "Parse file 'model.yml' failed!" << std::endl;
exit(-1);
}
if (key == "") {
config.SetModel(model_file, params_file);
}
......@@ -64,8 +67,8 @@ void Model::create_predictor(const std::string& model_dir,
inputs_batch_.assign(batch_size, ImageBlob());
}
bool Model::load_config(const std::string& model_dir) {
std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml";
bool Model::load_config(const std::string& yaml_file) {
// std::string yaml_file = model_dir + OS_PATH_SEP + "model.yml";
YAML::Node config = YAML::LoadFile(yaml_file);
type = config["_Attributes"]["model_type"].as<std::string>();
name = config["Model"].as<std::string>();
......
#!/bin/bash
set -e
readonly VERSION="3.8"
version=$(clang-format -version)
if ! [[ $version == *"$VERSION"* ]]; then
echo "clang-format version check failed."
echo "a version contains '$VERSION' is needed, but get '$version'"
echo "you can install the right version, and make an soft-link to '\$PATH' env"
exit -1
fi
clang-format $@
# set -e
#
# readonly VERSION="3.8"
#
# version=$(clang-format -version)
#
# if ! [[ $version == *"$VERSION"* ]]; then
# echo "clang-format version check failed."
# echo "a version contains '$VERSION' is needed, but get '$version'"
# echo "you can install the right version, and make an soft-link to '\$PATH' env"
# exit -1
# fi
#
# clang-format $@
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册