提交 43d8782c 编写于 作者: B barrierye

update code for multi-respose-op

上级 b1968e92
...@@ -246,26 +246,38 @@ class Client(object): ...@@ -246,26 +246,38 @@ class Client(object):
if res == -1: if res == -1:
return None return None
result_map_batch = [] multi_result_map_batch = []
result_map = {} model_num = result_batch.model_num()
for i, name in enumerate(fetch_names): for i in range(model_num):
if self.fetch_names_to_type_[name] == int_type: result_map_batch = []
result_map[name] = result_batch.get_int64_by_name(name) result_map = {}
elif self.fetch_names_to_type_[name] == float_type: for i, name in enumerate(fetch_names):
result_map[name] = result_batch.get_float_by_name(name) if self.fetch_names_to_type_[name] == int_type:
for i in range(batch_size): result_map[name] = result_batch.get_int64_by_name(i, name)
single_result = {} elif self.fetch_names_to_type_[name] == float_type:
for key in result_map: result_map[name] = result_batch.get_float_by_name(i, name)
single_result[key] = result_map[key][i] for i in range(batch_size):
result_map_batch.append(single_result) single_result = {}
for key in result_map:
if batch_size == 1: single_result[key] = result_map[key][i]
return [result_map_batch[0], self.result_handle_.variant_tag() result_map_batch.append(single_result)
] if need_variant_tag else result_map_batch[0] multi_result_map_batch.append(result_map_batch)
if model_num == 1:
if batch_size == 1:
return [multi_result_map_batch[0][0], self.result_handle_.variant_tag()
] if need_variant_tag else multi_result_map_batch[0][0]
else:
return [multi_result_map_batch[0], self.result_handle_.variant_tag()
] if need_variant_tag else multi_result_map_batch[0]
else: else:
return [result_map_batch, self.result_handle_.variant_tag() if batch_size == 1:
] if need_variant_tag else result_map_batch multi_result_map = [result_map_batch[0] for result_map_batch in multi_result_map_batch]
return [multi_result_map, self.result_handle_.variant_tag()
] if need_variant_tag else multi_result_map
else:
return [multi_result_map_batch, self.result_handle_.variant_tag()
] if need_variant_tag else multi_result_map_batch
def release(self): def release(self):
self.client_handle_.destroy_predictor() self.client_handle_.destroy_predictor()
self.client_handle_ = None self.client_handle_ = None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册