From e91711da4806661f260fdd8026a5f50d50990a10 Mon Sep 17 00:00:00 2001 From: barrierye Date: Mon, 10 Aug 2020 20:53:17 +0800 Subject: [PATCH] fix bug: lodtensor --- python/pipeline/operator.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index cc2b7637..b9135127 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -401,11 +401,33 @@ class Op(object): data_id=data_id) else: # transform np format to dict format + var_names = midped_batch.keys() + lod_var_names = set() + lod_offset_names = set() + for name in var_names: + lod_offset_name = "{}.lod".format(name) + if lod_offset_name in var_names: + _LOGGER.debug("(logid={}) {} {} is LodTensor".format( + typical_logid, op_info_prefix, name)) + lod_var_names.add(name) + lod_offset_names.add(lod_offset_name) + for idx, data_id in enumerate(data_ids): - midped_data_dict[data_id] = { - k: v[idx] - for k, v in midped_batch.items() - } + midped_data_dict[data_id] = {} + + for name, value in midped_batch.items(): + if name in lod_offset_names: + continue + if name in lod_var_names: + # lodtensor + lod_offset_name = "{}.lod".format(name) + for idx, data_id in enumerate(data_ids): + left = midped_batch[lod_offset_name][idx] + right = midped_batch[lod_offset_name][idx + 1] + midped_data_dict[data_id][name] = value[left:right] + else: + for idx, data_id in enumerate(data_ids): + midped_data_dict[data_id][name] = value[idx] else: midped_data_dict = preped_data_dict _LOGGER.debug("{} Succ process".format(op_info_prefix)) -- GitLab