未验证 提交 d6e73a73 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #322 from guru4elephant/unify_predict_refined_again

Unify predict refined again
......@@ -79,6 +79,8 @@ class Client(object):
self.feed_names_to_idx_ = {}
self.rpath()
self.pid = os.getpid()
self.producers = []
self.consumer = None
def rpath(self):
lib_path = os.path.dirname(paddle_serving_client.__file__)
......@@ -137,7 +139,6 @@ class Client(object):
predictor_sdk = SDKConfig()
predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc()
print(sdk_desc)
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
))
......@@ -155,44 +156,26 @@ class Client(object):
raise SystemExit("The shape of feed tensor {} not match.".format(
key))
def predict(self, feed={}, fetch=[]):
int_slot = []
float_slot = []
int_feed_names = []
float_feed_names = []
fetch_names = []
for key in feed:
self.shape_check(feed, key)
if key not in self.feed_names_:
continue
if self.feed_types_[key] == int_type:
int_feed_names.append(key)
int_slot.append(feed[key])
elif self.feed_types_[key] == float_type:
float_feed_names.append(key)
float_slot.append(feed[key])
for key in fetch:
if key in self.fetch_names_:
fetch_names.append(key)
def predict(self, feed=None, fetch=None):
if feed is None or fetch is None:
raise ValueError("You should specify feed and fetch for prediction")
fetch_list = []
if isinstance(fetch, str):
fetch_list = [fetch]
elif isinstance(fetch, list):
fetch_list = fetch
else:
raise ValueError("fetch only accepts string and list of string")
feed_batch = []
if isinstance(feed, dict):
feed_batch.append(feed)
elif isinstance(feed, list):
feed_batch = feed
else:
raise ValueError("feed only accepts dict and list of dict")
ret = self.client_handle_.predict(float_slot, float_feed_names,
int_slot, int_feed_names, fetch_names,
self.result_handle_, self.pid)
result_map = {}
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = self.result_handle_.get_int64_by_name(name)[
0]
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = self.result_handle_.get_float_by_name(name)[
0]
return result_map
def batch_predict(self, feed_batch=[], fetch=[]):
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
......@@ -200,28 +183,33 @@ class Client(object):
fetch_names = []
counter = 0
batch_size = len(feed_batch)
for feed in feed_batch:
for key in fetch_list:
if key in self.fetch_names_:
fetch_names.append(key)
if len(fetch_names) == 0:
raise ValueError(
"fetch names should not be empty or out of saved fetch list")
return {}
for i, feed_i in enumerate(feed_batch):
int_slot = []
float_slot = []
for key in feed:
for key in feed_i:
if key not in self.feed_names_:
continue
if self.feed_types_[key] == int_type:
if counter == 0:
if i == 0:
int_feed_names.append(key)
int_slot.append(feed[key])
elif self.feed_types_[key] == float_type:
if counter == 0:
if i == 0:
float_feed_names.append(key)
float_slot.append(feed[key])
counter += 1
float_slot.append(feed_i[key])
int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot)
for key in fetch:
if key in self.fetch_names_:
fetch_names.append(key)
result_batch = self.result_handle_
res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, int_slot_batch, int_feed_names,
......@@ -240,7 +228,10 @@ class Client(object):
single_result[key] = result_map[key][i]
result_map_batch.append(single_result)
return result_map_batch
if batch_size == 1:
return result_map_batch[0]
else:
return result_map_batch
def release(self):
self.client_handle_.destroy_predictor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册