From df0fa786f8fa55a46693af54da6b12ef33e4d593 Mon Sep 17 00:00:00 2001 From: KP <109694228@qq.com> Date: Mon, 23 Aug 2021 16:25:08 +0800 Subject: [PATCH] Convert dtype of input tensors to int64 explicitly in unified_transformer (#1582) --- .../unified_transformer-12L-cn-luge/module.py | 8 ++++---- .../text_generation/unified_transformer-12L-cn/module.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/modules/text/text_generation/unified_transformer-12L-cn-luge/module.py b/modules/text/text_generation/unified_transformer-12L-cn-luge/module.py index 52ef5532..115b1e0e 100644 --- a/modules/text/text_generation/unified_transformer-12L-cn-luge/module.py +++ b/modules/text/text_generation/unified_transformer-12L-cn-luge/module.py @@ -54,7 +54,7 @@ class UnifiedTransformer(nn.Layer): Generate input batches. """ padding = False if batch_size == 1 else True - pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False) + pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False, dtype=np.int64) def pad_mask(batch_attention_mask): batch_size = len(batch_attention_mask) @@ -75,9 +75,9 @@ class UnifiedTransformer(nn.Layer): position_ids = pad_func([example['position_ids'] for example in batch_examples]) attention_mask = pad_mask([example['attention_mask'] for example in batch_examples]) else: - input_ids = np.asarray([example['input_ids'] for example in batch_examples]) - token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples]) - position_ids = np.asarray([example['position_ids'] for example in batch_examples]) + input_ids = np.asarray([example['input_ids'] for example in batch_examples], dtype=np.int64) + token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples], dtype=np.int64) + position_ids = np.asarray([example['position_ids'] for example in batch_examples], dtype=np.int64) attention_mask = np.asarray([example['attention_mask'] for example in batch_examples]) attention_mask = np.expand_dims(attention_mask, 0) diff --git a/modules/text/text_generation/unified_transformer-12L-cn/module.py b/modules/text/text_generation/unified_transformer-12L-cn/module.py index ee09a55d..363d15d7 100644 --- a/modules/text/text_generation/unified_transformer-12L-cn/module.py +++ b/modules/text/text_generation/unified_transformer-12L-cn/module.py @@ -54,7 +54,7 @@ class UnifiedTransformer(nn.Layer): Generate input batches. """ padding = False if batch_size == 1 else True - pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False) + pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False, dtype=np.int64) def pad_mask(batch_attention_mask): batch_size = len(batch_attention_mask) @@ -75,9 +75,9 @@ class UnifiedTransformer(nn.Layer): position_ids = pad_func([example['position_ids'] for example in batch_examples]) attention_mask = pad_mask([example['attention_mask'] for example in batch_examples]) else: - input_ids = np.asarray([example['input_ids'] for example in batch_examples]) - token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples]) - position_ids = np.asarray([example['position_ids'] for example in batch_examples]) + input_ids = np.asarray([example['input_ids'] for example in batch_examples], dtype=np.int64) + token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples], dtype=np.int64) + position_ids = np.asarray([example['position_ids'] for example in batch_examples], dtype=np.int64) attention_mask = np.asarray([example['attention_mask'] for example in batch_examples]) attention_mask = np.expand_dims(attention_mask, 0) -- GitLab