diff --git a/ppocr/data/imaug/vqa/token/vqa_token_pad.py b/ppocr/data/imaug/vqa/token/vqa_token_pad.py index 6d3cf5e39193dc7cc5cebe2f817d93bef4a50593..8e5a20f95f0159e5c57072dd86eff0f25cf49eac 100644 --- a/ppocr/data/imaug/vqa/token/vqa_token_pad.py +++ b/ppocr/data/imaug/vqa/token/vqa_token_pad.py @@ -94,8 +94,11 @@ class VQATokenPad(object): 'input_ids', 'labels', 'token_type_ids', 'bbox', 'attention_mask' ]: - if self.infer_mode and key == 'labels': - continue - length = min(len(data[key]), self.max_seq_len) - data[key] = np.array(data[key][:length], dtype='int64') + if self.infer_mode: + if key != 'labels': + length = min(len(data[key]), self.max_seq_len) + data[key] = data[key][:length] + else: + continue + data[key] = np.array(data[key], dtype='int64') return data