From f6d19d8f09883d1caa02a23ce4895c5733441795 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Thu, 10 Jun 2021 15:45:25 +0800 Subject: [PATCH] Convert dtype of input tensors to int64 explicitly in plato-mini --- modules/text/text_generation/plato-mini/module.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/text/text_generation/plato-mini/module.py b/modules/text/text_generation/plato-mini/module.py index c25f0990..4a3594ef 100644 --- a/modules/text/text_generation/plato-mini/module.py +++ b/modules/text/text_generation/plato-mini/module.py @@ -54,7 +54,7 @@ class UnifiedTransformer(nn.Layer): Generate input batches. """ padding = False if batch_size == 1 else True - pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False) + pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False, dtype=np.int64) def pad_mask(batch_attention_mask): batch_size = len(batch_attention_mask) @@ -75,9 +75,9 @@ class UnifiedTransformer(nn.Layer): position_ids = pad_func([example['position_ids'] for example in batch_examples]) attention_mask = pad_mask([example['attention_mask'] for example in batch_examples]) else: - input_ids = np.asarray([example['input_ids'] for example in batch_examples]) - token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples]) - position_ids = np.asarray([example['position_ids'] for example in batch_examples]) + input_ids = np.asarray([example['input_ids'] for example in batch_examples], dtype=np.int64) + token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples], dtype=np.int64) + position_ids = np.asarray([example['position_ids'] for example in batch_examples], dtype=np.int64) attention_mask = np.asarray([example['attention_mask'] for example in batch_examples]) attention_mask = np.expand_dims(attention_mask, 0) -- GitLab