提交 19575971 编写于 作者: H HexToString

support 3 types and fix httpClient

上级 243d5992
......@@ -70,7 +70,7 @@ PrecisionType GetPrecision(const std::string& precision_data) {
return PrecisionType::kFloat32;
}
const std::string& getFileBySuffix(
const std::string getFileBySuffix(
const std::string& path, const std::vector<std::string>& suffixVector) {
DIR* dp = nullptr;
std::string fileName = "";
......@@ -156,9 +156,21 @@ class PaddleInferenceEngine : public EngineCore {
}
Config config;
// todo, auto config(zhangjun)
if (engine_conf.has_encrypted_model() && engine_conf.encrypted_model()) {
std::vector<std::string> suffixParaVector = {".pdiparams", "__params__"};
std::vector<std::string> suffixModelVector = {".pdmodel", "__model__"};
std::string paraFileName = getFileBySuffix(model_path, suffixParaVector);
std::string modelFileName = getFileBySuffix(model_path, suffixModelVector);
std::string encryParaPath = model_path + "/encrypt_model";
std::string encryModelPath = model_path + "/encrypt_params";
std::string encryKeyPath = model_path + "/key";
// encrypt model
if (access(encryParaPath.c_str(), F_OK) != -1 &&
access(encryModelPath.c_str(), F_OK) != -1 &&
access(encryKeyPath.c_str(), F_OK) != -1) {
// decrypt model
std::string model_buffer, params_buffer, key_buffer;
predictor::ReadBinaryFile(model_path + "/encrypt_model", &model_buffer);
predictor::ReadBinaryFile(model_path + "/encrypt_params", &params_buffer);
......@@ -172,19 +184,11 @@ class PaddleInferenceEngine : public EngineCore {
real_model_buffer.size(),
&real_params_buffer[0],
real_params_buffer.size());
} else if (engine_conf.has_combined_model() &&
(!engine_conf.combined_model())) {
config.SetModel(model_path);
} else if (paraFileName.length() != 0 && modelFileName.length() != 0) {
config.SetParamsFile(model_path + "/" + paraFileName);
config.SetProgFile(model_path + "/" + modelFileName);
} else {
std::vector<std::string> suffixParaVector = {".pdiparams", "__params__"};
std::vector<std::string> suffixModelVector = {".pdmodel", "__model__"};
std::string paraFileName = getFileBySuffix(model_path, suffixParaVector);
std::string modelFileName =
getFileBySuffix(model_path, suffixModelVector);
if (paraFileName.length() != 0 && modelFileName.length() != 0) {
config.SetParamsFile(model_path + "/" + paraFileName);
config.SetProgFile(model_path + "/" + modelFileName);
}
config.SetModel(model_path);
}
config.SwitchSpecifyInputNames(true);
......
......@@ -403,7 +403,7 @@ class HttpClient(object):
# 由于输入比较特殊,shape保持原feedvar中不变
data_value = []
data_value.append(feed_dict[key])
if isinstance(feed_dict[key], str):
if isinstance(feed_dict[key], (str, bytes)):
if self.feed_types_[key] != bytes_type:
raise ValueError(
"feedvar is not string-type,feed can`t be a single string."
......@@ -411,7 +411,7 @@ class HttpClient(object):
else:
if self.feed_types_[key] == bytes_type:
raise ValueError(
"feedvar is string-type,feed, feed can`t be a single int or others."
"feedvar is string-type,feed can`t be a single int or others."
)
# 如果不压缩,那么不需要统计数据量。
if self.try_request_gzip:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册