未验证 提交 5e0a8ecd 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge pull request #1359 from HexToString/grpc_update

增加Fetch为空的处理
......@@ -53,6 +53,9 @@ class ModelRes {
res._int32_value_map.end());
_shape_map.insert(res._shape_map.begin(), res._shape_map.end());
_lod_map.insert(res._lod_map.begin(), res._lod_map.end());
_tensor_alias_names.insert(_tensor_alias_names.end(),
res._tensor_alias_names.begin(),
res._tensor_alias_names.end());
}
ModelRes(ModelRes&& res) {
_engine_name = std::move(res._engine_name);
......@@ -69,6 +72,10 @@ class ModelRes {
std::make_move_iterator(std::end(res._shape_map)));
_lod_map.insert(std::make_move_iterator(std::begin(res._lod_map)),
std::make_move_iterator(std::end(res._lod_map)));
_tensor_alias_names.insert(
_tensor_alias_names.end(),
std::make_move_iterator(std::begin(res._tensor_alias_names)),
std::make_move_iterator(std::end(res._tensor_alias_names)));
}
~ModelRes() {}
const std::vector<int64_t>& get_int64_by_name(const std::string& name) {
......@@ -105,6 +112,10 @@ class ModelRes {
_engine_name = engine_name;
}
const std::string& engine_name() { return _engine_name; }
const std::vector<std::string>& tensor_alias_names() {
return _tensor_alias_names;
}
ModelRes& operator=(ModelRes&& res) {
if (this != &res) {
_engine_name = std::move(res._engine_name);
......@@ -121,6 +132,10 @@ class ModelRes {
std::make_move_iterator(std::end(res._shape_map)));
_lod_map.insert(std::make_move_iterator(std::begin(res._lod_map)),
std::make_move_iterator(std::end(res._lod_map)));
_tensor_alias_names.insert(
_tensor_alias_names.end(),
std::make_move_iterator(std::begin(res._tensor_alias_names)),
std::make_move_iterator(std::end(res._tensor_alias_names)));
}
return *this;
}
......@@ -132,6 +147,7 @@ class ModelRes {
std::map<std::string, std::vector<int32_t>> _int32_value_map;
std::map<std::string, std::vector<int>> _shape_map;
std::map<std::string, std::vector<int>> _lod_map;
std::vector<std::string> _tensor_alias_names;
};
class PredictorRes {
......@@ -193,11 +209,16 @@ class PredictorRes {
}
const std::string& variant_tag() { return _variant_tag; }
const std::vector<std::string>& get_engine_names() { return _engine_names; }
const std::vector<std::string>& get_tensor_alias_names(const int model_idx) {
_tensor_alias_names = _models[model_idx].tensor_alias_names();
return _tensor_alias_names;
}
private:
std::vector<ModelRes> _models;
std::string _variant_tag;
std::vector<std::string> _engine_names;
std::vector<std::string> _tensor_alias_names;
};
class PredictorClient {
......
......@@ -168,8 +168,6 @@ int PredictorClient::numpy_predict(
Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS();
int fetch_name_num = fetch_name.size();
_api.thrd_initialize();
std::string variant_tag;
_predictor = _api.fetch_predictor("general_model", &variant_tag);
......@@ -329,10 +327,12 @@ int PredictorClient::numpy_predict(
auto output = res.outputs(m_idx);
ModelRes model;
model.set_engine_name(output.engine_name());
int idx = 0;
for (auto &name : fetch_name) {
// 在ResponseOp处,已经按照fetch_name对输出数据进行了处理
// 所以,输出的数据与fetch_name是严格对应的,按顺序处理即可。
for (int idx = 0; idx < output.tensor_size(); ++idx) {
// int idx = _fetch_name_to_idx[name];
const std::string name = output.tensor(idx).alias_name();
model._tensor_alias_names.push_back(name);
int shape_size = output.tensor(idx).shape_size();
VLOG(2) << "fetch var " << name << " index " << idx << " shape size "
<< shape_size;
......@@ -347,13 +347,7 @@ int PredictorClient::numpy_predict(
model._lod_map[name][i] = output.tensor(idx).lod(i);
}
}
idx += 1;
}
idx = 0;
for (auto &name : fetch_name) {
// int idx = _fetch_name_to_idx[name];
if (_fetch_name_to_type[name] == P_INT64) {
VLOG(2) << "ferch var " << name << "type int64";
int size = output.tensor(idx).int64_data_size();
......@@ -373,7 +367,6 @@ int PredictorClient::numpy_predict(
output.tensor(idx).int_data().begin(),
output.tensor(idx).int_data().begin() + size);
}
idx += 1;
}
predict_res_batch.add_model_res(std::move(model));
}
......
......@@ -69,7 +69,10 @@ PYBIND11_MODULE(serving_client, m) {
})
.def("variant_tag", [](PredictorRes &self) { return self.variant_tag(); })
.def("get_engine_names",
[](PredictorRes &self) { return self.get_engine_names(); });
[](PredictorRes &self) { return self.get_engine_names(); })
.def("get_tensor_alias_names", [](PredictorRes &self, int model_idx) {
return self.get_tensor_alias_names(model_idx);
});
py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol())
.def(py::init())
......
......@@ -74,10 +74,19 @@ int GeneralResponseOp::inference() {
// and the order of Output is the same as the prototxt FetchVar.
// otherwise, you can only get the Output by the corresponding of
// Name -- Alias_name.
fetch_index.resize(req->fetch_var_names_size());
for (int i = 0; i < req->fetch_var_names_size(); ++i) {
fetch_index[i] =
model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)];
if (req->fetch_var_names_size() > 0) {
fetch_index.resize(req->fetch_var_names_size());
for (int i = 0; i < req->fetch_var_names_size(); ++i) {
fetch_index[i] =
model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)];
}
} else {
fetch_index.resize(model_config->_fetch_alias_name.size());
for (int i = 0; i < model_config->_fetch_alias_name.size(); ++i) {
fetch_index[i] =
model_config
->_fetch_alias_name_to_index[model_config->_fetch_alias_name[i]];
}
}
for (uint32_t pi = 0; pi < pre_node_names.size(); ++pi) {
......@@ -105,7 +114,7 @@ int GeneralResponseOp::inference() {
// fetch_index is the real index in FetchVar of Fetchlist
// for example, FetchVar = {0:A, 1:B, 2:C}
// FetchList = {0:C,1:A}, at this situation.
// fetch_index = [2,0], C`index = 2 and A`index = 0
// fetch_index = [2,0], C`index = 2 and A`index = 0
for (auto &idx : fetch_index) {
Tensor *tensor = output->add_tensor();
tensor->set_name(in->at(idx).name);
......
......@@ -289,16 +289,18 @@ class Client(object):
log_id=0):
self.profile_.record('py_prepro_0')
if feed is None or fetch is None:
raise ValueError("You should specify feed and fetch for prediction")
if feed is None:
raise ValueError("You should specify feed for prediction")
fetch_list = []
if isinstance(fetch, str):
fetch_list = [fetch]
elif isinstance(fetch, list):
fetch_list = fetch
elif fetch == None:
pass
else:
raise ValueError("Fetch only accepts string and list of string")
raise ValueError("Fetch only accepts string or list of string")
feed_batch = []
if isinstance(feed, dict):
......@@ -339,16 +341,13 @@ class Client(object):
string_feed_names = []
string_lod_slot_batch = []
string_shape = []
fetch_names = []
for key in fetch_list:
if key in self.fetch_names_:
fetch_names.append(key)
if len(fetch_names) == 0:
raise ValueError(
"Fetch names should not be empty or out of saved fetch list.")
feed_dict = feed_batch[0]
for key in feed_dict:
if ".lod" not in key and key not in self.feed_names_:
......@@ -443,6 +442,8 @@ class Client(object):
model_engine_names = result_batch_handle.get_engine_names()
for mi, engine_name in enumerate(model_engine_names):
result_map = {}
if len(fetch_names) == 0:
fetch_names = result_batch_handle.get_tensor_alias_names(mi)
# result map needs to be a numpy array
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int64_type:
......
......@@ -93,9 +93,13 @@ class HttpClient(object):
self.try_request_gzip = False
self.try_response_gzip = False
self.total_data_number = 0
self.headers = {}
self.http_proto = True
self.headers["Content-Type"] = "application/proto"
self.max_body_size = 512 * 1024 * 1024
self.use_grpc_client = False
self.url = None
# 使用连接池能够不用反复建立连接
self.requests_session = requests.session()
# 初始化grpc_stub
......@@ -190,8 +194,21 @@ class HttpClient(object):
def set_port(self, port):
self.port = port
self.server_port = port
self.init_grpc_stub()
def set_url(self, url):
if isinstance(url, str):
self.url = url
else:
print("url must be str")
def add_http_headers(self, headers):
if isinstance(headers, dict):
self.headers.update(headers)
else:
print("headers must be a dict")
def set_request_compress(self, try_request_gzip):
self.try_request_gzip = try_request_gzip
......@@ -200,6 +217,10 @@ class HttpClient(object):
def set_http_proto(self, http_proto):
self.http_proto = http_proto
if self.http_proto:
self.headers["Content-Type"] = "application/proto"
else:
self.headers["Content-Type"] = "application/json"
def set_use_grpc_client(self, use_grpc_client):
self.use_grpc_client = use_grpc_client
......@@ -232,31 +253,26 @@ class HttpClient(object):
return self.fetch_names_
def get_legal_fetch(self, fetch):
if fetch is None:
raise ValueError("You should specify feed and fetch for prediction")
fetch_list = []
if isinstance(fetch, str):
fetch_list = [fetch]
elif isinstance(fetch, (list, tuple)):
fetch_list = fetch
elif fetch == None:
pass
else:
raise ValueError("Fetch only accepts string and list of string")
raise ValueError("Fetch only accepts string/list/tuple of string")
fetch_names = []
for key in fetch_list:
if key in self.fetch_names_:
fetch_names.append(key)
if len(fetch_names) == 0:
raise ValueError(
"Fetch names should not be empty or out of saved fetch list.")
return {}
return fetch_names
def get_feedvar_dict(self, feed):
if feed is None:
raise ValueError("You should specify feed and fetch for prediction")
raise ValueError("You should specify feed for prediction")
feed_dict = {}
if isinstance(feed, dict):
feed_dict = feed
......@@ -402,17 +418,19 @@ class HttpClient(object):
# 此时先统一处理为一个list
# 由于输入比较特殊,shape保持原feedvar中不变
data_value = []
data_value.append(feed_dict[key])
if isinstance(feed_dict[key], (str, bytes)):
if self.feed_types_[key] != bytes_type:
raise ValueError(
"feedvar is not string-type,feed can`t be a single string."
)
if isinstance(feed_dict[key], bytes):
feed_dict[key] = feed_dict[key].decode()
else:
if self.feed_types_[key] == bytes_type:
raise ValueError(
"feedvar is string-type,feed can`t be a single int or others."
)
data_value.append(feed_dict[key])
# 如果不压缩,那么不需要统计数据量。
if self.try_request_gzip:
self.total_data_number = self.total_data_number + data_bytes_number(
......@@ -453,20 +471,25 @@ class HttpClient(object):
feed_dict = self.get_feedvar_dict(feed)
fetch_list = self.get_legal_fetch(fetch)
headers = {}
postData = ''
if self.http_proto == True:
postData = self.process_proto_data(feed_dict, fetch_list, batch,
log_id).SerializeToString()
headers["Content-Type"] = "application/proto"
else:
postData = self.process_json_data(feed_dict, fetch_list, batch,
log_id)
headers["Content-Type"] = "application/json"
web_url = "http://" + self.ip + ":" + self.server_port + self.service_name
if self.url != None:
if "http" not in self.url:
self.url = "http://" + self.url
if "self.service_name" not in self.url:
self.url = self.url + self.service_name
web_url = self.url
# 当数据区长度大于512字节时才压缩.
self.headers.pop("Content-Encoding", "nokey")
try:
if self.try_request_gzip and self.total_data_number > 512:
......@@ -474,20 +497,21 @@ class HttpClient(object):
postData = gzip.compress(postData)
else:
postData = gzip.compress(bytes(postData, 'utf-8'))
headers["Content-Encoding"] = "gzip"
self.headers["Content-Encoding"] = "gzip"
if self.try_response_gzip:
headers["Accept-encoding"] = "gzip"
self.headers["Accept-encoding"] = "gzip"
# 压缩异常,使用原始数据
except:
print("compress error, we will use the no-compress data")
headers.pop("Content-Encoding", "nokey")
self.headers.pop("Content-Encoding", "nokey")
# requests支持自动识别解压
try:
result = self.requests_session.post(
url=web_url,
headers=headers,
headers=self.headers,
data=postData,
timeout=self.timeout_ms / 1000)
timeout=self.timeout_ms / 1000,
verify=False)
result.raise_for_status()
except:
print("http post error")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册