diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index cc2b7637774ffb774d32e0a3595767849d253318..b91351276489532cefc727888d3b91a7b3a742ce 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -401,11 +401,33 @@ class Op(object): data_id=data_id) else: # transform np format to dict format + var_names = midped_batch.keys() + lod_var_names = set() + lod_offset_names = set() + for name in var_names: + lod_offset_name = "{}.lod".format(name) + if lod_offset_name in var_names: + _LOGGER.debug("(logid={}) {} {} is LodTensor".format( + typical_logid, op_info_prefix, name)) + lod_var_names.add(name) + lod_offset_names.add(lod_offset_name) + for idx, data_id in enumerate(data_ids): - midped_data_dict[data_id] = { - k: v[idx] - for k, v in midped_batch.items() - } + midped_data_dict[data_id] = {} + + for name, value in midped_batch.items(): + if name in lod_offset_names: + continue + if name in lod_var_names: + # lodtensor + lod_offset_name = "{}.lod".format(name) + for idx, data_id in enumerate(data_ids): + left = midped_batch[lod_offset_name][idx] + right = midped_batch[lod_offset_name][idx + 1] + midped_data_dict[data_id][name] = value[left:right] + else: + for idx, data_id in enumerate(data_ids): + midped_data_dict[data_id][name] = value[idx] else: midped_data_dict = preped_data_dict _LOGGER.debug("{} Succ process".format(op_info_prefix))