未验证 提交 d950d02c 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #303 from MRXLT/general-server-fix-v1

fix bug for batch predict
......@@ -228,16 +228,17 @@ class Client(object):
fetch_names, result_batch, self.pid)
result_map_batch = []
for index in range(batch_size):
result_map = {}
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = result_batch.get_int64_by_name(name)[
index]
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result_batch.get_float_by_name(name)[
index]
result_map_batch.append(result_map)
result_map = {}
for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = result_batch.get_int64_by_name(name)
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result_batch.get_float_by_name(name)
for i in range(batch_size):
single_result = {}
for key in result_map:
single_result[key] = result_map[key][i]
result_map_batch.append(single_result)
return result_map_batch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册