未验证 提交 a8e39a97 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge pull request #1008 from TeslaZhao/develop

Fix Bugs for batching predict of pipeline
......@@ -709,11 +709,39 @@ class Op(object):
# combine samples to batch
one_input = preped_data_dict[data_ids[0]]
feed_batch = []
feed_dict = {}
input_offset = None
cur_offset = 0
input_offset_dict = {}
if isinstance(one_input, dict):
# sample input
feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
input_offset = list(range(len(data_ids) + 1))
if len(data_ids) == 1:
feed_batch = [
preped_data_dict[data_id] for data_id in data_ids
]
else:
for data_id in data_ids:
for key, val in preped_data_dict[data_id].items():
has_val = feed_dict.get(key)
if has_val is None:
feed_dict[key] = val
continue
# merge 2 np.arrray
if isinstance(val, np.ndarray):
feed_dict[key] = np.append(
feed_dict[key], val, axis=0)
feed_batch.append(feed_dict)
for data_id in data_ids:
start = cur_offset
for key, val in preped_data_dict[data_id].items():
if isinstance(val, (list, np.ndarray)):
cur_offset += len(val)
else:
cur_offset += 1
break
input_offset_dict[data_id] = [start, cur_offset]
elif isinstance(one_input, list):
# batch input
input_offset = [0]
......@@ -796,8 +824,10 @@ class Op(object):
for name in var_names:
lod_offset_name = "{}.lod".format(name)
if lod_offset_name in var_names:
_LOGGER.debug("(log_id={}) {} {} is LodTensor".format(
typical_logid, op_info_prefix, name))
_LOGGER.debug(
"(log_id={}) {} {} is LodTensor. lod_offset_name:{}".
format(typical_logid, op_info_prefix, name,
lod_offset_name))
lod_var_names.add(name)
lod_offset_names.add(lod_offset_name)
......@@ -812,8 +842,8 @@ class Op(object):
lod_offset_name = "{}.lod".format(name)
lod_offset = midped_batch[lod_offset_name]
for idx, data_id in enumerate(data_ids):
data_offset_left = input_offset[idx]
data_offset_right = input_offset[idx + 1]
data_offset_left = input_offset_dict[data_id][0]
data_offset_right = input_offset_dict[data_id][1]
lod_offset_left = lod_offset[data_offset_left]
lod_offset_right = lod_offset[data_offset_right]
midped_data_dict[data_id][name] = value[
......@@ -823,8 +853,8 @@ class Op(object):
else:
# normal tensor
for idx, data_id in enumerate(data_ids):
left = input_offset[idx]
right = input_offset[idx + 1]
left = input_offset_dict[data_id][0]
right = input_offset_dict[data_id][1]
midped_data_dict[data_id][name] = value[left:right]
else:
midped_data_dict = preped_data_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册