提交 d7b53704 编写于 作者: B barrierye

fix batch input

上级 9f963bb4
......@@ -345,7 +345,29 @@ class Op(object):
" we selected logid={} (from batch: {}) as a "
"representative for logging.".format(
data_id, op_info_prefix, typical_logid, data_ids))
feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
# combine samples to batch
one_input = preped_data_dict[data_ids[0]]
feed_batch = []
input_offset = None
if isinstance(one_input, dict):
# sample input
feed_batch = [preped_data_dict[data_id] for data_id in data_ids]
input_offset = list(range(len(data_ids) + 1))
elif isinstance(one_input, list):
# batch input
input_offset = [0]
for data_id in data_ids:
batch_input = preped_data_dict[data_id]
offset = input_offset[-1] + len(batch_input)
feed_batch += batch_input
input_offset.append(offset)
else:
_LOGGER.critical(
"{} Failed to process: expect input type is dict or listi"
", but get {}".format(op_info_prefix, type(one_input)))
os._exit(-1)
midped_batch = None
ecode = ChannelDataEcode.OK.value
if self._timeout <= 0:
......@@ -414,20 +436,28 @@ class Op(object):
for idx, data_id in enumerate(data_ids):
midped_data_dict[data_id] = {}
for name, value in midped_batch.items():
if name in lod_offset_names:
continue
if name in lod_var_names:
# lodtensor
lod_offset_name = "{}.lod".format(name)
lod_offset = midped_batch[lod_offset_name]
for idx, data_id in enumerate(data_ids):
left = midped_batch[lod_offset_name][idx]
right = midped_batch[lod_offset_name][idx + 1]
midped_data_dict[data_id][name] = value[left:right]
data_offset_left = input_offset[idx]
data_offset_right = input_offset[idx + 1]
lod_offset_left = lod_offset[data_offset_left]
lod_offset_right = lod_offset[data_offset_right]
midped_data_dict[data_id][name] = value[lod_offset_left:lod_offset_right]
midped_data_dict[data_id][lod_offset_name] = \
lod_offset[data_offset_left:data_offset_right + 1] - lod_offset[data_offset_left]
else:
# normal tensor
for idx, data_id in enumerate(data_ids):
midped_data_dict[data_id][name] = value[idx]
left = input_offset[idx]
right = input_offset[idx + 1]
midped_data_dict[data_id][name] = value[left:right]
else:
midped_data_dict = preped_data_dict
_LOGGER.debug("{} Succ process".format(op_info_prefix))
......@@ -772,7 +802,7 @@ class ResponseOp(Op):
feed = channeldata.parse()
# ndarray to string:
# https://stackoverflow.com/questions/30167538/convert-a-numpy-ndarray-to-stringor-bytes-and-convert-it-back-to-numpy-ndarray
np.set_printoptions(threshold=np.nan)
np.set_printoptions(threshold=sys.maxsize)
for name, var in feed.items():
resp.value.append(var.__repr__())
resp.key.append(name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册