提交 08f5877f 编写于 作者: W wangjiawei04

adapt grpc to blazeface

上级 f88fc43f
...@@ -234,6 +234,7 @@ class PredictorClient { ...@@ -234,6 +234,7 @@ class PredictorClient {
const std::vector<std::string>& float_feed_name, const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int>>& float_shape, const std::vector<std::vector<int>>& float_shape,
const std::vector<std::vector<py::array_t<int64_t>>>& int_feed_batch, const std::vector<std::vector<py::array_t<int64_t>>>& int_feed_batch,
const std::vector<std::vector<py::array_t<int64_t>>>& lod_slot_batch,
const std::vector<std::string>& int_feed_name, const std::vector<std::string>& int_feed_name,
const std::vector<std::vector<int>>& int_shape, const std::vector<std::vector<int>>& int_shape,
const std::vector<std::string>& fetch_name, const std::vector<std::string>& fetch_name,
......
...@@ -352,6 +352,7 @@ int PredictorClient::numpy_predict( ...@@ -352,6 +352,7 @@ int PredictorClient::numpy_predict(
const std::vector<std::string> &float_feed_name, const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int>> &float_shape, const std::vector<std::vector<int>> &float_shape,
const std::vector<std::vector<py::array_t<int64_t>>> &int_feed_batch, const std::vector<std::vector<py::array_t<int64_t>>> &int_feed_batch,
const std::vector<std::vector<py::array_t<int64_t>>> &lod_slot_batch,
const std::vector<std::string> &int_feed_name, const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape, const std::vector<std::vector<int>> &int_shape,
const std::vector<std::string> &fetch_name, const std::vector<std::string> &fetch_name,
......
...@@ -127,6 +127,7 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -127,6 +127,7 @@ PYBIND11_MODULE(serving_client, m) {
const std::vector<std::vector<int>> &float_shape, const std::vector<std::vector<int>> &float_shape,
const std::vector<std::vector<py::array_t<int64_t>>> const std::vector<std::vector<py::array_t<int64_t>>>
&int_feed_batch, &int_feed_batch,
const std::vector<std::vector<py::array_t<int64_t>>>& lod_slot_batch,
const std::vector<std::string> &int_feed_name, const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape, const std::vector<std::vector<int>> &int_shape,
const std::vector<std::string> &fetch_name, const std::vector<std::string> &fetch_name,
...@@ -136,6 +137,7 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -136,6 +137,7 @@ PYBIND11_MODULE(serving_client, m) {
float_feed_name, float_feed_name,
float_shape, float_shape,
int_feed_batch, int_feed_batch,
lod_slot_batch,
int_feed_name, int_feed_name,
int_shape, int_shape,
fetch_name, fetch_name,
......
...@@ -69,6 +69,17 @@ int conf_check(const Request *req, ...@@ -69,6 +69,17 @@ int conf_check(const Request *req,
return 0; return 0;
} }
void print_lods(const std::vector<std::vector<size_t>>& lod) {
std::cout << "print lod info here" << std::endl;
for (size_t i = 0; i < lod.size(); ++i) {
std::cout << "the " << i << " th level of lod: " << std::endl;
for (size_t j = 0; j < lod[i].size(); ++j) {
std::cout << lod[i][j] << " ";
}
std::cout << std::endl;
}
}
int GeneralReaderOp::inference() { int GeneralReaderOp::inference() {
// reade request from client // reade request from client
const Request *req = dynamic_cast<const Request *>(get_request_message()); const Request *req = dynamic_cast<const Request *>(get_request_message());
...@@ -133,13 +144,20 @@ int GeneralReaderOp::inference() { ...@@ -133,13 +144,20 @@ int GeneralReaderOp::inference() {
elem_size[i] = sizeof(int32_t); elem_size[i] = sizeof(int32_t);
lod_tensor.dtype = paddle::PaddleDType::INT32; lod_tensor.dtype = paddle::PaddleDType::INT32;
} }
//implement lod tensor here
if (model_config->_is_lod_feed[i]) { std::cout << "lod size: "<< req->insts(0).tensor_array(i).lod_size() << std::endl;
lod_tensor.lod.resize(1); if (req->insts(0).tensor_array(i).lod_size() > 0) {
lod_tensor.lod[0].push_back(0); lod_tensor.lod.resize(1);
VLOG(2) << "var[" << i << "] is lod_tensor"; for (int k = 0; k < req->insts(0).tensor_array(i).lod_size(); ++k) {
} else { lod_tensor.lod[0].push_back(req->insts(0).tensor_array(i).lod(k));
lod_tensor.shape.push_back(batch_size); }
}
//if (model_config->_is_lod_feed[i]) {
// lod_tensor.lod.resize(1);
// lod_tensor.lod[0].push_back(0);
// VLOG(2) << "var[" << i << "] is lod_tensor";
//}
else {
capacity[i] = 1; capacity[i] = 1;
for (int k = 0; k < req->insts(0).tensor_array(i).shape_size(); ++k) { for (int k = 0; k < req->insts(0).tensor_array(i).shape_size(); ++k) {
int dim = req->insts(0).tensor_array(i).shape(k); int dim = req->insts(0).tensor_array(i).shape(k);
...@@ -150,6 +168,7 @@ int GeneralReaderOp::inference() { ...@@ -150,6 +168,7 @@ int GeneralReaderOp::inference() {
VLOG(2) << "var[" << i << "] is tensor, capacity: " << capacity[i]; VLOG(2) << "var[" << i << "] is tensor, capacity: " << capacity[i];
} }
lod_tensor.name = model_config->_feed_name[i]; lod_tensor.name = model_config->_feed_name[i];
print_lods(lod_tensor.lod);
out->push_back(lod_tensor); out->push_back(lod_tensor);
} }
...@@ -183,13 +202,13 @@ int GeneralReaderOp::inference() { ...@@ -183,13 +202,13 @@ int GeneralReaderOp::inference() {
VLOG(2) << "new len: " << cur_len + sample_len; VLOG(2) << "new len: " << cur_len + sample_len;
} }
out->at(i).data.Resize(tensor_size * elem_size[i]); out->at(i).data.Resize(tensor_size * elem_size[i]);
out->at(i).shape = {out->at(i).lod[0].back()}; out->at(i).shape = {};
for (int j = 1; j < req->insts(0).tensor_array(i).shape_size(); ++j) { for (int j = 1; j < req->insts(0).tensor_array(i).shape_size(); ++j) {
out->at(i).shape.push_back(req->insts(0).tensor_array(i).shape(j)); out->at(i).shape.push_back(req->insts(0).tensor_array(i).shape(j));
} }
if (out->at(i).shape.size() == 1) { //if (out->at(i).shape.size() == 1) {
out->at(i).shape.push_back(1); // out->at(i).shape.push_back(1);
} //}
VLOG(2) << "var[" << i VLOG(2) << "var[" << i
<< "] is lod_tensor and len=" << out->at(i).lod[0].back(); << "] is lod_tensor and len=" << out->at(i).lod[0].back();
} else { } else {
...@@ -211,11 +230,6 @@ int GeneralReaderOp::inference() { ...@@ -211,11 +230,6 @@ int GeneralReaderOp::inference() {
for (int k = 0; k < elem_num; ++k) { for (int k = 0; k < elem_num; ++k) {
dst_ptr[offset + k] = req->insts(j).tensor_array(i).int64_data(k); dst_ptr[offset + k] = req->insts(j).tensor_array(i).int64_data(k);
} }
if (out->at(i).lod.size() == 1) {
offset = out->at(i).lod[0][j + 1];
} else {
offset += capacity[i];
}
} }
} else if (elem_type[i] == 1) { } else if (elem_type[i] == 1) {
float *dst_ptr = static_cast<float *>(out->at(i).data.data()); float *dst_ptr = static_cast<float *>(out->at(i).data.data());
...@@ -227,11 +241,6 @@ int GeneralReaderOp::inference() { ...@@ -227,11 +241,6 @@ int GeneralReaderOp::inference() {
for (int k = 0; k < elem_num; ++k) { for (int k = 0; k < elem_num; ++k) {
dst_ptr[offset + k] = req->insts(j).tensor_array(i).float_data(k); dst_ptr[offset + k] = req->insts(j).tensor_array(i).float_data(k);
} }
if (out->at(i).lod.size() == 1) {
offset = out->at(i).lod[0][j + 1];
} else {
offset += capacity[i];
}
} }
} else if (elem_type[i] == 2) { } else if (elem_type[i] == 2) {
int32_t *dst_ptr = static_cast<int32_t *>(out->at(i).data.data()); int32_t *dst_ptr = static_cast<int32_t *>(out->at(i).data.data());
...@@ -243,11 +252,6 @@ int GeneralReaderOp::inference() { ...@@ -243,11 +252,6 @@ int GeneralReaderOp::inference() {
for (int k = 0; k < elem_num; ++k) { for (int k = 0; k < elem_num; ++k) {
dst_ptr[offset + k] = req->insts(j).tensor_array(i).int_data(k); dst_ptr[offset + k] = req->insts(j).tensor_array(i).int_data(k);
} }
if (out->at(i).lod.size() == 1) {
offset = out->at(i).lod[0][j + 1];
} else {
offset += capacity[i];
}
} }
} }
} }
......
...@@ -26,7 +26,7 @@ message Tensor { ...@@ -26,7 +26,7 @@ message Tensor {
repeated float float_data = 4; repeated float float_data = 4;
optional int32 elem_type = 5; optional int32 elem_type = 5;
repeated int32 shape = 6; repeated int32 shape = 6;
repeated int32 lod = 7; // only for fetch tensor currently repeated int32 lod = 7;
}; };
message FeedInst { repeated Tensor tensor_array = 1; }; message FeedInst { repeated Tensor tensor_array = 1; };
......
...@@ -116,7 +116,6 @@ class Debugger(object): ...@@ -116,7 +116,6 @@ class Debugger(object):
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
for name in input_names: for name in input_names:
print(feed)
if isinstance(feed[name], list): if isinstance(feed[name], list):
feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[ feed[name] = np.array(feed[name]).reshape(self.feed_shapes_[
name]) name])
...@@ -131,7 +130,7 @@ class Debugger(object): ...@@ -131,7 +130,7 @@ class Debugger(object):
input_tensor = self.predictor.get_input_tensor(name) input_tensor = self.predictor.get_input_tensor(name)
#TODO:set lods #TODO:set lods
if "{}.lod".format(name) in feed: if "{}.lod".format(name) in feed:
input_tensor.set_lod(feed["{}.lod".format(name)]) input_tensor.set_lod([feed["{}.lod".format(name)]])
if batch == True: if batch == True:
input_tensor.copy_from_cpu(feed[name][np.newaxis,:]) input_tensor.copy_from_cpu(feed[name][np.newaxis,:])
else: else:
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from .chinese_bert_reader import ChineseBertReader from .chinese_bert_reader import ChineseBertReader
from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize from .image_reader import ImageReader, File2Image, URL2Image, Sequential, Normalize
from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, ResizeByFactor from .image_reader import CenterCrop, Resize, Transpose, Div, RGB2BGR, BGR2RGB, ResizeByFactor
from .image_reader import RCNNPostprocess, SegPostprocess, PadStride from .image_reader import RCNNPostprocess, SegPostprocess, PadStride, BlazeFacePostprocess
from .image_reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes from .image_reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
from .lac_reader import LACReader from .lac_reader import LACReader
from .senta_reader import SentaReader from .senta_reader import SentaReader
......
...@@ -234,6 +234,11 @@ class Client(object): ...@@ -234,6 +234,11 @@ class Client(object):
pass pass
def predict(self, feed=None, fetch=None, need_variant_tag=False): def predict(self, feed=None, fetch=None, need_variant_tag=False):
"""
predict inferface
@feed: feed name and its lod
"""
self.profile_.record('py_prepro_0') self.profile_.record('py_prepro_0')
if feed is None or fetch is None: if feed is None or fetch is None:
...@@ -257,6 +262,7 @@ class Client(object): ...@@ -257,6 +262,7 @@ class Client(object):
int_slot_batch = [] int_slot_batch = []
float_slot_batch = [] float_slot_batch = []
lod_slot_batch = []
int_feed_names = [] int_feed_names = []
float_feed_names = [] float_feed_names = []
int_shape = [] int_shape = []
...@@ -277,9 +283,14 @@ class Client(object): ...@@ -277,9 +283,14 @@ class Client(object):
for i, feed_i in enumerate(feed_batch): for i, feed_i in enumerate(feed_batch):
int_slot = [] int_slot = []
float_slot = [] float_slot = []
lod_slot = []
#print("feed_i", feed_i)
for key in feed_i: for key in feed_i:
if key not in self.feed_names_: #print("key", key)
if ".lod" not in key and key not in self.feed_names_:
raise ValueError("Wrong feed name: {}.".format(key)) raise ValueError("Wrong feed name: {}.".format(key))
if ".lod" in key:
continue
#if not isinstance(feed_i[key], np.ndarray): #if not isinstance(feed_i[key], np.ndarray):
self.shape_check(feed_i, key) self.shape_check(feed_i, key)
if self.feed_types_[key] in int_type: if self.feed_types_[key] in int_type:
...@@ -308,8 +319,14 @@ class Client(object): ...@@ -308,8 +319,14 @@ class Client(object):
else: else:
float_slot.append(feed_i[key]) float_slot.append(feed_i[key])
self.all_numpy_input = False self.all_numpy_input = False
if ".lod" in key:
lod_slot.append(var.lod)
int_slot_batch.append(int_slot) int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot) float_slot_batch.append(float_slot)
lod_slot_batch.append(lod_slot)
#print("int slot", int_slot_batch)
#print("float slot", float_slot_batch)
self.profile_.record('py_prepro_1') self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0') self.profile_.record('py_client_infer_0')
...@@ -318,7 +335,7 @@ class Client(object): ...@@ -318,7 +335,7 @@ class Client(object):
if self.all_numpy_input: if self.all_numpy_input:
res = self.client_handle_.numpy_predict( res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch, float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch_handle, int_feed_names, lod_slot_batch, int_shape, fetch_names, result_batch_handle,
self.pid) self.pid)
elif self.has_numpy_input == False: elif self.has_numpy_input == False:
res = self.client_handle_.batch_predict( res = self.client_handle_.batch_predict(
...@@ -466,7 +483,7 @@ class MultiLangClient(object): ...@@ -466,7 +483,7 @@ class MultiLangClient(object):
if var.is_lod_tensor: if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name) self.lod_tensor_set_.add(var.alias_name)
def _pack_inference_request(self, feed, fetch, is_python): def _pack_inference_request(self, feed, fetch, is_python, batch=False):
req = multi_lang_general_model_service_pb2.InferenceRequest() req = multi_lang_general_model_service_pb2.InferenceRequest()
req.fetch_var_names.extend(fetch) req.fetch_var_names.extend(fetch)
req.is_python = is_python req.is_python = is_python
...@@ -477,12 +494,17 @@ class MultiLangClient(object): ...@@ -477,12 +494,17 @@ class MultiLangClient(object):
feed_batch = feed feed_batch = feed
else: else:
raise Exception("{} not support".format(type(feed))) raise Exception("{} not support".format(type(feed)))
req.feed_var_names.extend(feed_batch[0].keys()) for x in feed_batch[0].keys():
if ".lod" not in x:
req.feed_var_names.append(x)
init_feed_names = False init_feed_names = False
for feed_data in feed_batch: for feed_data in feed_batch:
inst = multi_lang_general_model_service_pb2.FeedInst() inst = multi_lang_general_model_service_pb2.FeedInst()
for name in req.feed_var_names: for name in req.feed_var_names:
tensor = multi_lang_general_model_service_pb2.Tensor() tensor = multi_lang_general_model_service_pb2.Tensor()
if "{}.lod".format(name) in feed_data:
var_lod = feed_data["{}.lod".format(name)]
tensor.lod.extend(var_lod)
var = feed_data[name] var = feed_data[name]
v_type = self.feed_types_[name] v_type = self.feed_types_[name]
if is_python: if is_python:
...@@ -536,6 +558,8 @@ class MultiLangClient(object): ...@@ -536,6 +558,8 @@ class MultiLangClient(object):
raise Exception("error tensor value type.") raise Exception("error tensor value type.")
else: else:
raise Exception("var must be list or ndarray.") raise Exception("var must be list or ndarray.")
if batch == False:
tensor.shape.append(1)
if isinstance(var, np.ndarray): if isinstance(var, np.ndarray):
tensor.shape.extend(list(var.shape)) tensor.shape.extend(list(var.shape))
else: else:
...@@ -602,12 +626,12 @@ class MultiLangClient(object): ...@@ -602,12 +626,12 @@ class MultiLangClient(object):
fetch, fetch,
need_variant_tag=False, need_variant_tag=False,
asyn=False, asyn=False,
is_python=True): is_python=True,
batch=False):
if not asyn: if not asyn:
try: try:
self.profile_.record('py_prepro_0') self.profile_.record('py_prepro_0')
req = self._pack_inference_request( req = self._pack_inference_request(feed, fetch, is_python=is_python, batch=batch)
feed, fetch, is_python=is_python)
self.profile_.record('py_prepro_1') self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0') self.profile_.record('py_client_infer_0')
...@@ -626,7 +650,7 @@ class MultiLangClient(object): ...@@ -626,7 +650,7 @@ class MultiLangClient(object):
except grpc.RpcError as e: except grpc.RpcError as e:
return {"serving_status_code": e.code()} return {"serving_status_code": e.code()}
else: else:
req = self._pack_inference_request(feed, fetch, is_python=is_python) req = self._pack_inference_request(feed, fetch, is_python=is_python, batch=batch)
call_future = self.stub_.Inference.future( call_future = self.stub_.Inference.future(
req, timeout=self.rpc_timeout_s_) req, timeout=self.rpc_timeout_s_)
return MultiLangPredictFuture( return MultiLangPredictFuture(
......
...@@ -533,6 +533,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -533,6 +533,8 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
raise Exception("error type.") raise Exception("error type.")
data.shape = list(feed_inst.tensor_array[idx].shape) data.shape = list(feed_inst.tensor_array[idx].shape)
feed_dict[name] = data feed_dict[name] = data
if len(var.lod) > 0:
feed_dict["{}.lod".format(name)] = var.lod
feed_batch.append(feed_dict) feed_batch.append(feed_dict)
return feed_batch, fetch_names, is_python return feed_batch, fetch_names, is_python
...@@ -569,12 +571,16 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -569,12 +571,16 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
raise Exception("error type.") raise Exception("error type.")
tensor.shape.extend(list(model_result[name].shape)) tensor.shape.extend(list(model_result[name].shape))
if name in self.lod_tensor_set_: if name in self.lod_tensor_set_:
tensor.lod.extend(model_result["{}.lod".format(name)] tmp_lod = model_result["{}.lod".format(name)]
.tolist()) if isinstance(tmp_lod, list):
tensor.lod.extend(tmp_lod)
else:
tensor.lod.extend(tmp_lod.tolist())
inst.tensor_array.append(tensor) inst.tensor_array.append(tensor)
model_output.insts.append(inst) model_output.insts.append(inst)
model_output.engine_name = model_name model_output.engine_name = model_name
resp.outputs.append(model_output) resp.outputs.append(model_output)
#print("resp", resp)
return resp return resp
def SetTimeout(self, request, context): def SetTimeout(self, request, context):
...@@ -587,15 +593,21 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc. ...@@ -587,15 +593,21 @@ class MultiLangServerServiceServicer(multi_lang_general_model_service_pb2_grpc.
return resp return resp
def Inference(self, request, context): def Inference(self, request, context):
feed_dict, fetch_names, is_python = self._unpack_inference_request( try:
request) feed_dict, fetch_names, is_python = self._unpack_inference_request(
if self.local_predictor == None: request)
ret = self.bclient_.predict( if self.local_predictor == None:
feed=feed_dict, fetch=fetch_names, need_variant_tag=True) ret = self.bclient_.predict(
else: feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
ret = [self.local_predictor.predict( else:
feed=feed_dict[0], fetch=fetch_names), "VariantTagNeeded"] ret = [self.local_predictor.predict(
return self._pack_inference_response(ret, fetch_names, is_python) feed=feed_dict[0], fetch=fetch_names), "VariantTagNeeded"]
#print("ret", ret)
res = self._pack_inference_response(ret, fetch_names, is_python)
return res
except Exception as e:
import traceback
print(traceback.format_exc())
def GetClientConfig(self, request, context): def GetClientConfig(self, request, context):
resp = multi_lang_general_model_service_pb2.GetClientConfigResponse() resp = multi_lang_general_model_service_pb2.GetClientConfigResponse()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册