未验证 提交 f6d19d8f 编写于 作者: K KP 提交者: GitHub

Convert dtype of input tensors to int64 explicitly in plato-mini

上级 5b5a1ea2
......@@ -54,7 +54,7 @@ class UnifiedTransformer(nn.Layer):
Generate input batches.
"""
padding = False if batch_size == 1 else True
pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False)
pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False, dtype=np.int64)
def pad_mask(batch_attention_mask):
batch_size = len(batch_attention_mask)
......@@ -75,9 +75,9 @@ class UnifiedTransformer(nn.Layer):
position_ids = pad_func([example['position_ids'] for example in batch_examples])
attention_mask = pad_mask([example['attention_mask'] for example in batch_examples])
else:
input_ids = np.asarray([example['input_ids'] for example in batch_examples])
token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples])
position_ids = np.asarray([example['position_ids'] for example in batch_examples])
input_ids = np.asarray([example['input_ids'] for example in batch_examples], dtype=np.int64)
token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples], dtype=np.int64)
position_ids = np.asarray([example['position_ids'] for example in batch_examples], dtype=np.int64)
attention_mask = np.asarray([example['attention_mask'] for example in batch_examples])
attention_mask = np.expand_dims(attention_mask, 0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册