提交 ff597a1d 编写于 作者: G guru4elephant

make predict return value right

上级 a52ecc75
......@@ -108,10 +108,16 @@ class Client(object):
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.feed_shapes_ = [var.shape for var in model_conf.feed_var]
self.feed_names_to_idx_ = {}
self.fetch_names_to_type_ = {}
self.fetch_names_to_idx_ = {}
for i, var in enumerate(model_conf.feed_var):
self.feed_names_to_idx_[var.alias_name] = i
self.feed_types_[var.alias_name] = var.feed_type
for i, var in enumerate(model_conf.fetch_var):
self.fetch_names_to_idx_[var.alias_name] = i
self.fetch_names_to_type_[var.alias_name] = var.fetch_type
return
def connect(self, endpoints):
......@@ -121,8 +127,8 @@ class Client(object):
predictor_sdk = SDKConfig()
predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc()
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
))
self.client_handle_.create_predictor_by_desc(
sdk_desc.SerializeToString())
def get_feed_names(self):
......@@ -131,7 +137,7 @@ class Client(object):
def get_fetch_names(self):
return self.fetch_names_
def predict(self, feed={}, fetch=[], profile=False):
def predict(self, feed={}, fetch=[]):
int_slot = []
float_slot = []
int_feed_names = []
......@@ -151,23 +157,19 @@ class Client(object):
if key in self.fetch_names_:
fetch_names.append(key)
'''
result = self.client_handle_.predict(
float_slot, float_feed_names, int_slot, int_feed_names, fetch_names)
'''
ret = self.client_handle_.predict(
float_slot, float_feed_names, int_slot, int_feed_names, fetch_names, self.result_handle_)
float_slot, float_feed_names, int_slot,
int_feed_names, fetch_names, self.result_handle_)
# TODO(guru4elephant): the order of fetch var name should be consistent with
# general_model_config, this is not friendly
# In the future, we need make the number of fetched variable changable
result_map = {}
for i, name in enumerate(fetch_names):
result_map[name] = self.result_handle_.get_float_by_name(name)
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = self.result_handle_.get_int64_by_name(name)[0]
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = self.result_handle_.get_float_by_name(name)[0]
return result_map
def batch_predict(self, feed_batch=[], fetch=[], profile=False):
def batch_predict(self, feed_batch=[], fetch=[]):
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册