未验证 提交 df0fa786 编写于 作者: K KP 提交者: GitHub

Convert dtype of input tensors to int64 explicitly in unified_transformer (#1582)

上级 71d3583a
...@@ -54,7 +54,7 @@ class UnifiedTransformer(nn.Layer): ...@@ -54,7 +54,7 @@ class UnifiedTransformer(nn.Layer):
Generate input batches. Generate input batches.
""" """
padding = False if batch_size == 1 else True padding = False if batch_size == 1 else True
pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False) pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False, dtype=np.int64)
def pad_mask(batch_attention_mask): def pad_mask(batch_attention_mask):
batch_size = len(batch_attention_mask) batch_size = len(batch_attention_mask)
...@@ -75,9 +75,9 @@ class UnifiedTransformer(nn.Layer): ...@@ -75,9 +75,9 @@ class UnifiedTransformer(nn.Layer):
position_ids = pad_func([example['position_ids'] for example in batch_examples]) position_ids = pad_func([example['position_ids'] for example in batch_examples])
attention_mask = pad_mask([example['attention_mask'] for example in batch_examples]) attention_mask = pad_mask([example['attention_mask'] for example in batch_examples])
else: else:
input_ids = np.asarray([example['input_ids'] for example in batch_examples]) input_ids = np.asarray([example['input_ids'] for example in batch_examples], dtype=np.int64)
token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples]) token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples], dtype=np.int64)
position_ids = np.asarray([example['position_ids'] for example in batch_examples]) position_ids = np.asarray([example['position_ids'] for example in batch_examples], dtype=np.int64)
attention_mask = np.asarray([example['attention_mask'] for example in batch_examples]) attention_mask = np.asarray([example['attention_mask'] for example in batch_examples])
attention_mask = np.expand_dims(attention_mask, 0) attention_mask = np.expand_dims(attention_mask, 0)
......
...@@ -54,7 +54,7 @@ class UnifiedTransformer(nn.Layer): ...@@ -54,7 +54,7 @@ class UnifiedTransformer(nn.Layer):
Generate input batches. Generate input batches.
""" """
padding = False if batch_size == 1 else True padding = False if batch_size == 1 else True
pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False) pad_func = Pad(pad_val=self.tokenizer.pad_token_id, pad_right=False, dtype=np.int64)
def pad_mask(batch_attention_mask): def pad_mask(batch_attention_mask):
batch_size = len(batch_attention_mask) batch_size = len(batch_attention_mask)
...@@ -75,9 +75,9 @@ class UnifiedTransformer(nn.Layer): ...@@ -75,9 +75,9 @@ class UnifiedTransformer(nn.Layer):
position_ids = pad_func([example['position_ids'] for example in batch_examples]) position_ids = pad_func([example['position_ids'] for example in batch_examples])
attention_mask = pad_mask([example['attention_mask'] for example in batch_examples]) attention_mask = pad_mask([example['attention_mask'] for example in batch_examples])
else: else:
input_ids = np.asarray([example['input_ids'] for example in batch_examples]) input_ids = np.asarray([example['input_ids'] for example in batch_examples], dtype=np.int64)
token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples]) token_type_ids = np.asarray([example['token_type_ids'] for example in batch_examples], dtype=np.int64)
position_ids = np.asarray([example['position_ids'] for example in batch_examples]) position_ids = np.asarray([example['position_ids'] for example in batch_examples], dtype=np.int64)
attention_mask = np.asarray([example['attention_mask'] for example in batch_examples]) attention_mask = np.asarray([example['attention_mask'] for example in batch_examples])
attention_mask = np.expand_dims(attention_mask, 0) attention_mask = np.expand_dims(attention_mask, 0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册