未验证 提交 d6e73a73 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #322 from guru4elephant/unify_predict_refined_again

Unify predict refined again
...@@ -79,6 +79,8 @@ class Client(object): ...@@ -79,6 +79,8 @@ class Client(object):
self.feed_names_to_idx_ = {} self.feed_names_to_idx_ = {}
self.rpath() self.rpath()
self.pid = os.getpid() self.pid = os.getpid()
self.producers = []
self.consumer = None
def rpath(self): def rpath(self):
lib_path = os.path.dirname(paddle_serving_client.__file__) lib_path = os.path.dirname(paddle_serving_client.__file__)
...@@ -137,7 +139,6 @@ class Client(object): ...@@ -137,7 +139,6 @@ class Client(object):
predictor_sdk = SDKConfig() predictor_sdk = SDKConfig()
predictor_sdk.set_server_endpoints(endpoints) predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc() sdk_desc = predictor_sdk.gen_desc()
print(sdk_desc)
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString( self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
)) ))
...@@ -155,44 +156,26 @@ class Client(object): ...@@ -155,44 +156,26 @@ class Client(object):
raise SystemExit("The shape of feed tensor {} not match.".format( raise SystemExit("The shape of feed tensor {} not match.".format(
key)) key))
def predict(self, feed={}, fetch=[]): def predict(self, feed=None, fetch=None):
int_slot = [] if feed is None or fetch is None:
float_slot = [] raise ValueError("You should specify feed and fetch for prediction")
int_feed_names = []
float_feed_names = [] fetch_list = []
fetch_names = [] if isinstance(fetch, str):
fetch_list = [fetch]
for key in feed: elif isinstance(fetch, list):
self.shape_check(feed, key) fetch_list = fetch
if key not in self.feed_names_: else:
continue raise ValueError("fetch only accepts string and list of string")
if self.feed_types_[key] == int_type:
int_feed_names.append(key) feed_batch = []
int_slot.append(feed[key]) if isinstance(feed, dict):
elif self.feed_types_[key] == float_type: feed_batch.append(feed)
float_feed_names.append(key) elif isinstance(feed, list):
float_slot.append(feed[key]) feed_batch = feed
else:
for key in fetch: raise ValueError("feed only accepts dict and list of dict")
if key in self.fetch_names_:
fetch_names.append(key)
ret = self.client_handle_.predict(float_slot, float_feed_names,
int_slot, int_feed_names, fetch_names,
self.result_handle_, self.pid)
result_map = {}
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = self.result_handle_.get_int64_by_name(name)[
0]
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = self.result_handle_.get_float_by_name(name)[
0]
return result_map
def batch_predict(self, feed_batch=[], fetch=[]):
int_slot_batch = [] int_slot_batch = []
float_slot_batch = [] float_slot_batch = []
int_feed_names = [] int_feed_names = []
...@@ -200,28 +183,33 @@ class Client(object): ...@@ -200,28 +183,33 @@ class Client(object):
fetch_names = [] fetch_names = []
counter = 0 counter = 0
batch_size = len(feed_batch) batch_size = len(feed_batch)
for feed in feed_batch:
for key in fetch_list:
if key in self.fetch_names_:
fetch_names.append(key)
if len(fetch_names) == 0:
raise ValueError(
"fetch names should not be empty or out of saved fetch list")
return {}
for i, feed_i in enumerate(feed_batch):
int_slot = [] int_slot = []
float_slot = [] float_slot = []
for key in feed: for key in feed_i:
if key not in self.feed_names_: if key not in self.feed_names_:
continue continue
if self.feed_types_[key] == int_type: if self.feed_types_[key] == int_type:
if counter == 0: if i == 0:
int_feed_names.append(key) int_feed_names.append(key)
int_slot.append(feed[key]) int_slot.append(feed[key])
elif self.feed_types_[key] == float_type: elif self.feed_types_[key] == float_type:
if counter == 0: if i == 0:
float_feed_names.append(key) float_feed_names.append(key)
float_slot.append(feed[key]) float_slot.append(feed_i[key])
counter += 1
int_slot_batch.append(int_slot) int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot) float_slot_batch.append(float_slot)
for key in fetch:
if key in self.fetch_names_:
fetch_names.append(key)
result_batch = self.result_handle_ result_batch = self.result_handle_
res = self.client_handle_.batch_predict( res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, int_slot_batch, int_feed_names, float_slot_batch, float_feed_names, int_slot_batch, int_feed_names,
...@@ -240,7 +228,10 @@ class Client(object): ...@@ -240,7 +228,10 @@ class Client(object):
single_result[key] = result_map[key][i] single_result[key] = result_map[key][i]
result_map_batch.append(single_result) result_map_batch.append(single_result)
return result_map_batch if batch_size == 1:
return result_map_batch[0]
else:
return result_map_batch
def release(self): def release(self):
self.client_handle_.destroy_predictor() self.client_handle_.destroy_predictor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册