未验证 提交 bfa217e4 编写于 作者: W WangZhen 提交者: GitHub

Add bert prim and cinn test (#50545)

* Add bert prim and cinn test
上级 f6dea800
......@@ -9,7 +9,11 @@ foreach(TEST_OP ${TEST_OPS})
endforeach()
set_tests_properties(test_resnet_prim_cinn PROPERTIES TIMEOUT 400)
set_tests_properties(test_bert_prim_cinn PROPERTIES TIMEOUT 500)
if(WITH_CINN)
set_tests_properties(test_resnet_prim_cinn PROPERTIES LABELS "RUN_TYPE=CINN")
set_tests_properties(
test_bert_prim_cinn PROPERTIES LABELS "RUN_TYPE=CINN" ENVIRONMENT
"FLAGS_deny_cinn_ops=dropout")
endif()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import warnings
from typing import Optional, Tuple
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle import Tensor, nn
from paddle.distributed.fleet.utils import recompute
from paddle.fluid.data_feeder import convert_dtype
from paddle.io import DataLoader, Dataset
from paddle.nn import MultiHeadAttention
try:
from paddle.incubate.nn import FusedTransformerEncoderLayer
except ImportError:
FusedTransformerEncoderLayer = None
VOCAB_SIZE = 30522
class Stack(object):
def __init__(self, axis=0, dtype=None):
self._axis = axis
self._dtype = dtype
def __call__(self, data):
data = (
np.stack(data, axis=self._axis).astype(self._dtype)
if self._dtype
else np.stack(data, axis=self._axis)
)
return data
def is_tensor(x):
if isinstance(x, paddle.Tensor):
return True
return isinstance(x, np.ndarray)
class BertConfig:
def __init__(self):
self.attention_probs_dropout_prob = 0.1
self.fuse = False
self.hidden_act = 'gelu'
self.hidden_dropout_prob = 0.1
# Decrease config to speed up unittest
# self.hidden_size = 768
self.hidden_size = 60
self.initializer_range = 0.02
self.intermediate_size = 3072
self.layer_norm_eps = 1e-12
self.max_position_embeddings = 512
self.model_type = 'bert'
# self.num_attention_heads = 12
self.num_attention_heads = 6
# self.num_hidden_layers = 12
self.num_hidden_layers = 6
self.pad_token_id = 0
self.paddlenlp_version = None
self.pool_act = 'tanh'
self.type_vocab_size = 2
self.vocab_size = VOCAB_SIZE
self.use_return_dict = False
self.output_hidden_states = False
self.output_attentions = False
self.use_cache = False
class BertLMPredictionHead(nn.Layer):
def __init__(self, config: BertConfig, embedding_weights=None):
super(BertLMPredictionHead, self).__init__()
self.transform = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = getattr(nn.functional, config.hidden_act)
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.decoder_weight = (
self.create_parameter(
shape=[config.vocab_size, config.hidden_size],
dtype=self.transform.weight.dtype,
is_bias=False,
)
if embedding_weights is None
else embedding_weights
)
self.decoder_bias = self.create_parameter(
shape=[config.vocab_size],
dtype=self.decoder_weight.dtype,
is_bias=True,
)
def forward(self, hidden_states, masked_positions=None):
if masked_positions is not None:
hidden_states = paddle.reshape(
hidden_states, [-1, hidden_states.shape[-1]]
)
hidden_states = paddle.tensor.gather(
hidden_states, masked_positions
)
# gather masked tokens might be more quick
hidden_states = self.transform(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = (
paddle.tensor.matmul(
hidden_states, self.decoder_weight, transpose_y=True
)
+ self.decoder_bias
)
return hidden_states
class BertPretrainingHeads(nn.Layer):
def __init__(self, config: BertConfig, embedding_weights=None):
super(BertPretrainingHeads, self).__init__()
self.predictions = BertLMPredictionHead(config, embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output, masked_positions=None):
prediction_scores = self.predictions(sequence_output, masked_positions)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class BertEmbeddings(nn.Layer):
def __init__(self, config: BertConfig):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size
)
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
input_ids: Tensor,
token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
past_key_values_length: Optional[int] = None,
):
if position_ids is None:
ones = paddle.ones_like(input_ids, dtype="int64")
seq_length = paddle.cumsum(ones, axis=-1)
position_ids = seq_length - ones
if past_key_values_length is not None:
position_ids += past_key_values_length
position_ids.stop_gradient = True
if token_type_ids is None:
token_type_ids = paddle.zeros_like(input_ids, dtype="int64")
input_embedings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = (
input_embedings + position_embeddings + token_type_embeddings
)
embeddings = self.layer_norm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertPooler(nn.Layer):
def __init__(self, config: BertConfig):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
self.pool_act = config.pool_act
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
if self.pool_act == "tanh":
pooled_output = self.activation(pooled_output)
return pooled_output
class BertModel(nn.Layer):
def __init__(self, config: BertConfig):
super(BertModel, self).__init__()
self.config = config
self.pad_token_id = config.pad_token_id
self.initializer_range = config.initializer_range
self.embeddings = BertEmbeddings(config)
if config.fuse and FusedTransformerEncoderLayer is None:
warnings.warn(
"FusedTransformerEncoderLayer is not supported by the running Paddle. "
"The flag fuse_transformer will be ignored. Try Paddle >= 2.3.0"
)
self.fuse = config.fuse and FusedTransformerEncoderLayer is not None
if self.fuse:
self.encoder = nn.LayerList(
[
FusedTransformerEncoderLayer(
config.hidden_size,
config.num_attention_heads,
config.intermediate_size,
dropout_rate=config.hidden_dropout_prob,
activation=config.hidden_act,
attn_dropout_rate=config.attention_probs_dropout_prob,
act_dropout_rate=0.0,
)
for _ in range(config.num_hidden_layers)
]
)
else:
encoder_layer = nn.TransformerEncoderLayer(
config.hidden_size,
config.num_attention_heads,
config.intermediate_size,
dropout=config.hidden_dropout_prob,
activation=config.hidden_act,
attn_dropout=config.attention_probs_dropout_prob,
act_dropout=0,
)
self.encoder = nn.TransformerEncoder(
encoder_layer, config.num_hidden_layers
)
self.pooler = BertPooler(config)
# self.apply(self.init_weights)
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def forward(
self,
input_ids: Tensor,
token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
past_key_values: Optional[Tuple[Tuple[Tensor]]] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
use_cache = (
use_cache if use_cache is not None else self.config.use_cache
)
past_key_values_length = None
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if attention_mask is None:
attention_mask = paddle.unsqueeze(
(input_ids == self.pad_token_id).astype(
self.pooler.dense.weight.dtype
)
* -1e4,
axis=[1, 2],
)
if past_key_values is not None:
batch_size = past_key_values[0][0].shape[0]
past_mask = paddle.zeros(
[batch_size, 1, 1, past_key_values_length],
dtype=attention_mask.dtype,
)
attention_mask = paddle.concat(
[past_mask, attention_mask], axis=-1
)
else:
if attention_mask.ndim == 2:
# attention_mask [batch_size, sequence_length] -> [batch_size, 1, 1, sequence_length]
attention_mask = attention_mask.unsqueeze(axis=[1, 2]).astype(
paddle.get_default_dtype()
)
attention_mask = (1.0 - attention_mask) * -1e4
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
past_key_values_length=past_key_values_length,
)
if self.fuse:
assert (
not output_attentions
), "Not support attentions output currently."
assert (
past_key_values is None
), "Not support past_key_values currently."
hidden_states = embedding_output
all_hidden_states = [] if output_hidden_states else None
for layer in self.encoder:
hidden_states = layer(hidden_states, attention_mask)
if output_hidden_states:
all_hidden_states.append(hidden_states)
pooled_output = self.pooler(hidden_states)
return (
(hidden_states, pooled_output, all_hidden_states)
if output_hidden_states
else (hidden_states, pooled_output)
)
else:
self.encoder._use_cache = use_cache # To be consistent with HF
encoder_outputs = self.encoder(
embedding_output,
src_mask=attention_mask,
cache=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if isinstance(encoder_outputs, type(embedding_output)):
sequence_output = encoder_outputs
pooled_output = self.pooler(sequence_output)
return (sequence_output, pooled_output)
else:
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
return (sequence_output, pooled_output) + encoder_outputs[1:]
class Bert(nn.Layer):
def __init__(self):
super(Bert, self).__init__()
config = BertConfig()
self.bert = BertModel(config)
self.cls = BertPretrainingHeads(
config,
embedding_weights=self.bert.embeddings.word_embeddings.weight,
)
# self.apply(self.init_weights)
def forward(
self,
input_ids: Tensor,
token_type_ids: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
masked_positions: Optional[Tensor] = None,
labels: Optional[Tensor] = None,
next_sentence_label: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
with paddle.static.amp.fp16_guard():
outputs = self.bert(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(
sequence_output, pooled_output, masked_positions
)
total_loss = None
if labels is not None and next_sentence_label is not None:
loss_fct = paddle.nn.CrossEntropyLoss()
masked_lm_loss = loss_fct(
prediction_scores.reshape(
(-1, prediction_scores.shape[-1])
),
labels.reshape((-1,)),
)
next_sentence_loss = loss_fct(
seq_relationship_score.reshape((-1, 2)),
next_sentence_label.reshape((-1,)),
)
total_loss = masked_lm_loss + next_sentence_loss
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return (
((total_loss,) + output) if total_loss is not None else output
)
class BertPretrainingCriterion(paddle.nn.Layer):
def __init__(self, vocab_size=VOCAB_SIZE):
super(BertPretrainingCriterion, self).__init__()
# CrossEntropyLoss is expensive since the inner reshape (copy)
self.loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=-1)
self.vocab_size = vocab_size
def forward(
self,
prediction_scores,
seq_relationship_score,
masked_lm_labels,
next_sentence_labels,
masked_lm_scale,
):
with paddle.static.amp.fp16_guard():
masked_lm_loss = F.cross_entropy(
prediction_scores,
masked_lm_labels,
reduction="none",
ignore_index=-1,
)
masked_lm_loss = masked_lm_loss / masked_lm_scale
next_sentence_loss = F.cross_entropy(
seq_relationship_score, next_sentence_labels, reduction="none"
)
return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss)
def layer_init_wrapper(func):
@functools.wraps(func)
def _impl(self, *args, **kwargs):
enable_recompute = kwargs.pop("enable_recompute", False)
func(self, *args, **kwargs)
if paddle.in_dynamic_mode():
self.enable_recompute = enable_recompute
else:
self.enable_recompute = False
return _impl
def _convert_attention_mask(attn_mask, dtype):
if attn_mask is not None and attn_mask.dtype != dtype:
attn_mask_dtype = convert_dtype(attn_mask.dtype)
if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype:
attn_mask = (paddle.cast(attn_mask, dtype) - 1.0) * 1e9
else:
attn_mask = paddle.cast(attn_mask, dtype)
return attn_mask
def _transformer_encoder_layer_fwd(
self, src, src_mask=None, cache=None, output_attentions=False
):
self.self_attn.need_weights = output_attentions
src_mask = _convert_attention_mask(src_mask, src.dtype)
residual = src
if self.normalize_before:
src = self.norm1(src)
attn_outputs = self.self_attn(src, src, src, src_mask, cache)
if isinstance(attn_outputs, tuple):
src = attn_outputs[0]
outputs = attn_outputs[1:]
else:
src = attn_outputs
outputs = None
src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
return (
src if outputs is None else ((src,) + outputs[::-1])
) # hidden_states, cache, attentions
def _transformer_decoder_layer_fwd(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
cache=None,
output_attentions=False,
):
residual = tgt
# self attention
self.self_attn.need_weights = output_attentions
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
if self.normalize_before:
tgt = self.norm1(tgt)
self_attn_outputs = self.self_attn(
tgt, tgt, tgt, tgt_mask, cache[0] if cache else None
)
# self_attn_outputs = (tgt, attn_weights, incremental_cache) or only tgt
if isinstance(self_attn_outputs, type(tgt)):
tgt = self_attn_outputs
else:
tgt = self_attn_outputs[0]
if output_attentions:
self_attn_weights = self_attn_outputs[1]
if cache:
incremental_cache = self_attn_outputs[-1]
tgt = residual + self.dropout1(tgt)
if not self.normalize_before:
tgt = self.norm1(tgt)
residual = tgt
# cross attention
if memory is not None:
self.cross_attn.need_weights = output_attentions
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
if self.normalize_before:
tgt = self.norm2(tgt)
cross_attn_outputs = self.cross_attn(
tgt, memory, memory, memory_mask, cache[1] if cache else None
)
if isinstance(cross_attn_outputs, type(tgt)):
tgt = cross_attn_outputs
else:
tgt = cross_attn_outputs[0]
if output_attentions:
cross_attn_weights = cross_attn_outputs[1]
if cache:
static_cache = cross_attn_outputs[-1]
tgt = residual + self.dropout2(tgt)
if not self.normalize_before:
tgt = self.norm2(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm3(tgt)
tgt = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = residual + self.dropout3(tgt)
if not self.normalize_before:
tgt = self.norm3(tgt)
if not output_attentions and cache is None:
return tgt
else:
outputs = (tgt,)
if output_attentions:
outputs += (
self_attn_weights,
cross_attn_weights if memory is not None else None,
)
if cache:
outputs += (
(
incremental_cache,
static_cache if memory is not None else None,
),
)
return outputs
def _transformer_encoder_fwd(
self,
src,
src_mask=None,
cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
src_mask = _convert_attention_mask(src_mask, src.dtype)
output = src
# To get cache from None when use_cache is True, which is compatible with HF
# while HF requires decoder. The implementation here uses cache update in the
# MultiHeadAttention not so efficiently, and maybe optimize it later.
if cache is None and getattr(self, "_use_cache", False):
cache = [tuple(self.layers[0].gen_cache(src))] * len(self.layers)
# To be compatible with `TransformerEncoder.forward`, `_use_cache` defualts
# to True when cache is not None.
new_caches = (
[] if cache is not None and getattr(self, "_use_cache", True) else None
)
all_attentions = [] if output_attentions else None
# NOTE: Also includes embeding output which is same as HF.
all_hidden_states = [output] if output_hidden_states else None
for i, mod in enumerate(self.layers):
if self.enable_recompute:
# Note: recompute do not support pass as **kwargs yet.
layer_outputs = recompute(
mod,
output,
src_mask,
None
if cache is None
else cache[i]
if isinstance(cache[i], MultiHeadAttention.Cache)
else MultiHeadAttention.Cache(*cache[i]),
output_attentions,
)
else:
layer_outputs = mod(
output,
src_mask=src_mask,
cache=None
if cache is None
else cache[i]
if isinstance(cache[i], MultiHeadAttention.Cache)
else MultiHeadAttention.Cache(*cache[i]),
output_attentions=output_attentions,
)
if isinstance(layer_outputs, tuple):
output = layer_outputs[0]
outputs = layer_outputs[1:]
else:
output = layer_outputs
outputs = None
if output_hidden_states:
all_hidden_states.append(output)
if output_attentions:
all_attentions.append(outputs[-1])
if new_caches is not None:
new_caches.append(
outputs[0]
if isinstance(cache[i], MultiHeadAttention.Cache)
else (tuple(outputs[0]))
)
if self.norm is not None:
output = self.norm(output)
if output_hidden_states:
all_hidden_states[-1] = output
outputs = tuple(
tuple(v) if isinstance(v, list) else v
for v in [
output,
new_caches,
all_hidden_states,
all_attentions,
]
if v is not None
)
if len(outputs) == 1:
return output
else:
return outputs
def _transformer_decoder_fwd(
self,
tgt,
memory=None,
tgt_mask=None,
memory_mask=None,
cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
if memory is not None:
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
new_caches = [] if cache else None
all_hidden_states = [tgt] if output_hidden_states else None
all_self_attns = [] if output_attentions else None
all_cross_attns = [] if output_attentions else None
for i, mod in enumerate(self.layers):
if cache is None:
if self.enable_recompute:
outputs = recompute(
mod,
tgt,
memory,
tgt_mask,
memory_mask,
None,
output_attentions,
)
else:
outputs = mod(
tgt,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
cache=None,
output_attentions=output_attentions,
)
else:
outputs = mod(
tgt,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
cache=cache[i] if cache else None,
output_attentions=output_attentions,
)
if isinstance(outputs, type(tgt)):
tgt = outputs
else:
tgt = outputs[0]
if cache:
new_caches.append(outputs[-1])
if output_attentions:
all_self_attns.append(outputs[1])
all_cross_attns.append(outputs[2])
if output_hidden_states:
all_hidden_states.append(tgt)
if self.norm is not None:
tgt = self.norm(tgt)
if output_hidden_states:
all_hidden_states[-1] = tgt
if isinstance(outputs, type(tgt)):
return tgt
temp_list = [
tgt,
new_caches if cache else None,
all_hidden_states,
all_self_attns,
all_cross_attns,
]
return tuple(v for v in temp_list if v is not None)
# patches of paddle.nn.Transformer to get all hidden_states and attentions
paddle.nn.TransformerEncoderLayer.forward = _transformer_encoder_layer_fwd
paddle.nn.TransformerDecoderLayer.forward = _transformer_decoder_layer_fwd
paddle.nn.TransformerEncoder.forward = _transformer_encoder_fwd
paddle.nn.TransformerDecoder.forward = _transformer_decoder_fwd
_encoder_init = paddle.nn.TransformerEncoder.__init__
_decoder_init = paddle.nn.TransformerDecoder.__init__
paddle.nn.TransformerEncoder.__init__ = layer_init_wrapper(_encoder_init)
paddle.nn.TransformerDecoder.__init__ = layer_init_wrapper(_decoder_init)
class PretrainingDataset(Dataset):
def __init__(self, input_file, max_pred_length):
self.input_file = input_file
self.max_pred_length = max_pred_length
keys = [
"input_ids",
"input_mask",
"segment_ids",
"masked_lm_positions",
"masked_lm_ids",
"next_sentence_labels",
]
self.inputs = np.load(input_file)
self.inputs = [self.inputs[key] for key in keys]
def __len__(self):
"Denotes the total number of samples"
return len(self.inputs[0])
def __getitem__(self, index):
[
input_ids,
input_mask,
segment_ids,
masked_lm_positions,
masked_lm_ids,
next_sentence_labels,
] = [
input[index].astype(np.int64)
if indice < 5
else np.asarray(input[index].astype(np.int64))
for indice, input in enumerate(self.inputs)
]
# TODO: whether to use reversed mask by changing 1s and 0s to be
# consistent with nv bert
input_mask = (
1
- np.reshape(
input_mask.astype(np.float32), [1, 1, input_mask.shape[0]]
)
) * -1e9
index = self.max_pred_length
# store number of masked tokens in index
# outputs of torch.nonzero diff with that of numpy.nonzero by zip
padded_mask_indices = (masked_lm_positions == 0).nonzero()[0]
if len(padded_mask_indices) != 0:
index = padded_mask_indices[0].item()
else:
index = self.max_pred_length
# masked_lm_labels = np.full(input_ids.shape, -1, dtype=np.int64)
# masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index]
masked_lm_labels = masked_lm_ids[:index]
masked_lm_positions = masked_lm_positions[:index]
# softmax_with_cross_entropy enforce last dim size equal 1
masked_lm_labels = np.expand_dims(masked_lm_labels, axis=-1)
next_sentence_labels = np.expand_dims(next_sentence_labels, axis=-1)
return [
input_ids,
segment_ids,
input_mask,
masked_lm_positions,
masked_lm_labels,
next_sentence_labels,
]
def create_pretraining_dataset(
input_file, max_pred_length, shared_list, batch_size, worker_init
):
train_data = PretrainingDataset(
input_file=input_file, max_pred_length=max_pred_length
)
# files have been sharded, no need to dispatch again
train_batch_sampler = paddle.io.BatchSampler(
train_data, batch_size=batch_size, shuffle=True
)
# DataLoader cannot be pickled because of its place.
# If it can be pickled, use global function instead of lambda and use
# ProcessPoolExecutor instead of ThreadPoolExecutor to prefetch.
def _collate_data(data, stack_fn=Stack()):
num_fields = len(data[0])
out = [None] * num_fields
# input_ids, segment_ids, input_mask, masked_lm_positions,
# masked_lm_labels, next_sentence_labels, mask_token_num
for i in (0, 1, 2, 5):
out[i] = stack_fn([x[i] for x in data])
_, seq_length = out[0].shape
size = sum(len(x[3]) for x in data)
# Padding for divisibility by 8 for fp16 or int8 usage
if size % 8 != 0:
size += 8 - (size % 8)
# masked_lm_positions
# Organize as a 1D tensor for gather or use gather_nd
out[3] = np.full(size, 0, dtype=np.int32)
# masked_lm_labels
out[4] = np.full([size, 1], -1, dtype=np.int64)
mask_token_num = 0
for i, x in enumerate(data):
for j, pos in enumerate(x[3]):
out[3][mask_token_num] = i * seq_length + pos
out[4][mask_token_num] = x[4][j]
mask_token_num += 1
# mask_token_num
out.append(np.asarray([mask_token_num], dtype=np.float32))
return out
train_data_loader = DataLoader(
dataset=train_data,
batch_sampler=train_batch_sampler,
collate_fn=_collate_data,
num_workers=0,
worker_init_fn=worker_init,
return_list=True,
)
return train_data_loader
if __name__ == '__main__':
bert = Bert()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
import time
import unittest
import numpy as np
from bert import Bert, BertPretrainingCriterion, create_pretraining_dataset
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.dataset.common import DATA_HOME, download
SEED = 2023
BATCH_SIZE = 2
URL = 'https://paddle-ci.gz.bcebos.com/prim_cinn/bert_training_data.npz'
MODULE_NAME = 'test_bert_prim_cinn'
MD5SUM = '71e730ee8d7aa77a215b7e898aa089af'
SAVE_NAME = 'bert_training_data.npz'
if core.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': True})
def train(to_static, enable_prim, enable_cinn):
if core.is_compiled_with_cuda():
paddle.set_device('gpu')
else:
paddle.set_device('cpu')
fluid.core._set_prim_all_enabled(
enable_prim and platform.system() == 'Linux'
)
np.random.seed(SEED)
paddle.seed(SEED)
# paddle.framework.random._manual_program_seed(SEED)
train_data_loader = create_pretraining_dataset(
os.path.join(DATA_HOME, MODULE_NAME, SAVE_NAME),
20,
{},
batch_size=BATCH_SIZE,
worker_init=None,
)
bert = Bert()
criterion = BertPretrainingCriterion()
if to_static:
# input_sepc = [
# InputSpec(shape=(-1, -1), dtype=paddle.int64, name='input_ids'),
# InputSpec(shape=(-1, -1), dtype=paddle.int64, name='segment_ids'),
# None,
# InputSpec(shape=(-1, 1, 1, -1), dtype=paddle.float32, name='input_mask'),
# InputSpec(shape=(-1,), dtype=paddle.int32, name='masked_lm_positions'),
# ]
input_sepc = None
build_strategy = paddle.static.BuildStrategy()
if enable_cinn:
build_strategy.build_cinn_pass = True
bert = paddle.jit.to_static(
bert, input_sepc, build_strategy=build_strategy
)
optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters())
losses = []
for step, batch in enumerate(train_data_loader):
start_time = time.time()
(
input_ids,
segment_ids,
input_mask,
masked_lm_positions,
masked_lm_labels,
next_sentence_labels,
masked_lm_scale,
) = batch
prediction_scores, seq_relationship_score = bert(
input_ids=input_ids,
token_type_ids=segment_ids,
attention_mask=input_mask,
masked_positions=masked_lm_positions,
)
loss = criterion(
prediction_scores,
seq_relationship_score,
masked_lm_labels,
next_sentence_labels,
masked_lm_scale,
)
loss.backward()
optimizer.minimize(loss)
bert.clear_gradients()
losses.append(loss)
print(
"step: {}, loss: {}, batch_cost: {:.5}".format(
step,
loss.numpy(),
time.time() - start_time,
)
)
if step >= 9:
break
return losses
class TestBert(unittest.TestCase):
@classmethod
def setUpClass(cls):
download(URL, MODULE_NAME, MD5SUM, SAVE_NAME)
cls.dy2st = train(to_static=True, enable_prim=False, enable_cinn=False)
def test_prim(self):
dy2st_prim = train(to_static=True, enable_prim=True, enable_cinn=False)
np.testing.assert_allclose(self.dy2st, dy2st_prim, rtol=1e-1)
@unittest.skipIf(
not paddle.is_compiled_with_cinn(), "padle is not compiled with CINN"
)
def test_cinn(self):
dy2st_cinn = train(to_static=True, enable_prim=False, enable_cinn=True)
np.testing.assert_allclose(self.dy2st, dy2st_cinn, rtol=1e-6)
@unittest.skipIf(
not paddle.is_compiled_with_cinn(), "padle is not compiled with CINN"
)
def test_prim_cinn(self):
dy2st_prim_cinn = train(
to_static=True, enable_prim=True, enable_cinn=True
)
np.testing.assert_allclose(self.dy2st, dy2st_prim_cinn, rtol=1e-1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册