提交 1e71a345 编写于 作者: B barrierye

fix: result_handle_ not thread-safe in client

上级 895a6960
......@@ -21,6 +21,7 @@ import google.protobuf.text_format
import numpy as np
import time
import sys
from .serving_client import PredictorRes
int_type = 0
float_type = 1
......@@ -108,7 +109,6 @@ class Client(object):
self.feed_names_ = []
self.fetch_names_ = []
self.client_handle_ = None
self.result_handle_ = None
self.feed_shapes_ = {}
self.feed_types_ = {}
self.feed_names_to_idx_ = {}
......@@ -122,7 +122,6 @@ class Client(object):
def load_client_config(self, path):
from .serving_client import PredictorClient
from .serving_client import PredictorRes
model_conf = m_config.GeneralModelConfig()
f = open(path, 'r')
model_conf = google.protobuf.text_format.Merge(
......@@ -132,7 +131,6 @@ class Client(object):
# get feed vars, fetch vars
# get feed shapes, feed types
# map feed names to index
self.result_handle_ = PredictorRes()
self.client_handle_ = PredictorClient()
self.client_handle_.init(path)
if "FLAGS_max_body_size" not in os.environ:
......@@ -302,15 +300,17 @@ class Client(object):
self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0')
result_batch = self.result_handle_
result_batch_handle = PredictorRes()
if self.all_numpy_input:
res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch, self.pid)
int_feed_names, int_shape, fetch_names, result_batch_handle,
self.pid)
elif self.has_numpy_input == False:
res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch, self.pid)
int_feed_names, int_shape, fetch_names, result_batch_handle,
self.pid)
else:
raise SystemExit(
"Please make sure the inputs are all in list type or all in numpy.array type"
......@@ -323,26 +323,28 @@ class Client(object):
return None
multi_result_map = []
model_engine_names = result_batch.get_engine_names()
model_engine_names = result_batch_handle.get_engine_names()
for mi, engine_name in enumerate(model_engine_names):
result_map = {}
# result map needs to be a numpy array
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
# result_map[name] will be py::array(numpy array)
result_map[name] = result_batch.get_int64_by_name(mi, name)
shape = result_batch.get_shape(mi, name)
result_map[name] = result_batch_handle.get_int64_by_name(
mi, name)
shape = result_batch_handle.get_shape(mi, name)
result_map[name].shape = shape
if name in self.lod_tensor_set:
result_map["{}.lod".format(name)] = np.array(
result_batch.get_lod(mi, name))
result_map["{}.lod".format(
name)] = result_batch_handle.get_lod(mi, name)
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result_batch.get_float_by_name(mi, name)
shape = result_batch.get_shape(mi, name)
result_map[name] = result_batch_handle.get_float_by_name(
mi, name)
shape = result_batch_handle.get_shape(mi, name)
result_map[name].shape = shape
if name in self.lod_tensor_set:
result_map["{}.lod".format(name)] = np.array(
result_batch.get_lod(mi, name))
result_map["{}.lod".format(
name)] = result_batch_handle.get_lod(mi, name)
multi_result_map.append(result_map)
ret = None
if len(model_engine_names) == 1:
......@@ -360,7 +362,7 @@ class Client(object):
# When using the A/B test, the tag of variant needs to be returned
return ret if not need_variant_tag else [
ret, self.result_handle_.variant_tag()
ret, result_batch_handle.variant_tag()
]
def release(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册