提交 e91711da 编写于 作者: B barrierye

fix bug: lodtensor

上级 6cc15d66
......@@ -401,11 +401,33 @@ class Op(object):
data_id=data_id)
else:
# transform np format to dict format
var_names = midped_batch.keys()
lod_var_names = set()
lod_offset_names = set()
for name in var_names:
lod_offset_name = "{}.lod".format(name)
if lod_offset_name in var_names:
_LOGGER.debug("(logid={}) {} {} is LodTensor".format(
typical_logid, op_info_prefix, name))
lod_var_names.add(name)
lod_offset_names.add(lod_offset_name)
for idx, data_id in enumerate(data_ids):
midped_data_dict[data_id] = {
k: v[idx]
for k, v in midped_batch.items()
}
midped_data_dict[data_id] = {}
for name, value in midped_batch.items():
if name in lod_offset_names:
continue
if name in lod_var_names:
# lodtensor
lod_offset_name = "{}.lod".format(name)
for idx, data_id in enumerate(data_ids):
left = midped_batch[lod_offset_name][idx]
right = midped_batch[lod_offset_name][idx + 1]
midped_data_dict[data_id][name] = value[left:right]
else:
for idx, data_id in enumerate(data_ids):
midped_data_dict[data_id][name] = value[idx]
else:
midped_data_dict = preped_data_dict
_LOGGER.debug("{} Succ process".format(op_info_prefix))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册