提交 ca754e46 编写于 作者: B barriery 提交者: GitHub

Merge pull request #601 from barrierye/fix-not-thread-safe-in-client

Fix batch_handle_ in Client not thread safe
...@@ -78,12 +78,18 @@ class ModelRes { ...@@ -78,12 +78,18 @@ class ModelRes {
std::vector<float>&& get_float_by_name_with_rv(const std::string& name) { std::vector<float>&& get_float_by_name_with_rv(const std::string& name) {
return std::move(_float_value_map[name]); return std::move(_float_value_map[name]);
} }
const std::vector<int>& get_shape(const std::string& name) { const std::vector<int>& get_shape_by_name(const std::string& name) {
return _shape_map[name]; return _shape_map[name];
} }
const std::vector<int>& get_lod(const std::string& name) { std::vector<int>&& get_shape_by_name_with_rv(const std::string& name) {
return std::move(_shape_map[name]);
}
const std::vector<int>& get_lod_by_name(const std::string& name) {
return _lod_map[name]; return _lod_map[name];
} }
std::vector<int>&& get_lod_by_name_with_rv(const std::string& name) {
return std::move(_lod_map[name]);
}
void set_engine_name(const std::string& engine_name) { void set_engine_name(const std::string& engine_name) {
_engine_name = engine_name; _engine_name = engine_name;
} }
...@@ -139,13 +145,21 @@ class PredictorRes { ...@@ -139,13 +145,21 @@ class PredictorRes {
const std::string& name) { const std::string& name) {
return std::move(_models[model_idx].get_float_by_name_with_rv(name)); return std::move(_models[model_idx].get_float_by_name_with_rv(name));
} }
const std::vector<int>& get_shape(const int model_idx, const std::vector<int>& get_shape_by_name(const int model_idx,
const std::string& name) { const std::string& name) {
return _models[model_idx].get_shape(name); return _models[model_idx].get_shape_by_name(name);
}
const std::vector<int>&& get_shape_by_name_with_rv(const int model_idx,
const std::string& name) {
return std::move(_models[model_idx].get_shape_by_name_with_rv(name));
} }
const std::vector<int>& get_lod(const int model_idx, const std::vector<int>& get_lod_by_name(const int model_idx,
const std::string& name) { const std::string& name) {
return _models[model_idx].get_lod(name); return _models[model_idx].get_lod_by_name(name);
}
const std::vector<int>&& get_lod_by_name_with_rv(const int model_idx,
const std::string& name) {
return std::move(_models[model_idx].get_lod_by_name_with_rv(name));
} }
void add_model_res(ModelRes&& res) { void add_model_res(ModelRes&& res) {
_engine_names.push_back(res.engine_name()); _engine_names.push_back(res.engine_name());
......
...@@ -51,14 +51,22 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -51,14 +51,22 @@ PYBIND11_MODULE(serving_client, m) {
}) })
.def("get_shape", .def("get_shape",
[](PredictorRes &self, int model_idx, std::string &name) { [](PredictorRes &self, int model_idx, std::string &name) {
return self.get_shape(model_idx, name); std::vector<int> *ptr = new std::vector<int>(
}, std::move(self.get_shape_by_name_with_rv(model_idx, name)));
py::return_value_policy::reference) auto capsule = py::capsule(ptr, [](void *p) {
delete reinterpret_cast<std::vector<int> *>(p);
});
return py::array(ptr->size(), ptr->data(), capsule);
})
.def("get_lod", .def("get_lod",
[](PredictorRes &self, int model_idx, std::string &name) { [](PredictorRes &self, int model_idx, std::string &name) {
return self.get_lod(model_idx, name); std::vector<int> *ptr = new std::vector<int>(
}, std::move(self.get_lod_by_name_with_rv(model_idx, name)));
py::return_value_policy::reference) auto capsule = py::capsule(ptr, [](void *p) {
delete reinterpret_cast<std::vector<int> *>(p);
});
return py::array(ptr->size(), ptr->data(), capsule);
})
.def("variant_tag", [](PredictorRes &self) { return self.variant_tag(); }) .def("variant_tag", [](PredictorRes &self) { return self.variant_tag(); })
.def("get_engine_names", .def("get_engine_names",
[](PredictorRes &self) { return self.get_engine_names(); }); [](PredictorRes &self) { return self.get_engine_names(); });
...@@ -109,7 +117,8 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -109,7 +117,8 @@ PYBIND11_MODULE(serving_client, m) {
fetch_name, fetch_name,
predict_res_batch, predict_res_batch,
pid); pid);
}) },
py::call_guard<py::gil_scoped_release>())
.def("numpy_predict", .def("numpy_predict",
[](PredictorClient &self, [](PredictorClient &self,
const std::vector<std::vector<py::array_t<float>>> const std::vector<std::vector<py::array_t<float>>>
......
...@@ -21,6 +21,7 @@ import google.protobuf.text_format ...@@ -21,6 +21,7 @@ import google.protobuf.text_format
import numpy as np import numpy as np
import time import time
import sys import sys
from .serving_client import PredictorRes
int_type = 0 int_type = 0
float_type = 1 float_type = 1
...@@ -108,7 +109,6 @@ class Client(object): ...@@ -108,7 +109,6 @@ class Client(object):
self.feed_names_ = [] self.feed_names_ = []
self.fetch_names_ = [] self.fetch_names_ = []
self.client_handle_ = None self.client_handle_ = None
self.result_handle_ = None
self.feed_shapes_ = {} self.feed_shapes_ = {}
self.feed_types_ = {} self.feed_types_ = {}
self.feed_names_to_idx_ = {} self.feed_names_to_idx_ = {}
...@@ -122,7 +122,6 @@ class Client(object): ...@@ -122,7 +122,6 @@ class Client(object):
def load_client_config(self, path): def load_client_config(self, path):
from .serving_client import PredictorClient from .serving_client import PredictorClient
from .serving_client import PredictorRes
model_conf = m_config.GeneralModelConfig() model_conf = m_config.GeneralModelConfig()
f = open(path, 'r') f = open(path, 'r')
model_conf = google.protobuf.text_format.Merge( model_conf = google.protobuf.text_format.Merge(
...@@ -132,7 +131,6 @@ class Client(object): ...@@ -132,7 +131,6 @@ class Client(object):
# get feed vars, fetch vars # get feed vars, fetch vars
# get feed shapes, feed types # get feed shapes, feed types
# map feed names to index # map feed names to index
self.result_handle_ = PredictorRes()
self.client_handle_ = PredictorClient() self.client_handle_ = PredictorClient()
self.client_handle_.init(path) self.client_handle_.init(path)
if "FLAGS_max_body_size" not in os.environ: if "FLAGS_max_body_size" not in os.environ:
...@@ -293,15 +291,17 @@ class Client(object): ...@@ -293,15 +291,17 @@ class Client(object):
self.profile_.record('py_prepro_1') self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0') self.profile_.record('py_client_infer_0')
result_batch = self.result_handle_ result_batch_handle = PredictorRes()
if self.all_numpy_input: if self.all_numpy_input:
res = self.client_handle_.numpy_predict( res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch, float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch, self.pid) int_feed_names, int_shape, fetch_names, result_batch_handle,
self.pid)
elif self.has_numpy_input == False: elif self.has_numpy_input == False:
res = self.client_handle_.batch_predict( res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch, float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch, self.pid) int_feed_names, int_shape, fetch_names, result_batch_handle,
self.pid)
else: else:
raise SystemExit( raise SystemExit(
"Please make sure the inputs are all in list type or all in numpy.array type" "Please make sure the inputs are all in list type or all in numpy.array type"
...@@ -314,26 +314,28 @@ class Client(object): ...@@ -314,26 +314,28 @@ class Client(object):
return None return None
multi_result_map = [] multi_result_map = []
model_engine_names = result_batch.get_engine_names() model_engine_names = result_batch_handle.get_engine_names()
for mi, engine_name in enumerate(model_engine_names): for mi, engine_name in enumerate(model_engine_names):
result_map = {} result_map = {}
# result map needs to be a numpy array # result map needs to be a numpy array
for i, name in enumerate(fetch_names): for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type: if self.fetch_names_to_type_[name] == int_type:
# result_map[name] will be py::array(numpy array) # result_map[name] will be py::array(numpy array)
result_map[name] = result_batch.get_int64_by_name(mi, name) result_map[name] = result_batch_handle.get_int64_by_name(
shape = result_batch.get_shape(mi, name) mi, name)
shape = result_batch_handle.get_shape(mi, name)
result_map[name].shape = shape result_map[name].shape = shape
if name in self.lod_tensor_set: if name in self.lod_tensor_set:
result_map["{}.lod".format(name)] = np.array( result_map["{}.lod".format(
result_batch.get_lod(mi, name)) name)] = result_batch_handle.get_lod(mi, name)
elif self.fetch_names_to_type_[name] == float_type: elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result_batch.get_float_by_name(mi, name) result_map[name] = result_batch_handle.get_float_by_name(
shape = result_batch.get_shape(mi, name) mi, name)
shape = result_batch_handle.get_shape(mi, name)
result_map[name].shape = shape result_map[name].shape = shape
if name in self.lod_tensor_set: if name in self.lod_tensor_set:
result_map["{}.lod".format(name)] = np.array( result_map["{}.lod".format(
result_batch.get_lod(mi, name)) name)] = result_batch_handle.get_lod(mi, name)
multi_result_map.append(result_map) multi_result_map.append(result_map)
ret = None ret = None
if len(model_engine_names) == 1: if len(model_engine_names) == 1:
...@@ -351,7 +353,7 @@ class Client(object): ...@@ -351,7 +353,7 @@ class Client(object):
# When using the A/B test, the tag of variant needs to be returned # When using the A/B test, the tag of variant needs to be returned
return ret if not need_variant_tag else [ return ret if not need_variant_tag else [
ret, self.result_handle_.variant_tag() ret, result_batch_handle.variant_tag()
] ]
def release(self): def release(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册