提交 1e71a345 编写于 作者: B barrierye

fix: result_handle_ not thread-safe in client

上级 895a6960
...@@ -21,6 +21,7 @@ import google.protobuf.text_format ...@@ -21,6 +21,7 @@ import google.protobuf.text_format
import numpy as np import numpy as np
import time import time
import sys import sys
from .serving_client import PredictorRes
int_type = 0 int_type = 0
float_type = 1 float_type = 1
...@@ -108,7 +109,6 @@ class Client(object): ...@@ -108,7 +109,6 @@ class Client(object):
self.feed_names_ = [] self.feed_names_ = []
self.fetch_names_ = [] self.fetch_names_ = []
self.client_handle_ = None self.client_handle_ = None
self.result_handle_ = None
self.feed_shapes_ = {} self.feed_shapes_ = {}
self.feed_types_ = {} self.feed_types_ = {}
self.feed_names_to_idx_ = {} self.feed_names_to_idx_ = {}
...@@ -122,7 +122,6 @@ class Client(object): ...@@ -122,7 +122,6 @@ class Client(object):
def load_client_config(self, path): def load_client_config(self, path):
from .serving_client import PredictorClient from .serving_client import PredictorClient
from .serving_client import PredictorRes
model_conf = m_config.GeneralModelConfig() model_conf = m_config.GeneralModelConfig()
f = open(path, 'r') f = open(path, 'r')
model_conf = google.protobuf.text_format.Merge( model_conf = google.protobuf.text_format.Merge(
...@@ -132,7 +131,6 @@ class Client(object): ...@@ -132,7 +131,6 @@ class Client(object):
# get feed vars, fetch vars # get feed vars, fetch vars
# get feed shapes, feed types # get feed shapes, feed types
# map feed names to index # map feed names to index
self.result_handle_ = PredictorRes()
self.client_handle_ = PredictorClient() self.client_handle_ = PredictorClient()
self.client_handle_.init(path) self.client_handle_.init(path)
if "FLAGS_max_body_size" not in os.environ: if "FLAGS_max_body_size" not in os.environ:
...@@ -302,15 +300,17 @@ class Client(object): ...@@ -302,15 +300,17 @@ class Client(object):
self.profile_.record('py_prepro_1') self.profile_.record('py_prepro_1')
self.profile_.record('py_client_infer_0') self.profile_.record('py_client_infer_0')
result_batch = self.result_handle_ result_batch_handle = PredictorRes()
if self.all_numpy_input: if self.all_numpy_input:
res = self.client_handle_.numpy_predict( res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch, float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch, self.pid) int_feed_names, int_shape, fetch_names, result_batch_handle,
self.pid)
elif self.has_numpy_input == False: elif self.has_numpy_input == False:
res = self.client_handle_.batch_predict( res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch, float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch, self.pid) int_feed_names, int_shape, fetch_names, result_batch_handle,
self.pid)
else: else:
raise SystemExit( raise SystemExit(
"Please make sure the inputs are all in list type or all in numpy.array type" "Please make sure the inputs are all in list type or all in numpy.array type"
...@@ -323,26 +323,28 @@ class Client(object): ...@@ -323,26 +323,28 @@ class Client(object):
return None return None
multi_result_map = [] multi_result_map = []
model_engine_names = result_batch.get_engine_names() model_engine_names = result_batch_handle.get_engine_names()
for mi, engine_name in enumerate(model_engine_names): for mi, engine_name in enumerate(model_engine_names):
result_map = {} result_map = {}
# result map needs to be a numpy array # result map needs to be a numpy array
for i, name in enumerate(fetch_names): for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type: if self.fetch_names_to_type_[name] == int_type:
# result_map[name] will be py::array(numpy array) # result_map[name] will be py::array(numpy array)
result_map[name] = result_batch.get_int64_by_name(mi, name) result_map[name] = result_batch_handle.get_int64_by_name(
shape = result_batch.get_shape(mi, name) mi, name)
shape = result_batch_handle.get_shape(mi, name)
result_map[name].shape = shape result_map[name].shape = shape
if name in self.lod_tensor_set: if name in self.lod_tensor_set:
result_map["{}.lod".format(name)] = np.array( result_map["{}.lod".format(
result_batch.get_lod(mi, name)) name)] = result_batch_handle.get_lod(mi, name)
elif self.fetch_names_to_type_[name] == float_type: elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result_batch.get_float_by_name(mi, name) result_map[name] = result_batch_handle.get_float_by_name(
shape = result_batch.get_shape(mi, name) mi, name)
shape = result_batch_handle.get_shape(mi, name)
result_map[name].shape = shape result_map[name].shape = shape
if name in self.lod_tensor_set: if name in self.lod_tensor_set:
result_map["{}.lod".format(name)] = np.array( result_map["{}.lod".format(
result_batch.get_lod(mi, name)) name)] = result_batch_handle.get_lod(mi, name)
multi_result_map.append(result_map) multi_result_map.append(result_map)
ret = None ret = None
if len(model_engine_names) == 1: if len(model_engine_names) == 1:
...@@ -360,7 +362,7 @@ class Client(object): ...@@ -360,7 +362,7 @@ class Client(object):
# When using the A/B test, the tag of variant needs to be returned # When using the A/B test, the tag of variant needs to be returned
return ret if not need_variant_tag else [ return ret if not need_variant_tag else [
ret, self.result_handle_.variant_tag() ret, result_batch_handle.variant_tag()
] ]
def release(self): def release(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册