未验证 提交 1e2d713f 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1354 from HexToString/grpc_update

support 3 types model
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <dirent.h>
#include <pthread.h> #include <pthread.h>
#include <fstream> #include <fstream>
#include <map> #include <map>
...@@ -69,6 +70,30 @@ PrecisionType GetPrecision(const std::string& precision_data) { ...@@ -69,6 +70,30 @@ PrecisionType GetPrecision(const std::string& precision_data) {
return PrecisionType::kFloat32; return PrecisionType::kFloat32;
} }
const std::string getFileBySuffix(
const std::string& path, const std::vector<std::string>& suffixVector) {
DIR* dp = nullptr;
std::string fileName = "";
struct dirent* dirp = nullptr;
if ((dp = opendir(path.c_str())) == nullptr) {
return fileName;
}
while ((dirp = readdir(dp)) != nullptr) {
if (dirp->d_type == DT_REG) {
for (int idx = 0; idx < suffixVector.size(); ++idx) {
if (std::string(dirp->d_name).find(suffixVector[idx]) !=
std::string::npos) {
fileName = static_cast<std::string>(dirp->d_name);
break;
}
}
}
if (fileName.length() != 0) break;
}
closedir(dp);
return fileName;
}
// Engine Base // Engine Base
class EngineCore { class EngineCore {
public: public:
...@@ -131,9 +156,21 @@ class PaddleInferenceEngine : public EngineCore { ...@@ -131,9 +156,21 @@ class PaddleInferenceEngine : public EngineCore {
} }
Config config; Config config;
// todo, auto config(zhangjun) std::vector<std::string> suffixParaVector = {".pdiparams", "__params__"};
if (engine_conf.has_encrypted_model() && engine_conf.encrypted_model()) { std::vector<std::string> suffixModelVector = {".pdmodel", "__model__"};
std::string paraFileName = getFileBySuffix(model_path, suffixParaVector);
std::string modelFileName = getFileBySuffix(model_path, suffixModelVector);
std::string encryParaPath = model_path + "/encrypt_model";
std::string encryModelPath = model_path + "/encrypt_params";
std::string encryKeyPath = model_path + "/key";
// encrypt model
if (access(encryParaPath.c_str(), F_OK) != -1 &&
access(encryModelPath.c_str(), F_OK) != -1 &&
access(encryKeyPath.c_str(), F_OK) != -1) {
// decrypt model // decrypt model
std::string model_buffer, params_buffer, key_buffer; std::string model_buffer, params_buffer, key_buffer;
predictor::ReadBinaryFile(model_path + "/encrypt_model", &model_buffer); predictor::ReadBinaryFile(model_path + "/encrypt_model", &model_buffer);
predictor::ReadBinaryFile(model_path + "/encrypt_params", &params_buffer); predictor::ReadBinaryFile(model_path + "/encrypt_params", &params_buffer);
...@@ -147,16 +184,11 @@ class PaddleInferenceEngine : public EngineCore { ...@@ -147,16 +184,11 @@ class PaddleInferenceEngine : public EngineCore {
real_model_buffer.size(), real_model_buffer.size(),
&real_params_buffer[0], &real_params_buffer[0],
real_params_buffer.size()); real_params_buffer.size());
} else if (engine_conf.has_combined_model()) { } else if (paraFileName.length() != 0 && modelFileName.length() != 0) {
if (!engine_conf.combined_model()) { config.SetParamsFile(model_path + "/" + paraFileName);
config.SetModel(model_path); config.SetProgFile(model_path + "/" + modelFileName);
} else {
config.SetParamsFile(model_path + "/__params__");
config.SetProgFile(model_path + "/__model__");
}
} else { } else {
config.SetParamsFile(model_path + "/__params__"); config.SetModel(model_path);
config.SetProgFile(model_path + "/__model__");
} }
config.SwitchSpecifyInputNames(true); config.SwitchSpecifyInputNames(true);
......
...@@ -403,7 +403,7 @@ class HttpClient(object): ...@@ -403,7 +403,7 @@ class HttpClient(object):
# 由于输入比较特殊,shape保持原feedvar中不变 # 由于输入比较特殊,shape保持原feedvar中不变
data_value = [] data_value = []
data_value.append(feed_dict[key]) data_value.append(feed_dict[key])
if isinstance(feed_dict[key], str): if isinstance(feed_dict[key], (str, bytes)):
if self.feed_types_[key] != bytes_type: if self.feed_types_[key] != bytes_type:
raise ValueError( raise ValueError(
"feedvar is not string-type,feed can`t be a single string." "feedvar is not string-type,feed can`t be a single string."
...@@ -411,7 +411,7 @@ class HttpClient(object): ...@@ -411,7 +411,7 @@ class HttpClient(object):
else: else:
if self.feed_types_[key] == bytes_type: if self.feed_types_[key] == bytes_type:
raise ValueError( raise ValueError(
"feedvar is string-type,feed, feed can`t be a single int or others." "feedvar is string-type,feed can`t be a single int or others."
) )
# 如果不压缩,那么不需要统计数据量。 # 如果不压缩,那么不需要统计数据量。
if self.try_request_gzip: if self.try_request_gzip:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册