提交 43d8782c 编写于 作者: B barrierye

update code for multi-respose-op

上级 b1968e92
......@@ -246,26 +246,38 @@ class Client(object):
if res == -1:
return None
result_map_batch = []
result_map = {}
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = result_batch.get_int64_by_name(name)
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result_batch.get_float_by_name(name)
for i in range(batch_size):
single_result = {}
for key in result_map:
single_result[key] = result_map[key][i]
result_map_batch.append(single_result)
if batch_size == 1:
return [result_map_batch[0], self.result_handle_.variant_tag()
] if need_variant_tag else result_map_batch[0]
multi_result_map_batch = []
model_num = result_batch.model_num()
for i in range(model_num):
result_map_batch = []
result_map = {}
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = result_batch.get_int64_by_name(i, name)
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result_batch.get_float_by_name(i, name)
for i in range(batch_size):
single_result = {}
for key in result_map:
single_result[key] = result_map[key][i]
result_map_batch.append(single_result)
multi_result_map_batch.append(result_map_batch)
if model_num == 1:
if batch_size == 1:
return [multi_result_map_batch[0][0], self.result_handle_.variant_tag()
] if need_variant_tag else multi_result_map_batch[0][0]
else:
return [multi_result_map_batch[0], self.result_handle_.variant_tag()
] if need_variant_tag else multi_result_map_batch[0]
else:
return [result_map_batch, self.result_handle_.variant_tag()
] if need_variant_tag else result_map_batch
if batch_size == 1:
multi_result_map = [result_map_batch[0] for result_map_batch in multi_result_map_batch]
return [multi_result_map, self.result_handle_.variant_tag()
] if need_variant_tag else multi_result_map
else:
return [multi_result_map_batch, self.result_handle_.variant_tag()
] if need_variant_tag else multi_result_map_batch
def release(self):
self.client_handle_.destroy_predictor()
self.client_handle_ = None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册