未验证 提交 1e2d713f 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1354 from HexToString/grpc_update

support 3 types model
......@@ -14,6 +14,7 @@
#pragma once
#include <dirent.h>
#include <pthread.h>
#include <fstream>
#include <map>
......@@ -69,6 +70,30 @@ PrecisionType GetPrecision(const std::string& precision_data) {
return PrecisionType::kFloat32;
}
const std::string getFileBySuffix(
const std::string& path, const std::vector<std::string>& suffixVector) {
DIR* dp = nullptr;
std::string fileName = "";
struct dirent* dirp = nullptr;
if ((dp = opendir(path.c_str())) == nullptr) {
return fileName;
}
while ((dirp = readdir(dp)) != nullptr) {
if (dirp->d_type == DT_REG) {
for (int idx = 0; idx < suffixVector.size(); ++idx) {
if (std::string(dirp->d_name).find(suffixVector[idx]) !=
std::string::npos) {
fileName = static_cast<std::string>(dirp->d_name);
break;
}
}
}
if (fileName.length() != 0) break;
}
closedir(dp);
return fileName;
}
// Engine Base
class EngineCore {
public:
......@@ -131,9 +156,21 @@ class PaddleInferenceEngine : public EngineCore {
}
Config config;
// todo, auto config(zhangjun)
if (engine_conf.has_encrypted_model() && engine_conf.encrypted_model()) {
std::vector<std::string> suffixParaVector = {".pdiparams", "__params__"};
std::vector<std::string> suffixModelVector = {".pdmodel", "__model__"};
std::string paraFileName = getFileBySuffix(model_path, suffixParaVector);
std::string modelFileName = getFileBySuffix(model_path, suffixModelVector);
std::string encryParaPath = model_path + "/encrypt_model";
std::string encryModelPath = model_path + "/encrypt_params";
std::string encryKeyPath = model_path + "/key";
// encrypt model
if (access(encryParaPath.c_str(), F_OK) != -1 &&
access(encryModelPath.c_str(), F_OK) != -1 &&
access(encryKeyPath.c_str(), F_OK) != -1) {
// decrypt model
std::string model_buffer, params_buffer, key_buffer;
predictor::ReadBinaryFile(model_path + "/encrypt_model", &model_buffer);
predictor::ReadBinaryFile(model_path + "/encrypt_params", &params_buffer);
......@@ -147,16 +184,11 @@ class PaddleInferenceEngine : public EngineCore {
real_model_buffer.size(),
&real_params_buffer[0],
real_params_buffer.size());
} else if (engine_conf.has_combined_model()) {
if (!engine_conf.combined_model()) {
config.SetModel(model_path);
} else if (paraFileName.length() != 0 && modelFileName.length() != 0) {
config.SetParamsFile(model_path + "/" + paraFileName);
config.SetProgFile(model_path + "/" + modelFileName);
} else {
config.SetParamsFile(model_path + "/__params__");
config.SetProgFile(model_path + "/__model__");
}
} else {
config.SetParamsFile(model_path + "/__params__");
config.SetProgFile(model_path + "/__model__");
config.SetModel(model_path);
}
config.SwitchSpecifyInputNames(true);
......
......@@ -403,7 +403,7 @@ class HttpClient(object):
# 由于输入比较特殊,shape保持原feedvar中不变
data_value = []
data_value.append(feed_dict[key])
if isinstance(feed_dict[key], str):
if isinstance(feed_dict[key], (str, bytes)):
if self.feed_types_[key] != bytes_type:
raise ValueError(
"feedvar is not string-type,feed can`t be a single string."
......@@ -411,7 +411,7 @@ class HttpClient(object):
else:
if self.feed_types_[key] == bytes_type:
raise ValueError(
"feedvar is string-type,feed, feed can`t be a single int or others."
"feedvar is string-type,feed can`t be a single int or others."
)
# 如果不压缩,那么不需要统计数据量。
if self.try_request_gzip:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册