diff --git a/modules/text/text_generation/ernie_gen/encode.py b/modules/text/text_generation/ernie_gen/encode.py index c4fb9b07c2df62ff4a6a363cc6892c24a6cd8bf6..370ba0004d8158b2357159c9b373caca0c815acd 100644 --- a/modules/text/text_generation/ernie_gen/encode.py +++ b/modules/text/text_generation/ernie_gen/encode.py @@ -54,7 +54,8 @@ def convert_example(tokenizer, else: tgt_labels = tgt_ids - return (src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids, attn_ids, tgt_labels) + return [np.asarray(item, dtype=np.int64) for item \ + in [src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids, attn_ids, tgt_labels]] return warpper