diff --git a/python/pipeline/operator.py b/python/pipeline/operator.py index dda992c7d8adc6b73cb0d156c4a30a0badcc41b1..7670af3dfe57aa8d8bd4d546631bd5f3a3be1de6 100644 --- a/python/pipeline/operator.py +++ b/python/pipeline/operator.py @@ -709,11 +709,39 @@ class Op(object): # combine samples to batch one_input = preped_data_dict[data_ids[0]] feed_batch = [] + feed_dict = {} input_offset = None + cur_offset = 0 + input_offset_dict = {} + if isinstance(one_input, dict): # sample input - feed_batch = [preped_data_dict[data_id] for data_id in data_ids] - input_offset = list(range(len(data_ids) + 1)) + if len(data_ids) == 1: + feed_batch = [ + preped_data_dict[data_id] for data_id in data_ids + ] + else: + for data_id in data_ids: + for key, val in preped_data_dict[data_id].items(): + has_val = feed_dict.get(key) + if has_val is None: + feed_dict[key] = val + continue + # merge 2 np.arrray + if isinstance(val, np.ndarray): + feed_dict[key] = np.append( + feed_dict[key], val, axis=0) + feed_batch.append(feed_dict) + + for data_id in data_ids: + start = cur_offset + for key, val in preped_data_dict[data_id].items(): + if isinstance(val, (list, np.ndarray)): + cur_offset += len(val) + else: + cur_offset += 1 + break + input_offset_dict[data_id] = [start, cur_offset] elif isinstance(one_input, list): # batch input input_offset = [0] @@ -796,8 +824,10 @@ class Op(object): for name in var_names: lod_offset_name = "{}.lod".format(name) if lod_offset_name in var_names: - _LOGGER.debug("(log_id={}) {} {} is LodTensor".format( - typical_logid, op_info_prefix, name)) + _LOGGER.debug( + "(log_id={}) {} {} is LodTensor. lod_offset_name:{}". + format(typical_logid, op_info_prefix, name, + lod_offset_name)) lod_var_names.add(name) lod_offset_names.add(lod_offset_name) @@ -812,8 +842,8 @@ class Op(object): lod_offset_name = "{}.lod".format(name) lod_offset = midped_batch[lod_offset_name] for idx, data_id in enumerate(data_ids): - data_offset_left = input_offset[idx] - data_offset_right = input_offset[idx + 1] + data_offset_left = input_offset_dict[data_id][0] + data_offset_right = input_offset_dict[data_id][1] lod_offset_left = lod_offset[data_offset_left] lod_offset_right = lod_offset[data_offset_right] midped_data_dict[data_id][name] = value[ @@ -823,8 +853,8 @@ class Op(object): else: # normal tensor for idx, data_id in enumerate(data_ids): - left = input_offset[idx] - right = input_offset[idx + 1] + left = input_offset_dict[data_id][0] + right = input_offset_dict[data_id][1] midped_data_dict[data_id][name] = value[left:right] else: midped_data_dict = preped_data_dict