diff --git a/python/paddle/fluid/tests/unittests/prim/model/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/model/CMakeLists.txt index 55747ae5731bc46866457cd47fe3a93b22c213d8..82a0c92c516cc8ffe25a421601462ee02674131d 100644 --- a/python/paddle/fluid/tests/unittests/prim/model/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/model/CMakeLists.txt @@ -9,7 +9,11 @@ foreach(TEST_OP ${TEST_OPS}) endforeach() set_tests_properties(test_resnet_prim_cinn PROPERTIES TIMEOUT 400) +set_tests_properties(test_bert_prim_cinn PROPERTIES TIMEOUT 500) if(WITH_CINN) set_tests_properties(test_resnet_prim_cinn PROPERTIES LABELS "RUN_TYPE=CINN") + set_tests_properties( + test_bert_prim_cinn PROPERTIES LABELS "RUN_TYPE=CINN" ENVIRONMENT + "FLAGS_deny_cinn_ops=dropout") endif() diff --git a/python/paddle/fluid/tests/unittests/prim/model/bert.py b/python/paddle/fluid/tests/unittests/prim/model/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0eadf8698650fb7a7d30f39b8295261424472d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/model/bert.py @@ -0,0 +1,908 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import warnings +from typing import Optional, Tuple + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed.fleet.utils import recompute +from paddle.fluid.data_feeder import convert_dtype +from paddle.io import DataLoader, Dataset +from paddle.nn import MultiHeadAttention + +try: + from paddle.incubate.nn import FusedTransformerEncoderLayer +except ImportError: + FusedTransformerEncoderLayer = None + + +VOCAB_SIZE = 30522 + + +class Stack(object): + def __init__(self, axis=0, dtype=None): + self._axis = axis + self._dtype = dtype + + def __call__(self, data): + data = ( + np.stack(data, axis=self._axis).astype(self._dtype) + if self._dtype + else np.stack(data, axis=self._axis) + ) + return data + + +def is_tensor(x): + if isinstance(x, paddle.Tensor): + return True + + return isinstance(x, np.ndarray) + + +class BertConfig: + def __init__(self): + self.attention_probs_dropout_prob = 0.1 + self.fuse = False + self.hidden_act = 'gelu' + self.hidden_dropout_prob = 0.1 + # Decrease config to speed up unittest + # self.hidden_size = 768 + self.hidden_size = 60 + self.initializer_range = 0.02 + self.intermediate_size = 3072 + self.layer_norm_eps = 1e-12 + self.max_position_embeddings = 512 + self.model_type = 'bert' + # self.num_attention_heads = 12 + self.num_attention_heads = 6 + # self.num_hidden_layers = 12 + self.num_hidden_layers = 6 + self.pad_token_id = 0 + self.paddlenlp_version = None + self.pool_act = 'tanh' + self.type_vocab_size = 2 + self.vocab_size = VOCAB_SIZE + self.use_return_dict = False + self.output_hidden_states = False + self.output_attentions = False + self.use_cache = False + + +class BertLMPredictionHead(nn.Layer): + def __init__(self, config: BertConfig, embedding_weights=None): + super(BertLMPredictionHead, self).__init__() + + self.transform = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = getattr(nn.functional, config.hidden_act) + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.decoder_weight = ( + self.create_parameter( + shape=[config.vocab_size, config.hidden_size], + dtype=self.transform.weight.dtype, + is_bias=False, + ) + if embedding_weights is None + else embedding_weights + ) + + self.decoder_bias = self.create_parameter( + shape=[config.vocab_size], + dtype=self.decoder_weight.dtype, + is_bias=True, + ) + + def forward(self, hidden_states, masked_positions=None): + if masked_positions is not None: + hidden_states = paddle.reshape( + hidden_states, [-1, hidden_states.shape[-1]] + ) + hidden_states = paddle.tensor.gather( + hidden_states, masked_positions + ) + # gather masked tokens might be more quick + hidden_states = self.transform(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = ( + paddle.tensor.matmul( + hidden_states, self.decoder_weight, transpose_y=True + ) + + self.decoder_bias + ) + return hidden_states + + +class BertPretrainingHeads(nn.Layer): + def __init__(self, config: BertConfig, embedding_weights=None): + super(BertPretrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output, masked_positions=None): + prediction_scores = self.predictions(sequence_output, masked_positions) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertEmbeddings(nn.Layer): + def __init__(self, config: BertConfig): + super(BertEmbeddings, self).__init__() + + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + self.token_type_embeddings = nn.Embedding( + config.type_vocab_size, config.hidden_size + ) + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + past_key_values_length: Optional[int] = None, + ): + + if position_ids is None: + ones = paddle.ones_like(input_ids, dtype="int64") + seq_length = paddle.cumsum(ones, axis=-1) + + position_ids = seq_length - ones + if past_key_values_length is not None: + position_ids += past_key_values_length + position_ids.stop_gradient = True + if token_type_ids is None: + token_type_ids = paddle.zeros_like(input_ids, dtype="int64") + + input_embedings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = ( + input_embedings + position_embeddings + token_type_embeddings + ) + embeddings = self.layer_norm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertPooler(nn.Layer): + def __init__(self, config: BertConfig): + super(BertPooler, self).__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + self.pool_act = config.pool_act + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + if self.pool_act == "tanh": + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Layer): + def __init__(self, config: BertConfig): + super(BertModel, self).__init__() + self.config = config + self.pad_token_id = config.pad_token_id + self.initializer_range = config.initializer_range + self.embeddings = BertEmbeddings(config) + if config.fuse and FusedTransformerEncoderLayer is None: + warnings.warn( + "FusedTransformerEncoderLayer is not supported by the running Paddle. " + "The flag fuse_transformer will be ignored. Try Paddle >= 2.3.0" + ) + self.fuse = config.fuse and FusedTransformerEncoderLayer is not None + if self.fuse: + self.encoder = nn.LayerList( + [ + FusedTransformerEncoderLayer( + config.hidden_size, + config.num_attention_heads, + config.intermediate_size, + dropout_rate=config.hidden_dropout_prob, + activation=config.hidden_act, + attn_dropout_rate=config.attention_probs_dropout_prob, + act_dropout_rate=0.0, + ) + for _ in range(config.num_hidden_layers) + ] + ) + else: + encoder_layer = nn.TransformerEncoderLayer( + config.hidden_size, + config.num_attention_heads, + config.intermediate_size, + dropout=config.hidden_dropout_prob, + activation=config.hidden_act, + attn_dropout=config.attention_probs_dropout_prob, + act_dropout=0, + ) + self.encoder = nn.TransformerEncoder( + encoder_layer, config.num_hidden_layers + ) + self.pooler = BertPooler(config) + # self.apply(self.init_weights) + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + past_key_values: Optional[Tuple[Tuple[Tensor]]] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache + ) + + past_key_values_length = None + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + if attention_mask is None: + attention_mask = paddle.unsqueeze( + (input_ids == self.pad_token_id).astype( + self.pooler.dense.weight.dtype + ) + * -1e4, + axis=[1, 2], + ) + if past_key_values is not None: + batch_size = past_key_values[0][0].shape[0] + past_mask = paddle.zeros( + [batch_size, 1, 1, past_key_values_length], + dtype=attention_mask.dtype, + ) + attention_mask = paddle.concat( + [past_mask, attention_mask], axis=-1 + ) + + else: + if attention_mask.ndim == 2: + # attention_mask [batch_size, sequence_length] -> [batch_size, 1, 1, sequence_length] + attention_mask = attention_mask.unsqueeze(axis=[1, 2]).astype( + paddle.get_default_dtype() + ) + attention_mask = (1.0 - attention_mask) * -1e4 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + past_key_values_length=past_key_values_length, + ) + if self.fuse: + assert ( + not output_attentions + ), "Not support attentions output currently." + assert ( + past_key_values is None + ), "Not support past_key_values currently." + hidden_states = embedding_output + all_hidden_states = [] if output_hidden_states else None + for layer in self.encoder: + hidden_states = layer(hidden_states, attention_mask) + if output_hidden_states: + all_hidden_states.append(hidden_states) + pooled_output = self.pooler(hidden_states) + + return ( + (hidden_states, pooled_output, all_hidden_states) + if output_hidden_states + else (hidden_states, pooled_output) + ) + else: + self.encoder._use_cache = use_cache # To be consistent with HF + encoder_outputs = self.encoder( + embedding_output, + src_mask=attention_mask, + cache=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if isinstance(encoder_outputs, type(embedding_output)): + sequence_output = encoder_outputs + pooled_output = self.pooler(sequence_output) + return (sequence_output, pooled_output) + else: + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) + return (sequence_output, pooled_output) + encoder_outputs[1:] + + +class Bert(nn.Layer): + def __init__(self): + super(Bert, self).__init__() + config = BertConfig() + self.bert = BertModel(config) + self.cls = BertPretrainingHeads( + config, + embedding_weights=self.bert.embeddings.word_embeddings.weight, + ) + + # self.apply(self.init_weights) + + def forward( + self, + input_ids: Tensor, + token_type_ids: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + masked_positions: Optional[Tensor] = None, + labels: Optional[Tensor] = None, + next_sentence_label: Optional[Tensor] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + with paddle.static.amp.fp16_guard(): + outputs = self.bert( + input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls( + sequence_output, pooled_output, masked_positions + ) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = paddle.nn.CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.reshape( + (-1, prediction_scores.shape[-1]) + ), + labels.reshape((-1,)), + ) + next_sentence_loss = loss_fct( + seq_relationship_score.reshape((-1, 2)), + next_sentence_label.reshape((-1,)), + ) + total_loss = masked_lm_loss + next_sentence_loss + + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ( + ((total_loss,) + output) if total_loss is not None else output + ) + + +class BertPretrainingCriterion(paddle.nn.Layer): + def __init__(self, vocab_size=VOCAB_SIZE): + super(BertPretrainingCriterion, self).__init__() + # CrossEntropyLoss is expensive since the inner reshape (copy) + self.loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=-1) + self.vocab_size = vocab_size + + def forward( + self, + prediction_scores, + seq_relationship_score, + masked_lm_labels, + next_sentence_labels, + masked_lm_scale, + ): + with paddle.static.amp.fp16_guard(): + masked_lm_loss = F.cross_entropy( + prediction_scores, + masked_lm_labels, + reduction="none", + ignore_index=-1, + ) + masked_lm_loss = masked_lm_loss / masked_lm_scale + next_sentence_loss = F.cross_entropy( + seq_relationship_score, next_sentence_labels, reduction="none" + ) + return paddle.sum(masked_lm_loss) + paddle.mean(next_sentence_loss) + + +def layer_init_wrapper(func): + @functools.wraps(func) + def _impl(self, *args, **kwargs): + enable_recompute = kwargs.pop("enable_recompute", False) + func(self, *args, **kwargs) + if paddle.in_dynamic_mode(): + self.enable_recompute = enable_recompute + else: + self.enable_recompute = False + + return _impl + + +def _convert_attention_mask(attn_mask, dtype): + if attn_mask is not None and attn_mask.dtype != dtype: + attn_mask_dtype = convert_dtype(attn_mask.dtype) + if attn_mask_dtype == 'bool' or 'int' in attn_mask_dtype: + attn_mask = (paddle.cast(attn_mask, dtype) - 1.0) * 1e9 + else: + attn_mask = paddle.cast(attn_mask, dtype) + return attn_mask + + +def _transformer_encoder_layer_fwd( + self, src, src_mask=None, cache=None, output_attentions=False +): + self.self_attn.need_weights = output_attentions + src_mask = _convert_attention_mask(src_mask, src.dtype) + + residual = src + if self.normalize_before: + src = self.norm1(src) + + attn_outputs = self.self_attn(src, src, src, src_mask, cache) + if isinstance(attn_outputs, tuple): + src = attn_outputs[0] + outputs = attn_outputs[1:] + else: + src = attn_outputs + outputs = None + + src = residual + self.dropout1(src) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src) + if not self.normalize_before: + src = self.norm2(src) + + return ( + src if outputs is None else ((src,) + outputs[::-1]) + ) # hidden_states, cache, attentions + + +def _transformer_decoder_layer_fwd( + self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + cache=None, + output_attentions=False, +): + residual = tgt + + # self attention + self.self_attn.need_weights = output_attentions + tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype) + + if self.normalize_before: + tgt = self.norm1(tgt) + + self_attn_outputs = self.self_attn( + tgt, tgt, tgt, tgt_mask, cache[0] if cache else None + ) + # self_attn_outputs = (tgt, attn_weights, incremental_cache) or only tgt + if isinstance(self_attn_outputs, type(tgt)): + tgt = self_attn_outputs + else: + tgt = self_attn_outputs[0] + if output_attentions: + self_attn_weights = self_attn_outputs[1] + if cache: + incremental_cache = self_attn_outputs[-1] + + tgt = residual + self.dropout1(tgt) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + + # cross attention + if memory is not None: + self.cross_attn.need_weights = output_attentions + memory_mask = _convert_attention_mask(memory_mask, memory.dtype) + + if self.normalize_before: + tgt = self.norm2(tgt) + + cross_attn_outputs = self.cross_attn( + tgt, memory, memory, memory_mask, cache[1] if cache else None + ) + if isinstance(cross_attn_outputs, type(tgt)): + tgt = cross_attn_outputs + else: + tgt = cross_attn_outputs[0] + if output_attentions: + cross_attn_weights = cross_attn_outputs[1] + if cache: + static_cache = cross_attn_outputs[-1] + + tgt = residual + self.dropout2(tgt) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + + if self.normalize_before: + tgt = self.norm3(tgt) + tgt = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt) + if not self.normalize_before: + tgt = self.norm3(tgt) + + if not output_attentions and cache is None: + return tgt + else: + outputs = (tgt,) + if output_attentions: + outputs += ( + self_attn_weights, + cross_attn_weights if memory is not None else None, + ) + if cache: + outputs += ( + ( + incremental_cache, + static_cache if memory is not None else None, + ), + ) + return outputs + + +def _transformer_encoder_fwd( + self, + src, + src_mask=None, + cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, +): + src_mask = _convert_attention_mask(src_mask, src.dtype) + + output = src + # To get cache from None when use_cache is True, which is compatible with HF + # while HF requires decoder. The implementation here uses cache update in the + # MultiHeadAttention not so efficiently, and maybe optimize it later. + if cache is None and getattr(self, "_use_cache", False): + cache = [tuple(self.layers[0].gen_cache(src))] * len(self.layers) + # To be compatible with `TransformerEncoder.forward`, `_use_cache` defualts + # to True when cache is not None. + new_caches = ( + [] if cache is not None and getattr(self, "_use_cache", True) else None + ) + all_attentions = [] if output_attentions else None + # NOTE: Also includes embeding output which is same as HF. + all_hidden_states = [output] if output_hidden_states else None + for i, mod in enumerate(self.layers): + if self.enable_recompute: + # Note: recompute do not support pass as **kwargs yet. + layer_outputs = recompute( + mod, + output, + src_mask, + None + if cache is None + else cache[i] + if isinstance(cache[i], MultiHeadAttention.Cache) + else MultiHeadAttention.Cache(*cache[i]), + output_attentions, + ) + else: + layer_outputs = mod( + output, + src_mask=src_mask, + cache=None + if cache is None + else cache[i] + if isinstance(cache[i], MultiHeadAttention.Cache) + else MultiHeadAttention.Cache(*cache[i]), + output_attentions=output_attentions, + ) + + if isinstance(layer_outputs, tuple): + output = layer_outputs[0] + outputs = layer_outputs[1:] + else: + output = layer_outputs + outputs = None + + if output_hidden_states: + all_hidden_states.append(output) + if output_attentions: + all_attentions.append(outputs[-1]) + if new_caches is not None: + new_caches.append( + outputs[0] + if isinstance(cache[i], MultiHeadAttention.Cache) + else (tuple(outputs[0])) + ) + + if self.norm is not None: + output = self.norm(output) + + if output_hidden_states: + all_hidden_states[-1] = output + + outputs = tuple( + tuple(v) if isinstance(v, list) else v + for v in [ + output, + new_caches, + all_hidden_states, + all_attentions, + ] + if v is not None + ) + if len(outputs) == 1: + return output + else: + return outputs + + +def _transformer_decoder_fwd( + self, + tgt, + memory=None, + tgt_mask=None, + memory_mask=None, + cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, +): + tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype) + if memory is not None: + memory_mask = _convert_attention_mask(memory_mask, memory.dtype) + + new_caches = [] if cache else None + all_hidden_states = [tgt] if output_hidden_states else None + all_self_attns = [] if output_attentions else None + all_cross_attns = [] if output_attentions else None + + for i, mod in enumerate(self.layers): + if cache is None: + if self.enable_recompute: + outputs = recompute( + mod, + tgt, + memory, + tgt_mask, + memory_mask, + None, + output_attentions, + ) + else: + outputs = mod( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + cache=None, + output_attentions=output_attentions, + ) + else: + outputs = mod( + tgt, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + cache=cache[i] if cache else None, + output_attentions=output_attentions, + ) + if isinstance(outputs, type(tgt)): + tgt = outputs + else: + tgt = outputs[0] + if cache: + new_caches.append(outputs[-1]) + if output_attentions: + all_self_attns.append(outputs[1]) + all_cross_attns.append(outputs[2]) + if output_hidden_states: + all_hidden_states.append(tgt) + + if self.norm is not None: + tgt = self.norm(tgt) + if output_hidden_states: + all_hidden_states[-1] = tgt + + if isinstance(outputs, type(tgt)): + return tgt + + temp_list = [ + tgt, + new_caches if cache else None, + all_hidden_states, + all_self_attns, + all_cross_attns, + ] + return tuple(v for v in temp_list if v is not None) + + +# patches of paddle.nn.Transformer to get all hidden_states and attentions +paddle.nn.TransformerEncoderLayer.forward = _transformer_encoder_layer_fwd +paddle.nn.TransformerDecoderLayer.forward = _transformer_decoder_layer_fwd +paddle.nn.TransformerEncoder.forward = _transformer_encoder_fwd +paddle.nn.TransformerDecoder.forward = _transformer_decoder_fwd + +_encoder_init = paddle.nn.TransformerEncoder.__init__ +_decoder_init = paddle.nn.TransformerDecoder.__init__ +paddle.nn.TransformerEncoder.__init__ = layer_init_wrapper(_encoder_init) +paddle.nn.TransformerDecoder.__init__ = layer_init_wrapper(_decoder_init) + + +class PretrainingDataset(Dataset): + def __init__(self, input_file, max_pred_length): + self.input_file = input_file + self.max_pred_length = max_pred_length + keys = [ + "input_ids", + "input_mask", + "segment_ids", + "masked_lm_positions", + "masked_lm_ids", + "next_sentence_labels", + ] + self.inputs = np.load(input_file) + self.inputs = [self.inputs[key] for key in keys] + + def __len__(self): + "Denotes the total number of samples" + return len(self.inputs[0]) + + def __getitem__(self, index): + + [ + input_ids, + input_mask, + segment_ids, + masked_lm_positions, + masked_lm_ids, + next_sentence_labels, + ] = [ + input[index].astype(np.int64) + if indice < 5 + else np.asarray(input[index].astype(np.int64)) + for indice, input in enumerate(self.inputs) + ] + # TODO: whether to use reversed mask by changing 1s and 0s to be + # consistent with nv bert + input_mask = ( + 1 + - np.reshape( + input_mask.astype(np.float32), [1, 1, input_mask.shape[0]] + ) + ) * -1e9 + + index = self.max_pred_length + # store number of masked tokens in index + # outputs of torch.nonzero diff with that of numpy.nonzero by zip + padded_mask_indices = (masked_lm_positions == 0).nonzero()[0] + if len(padded_mask_indices) != 0: + index = padded_mask_indices[0].item() + else: + index = self.max_pred_length + # masked_lm_labels = np.full(input_ids.shape, -1, dtype=np.int64) + # masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] + masked_lm_labels = masked_lm_ids[:index] + masked_lm_positions = masked_lm_positions[:index] + # softmax_with_cross_entropy enforce last dim size equal 1 + masked_lm_labels = np.expand_dims(masked_lm_labels, axis=-1) + next_sentence_labels = np.expand_dims(next_sentence_labels, axis=-1) + return [ + input_ids, + segment_ids, + input_mask, + masked_lm_positions, + masked_lm_labels, + next_sentence_labels, + ] + + +def create_pretraining_dataset( + input_file, max_pred_length, shared_list, batch_size, worker_init +): + train_data = PretrainingDataset( + input_file=input_file, max_pred_length=max_pred_length + ) + # files have been sharded, no need to dispatch again + train_batch_sampler = paddle.io.BatchSampler( + train_data, batch_size=batch_size, shuffle=True + ) + + # DataLoader cannot be pickled because of its place. + # If it can be pickled, use global function instead of lambda and use + # ProcessPoolExecutor instead of ThreadPoolExecutor to prefetch. + def _collate_data(data, stack_fn=Stack()): + num_fields = len(data[0]) + out = [None] * num_fields + # input_ids, segment_ids, input_mask, masked_lm_positions, + # masked_lm_labels, next_sentence_labels, mask_token_num + for i in (0, 1, 2, 5): + out[i] = stack_fn([x[i] for x in data]) + _, seq_length = out[0].shape + size = sum(len(x[3]) for x in data) + # Padding for divisibility by 8 for fp16 or int8 usage + if size % 8 != 0: + size += 8 - (size % 8) + # masked_lm_positions + # Organize as a 1D tensor for gather or use gather_nd + out[3] = np.full(size, 0, dtype=np.int32) + # masked_lm_labels + out[4] = np.full([size, 1], -1, dtype=np.int64) + mask_token_num = 0 + for i, x in enumerate(data): + for j, pos in enumerate(x[3]): + out[3][mask_token_num] = i * seq_length + pos + out[4][mask_token_num] = x[4][j] + mask_token_num += 1 + # mask_token_num + out.append(np.asarray([mask_token_num], dtype=np.float32)) + return out + + train_data_loader = DataLoader( + dataset=train_data, + batch_sampler=train_batch_sampler, + collate_fn=_collate_data, + num_workers=0, + worker_init_fn=worker_init, + return_list=True, + ) + return train_data_loader + + +if __name__ == '__main__': + bert = Bert() diff --git a/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py b/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py new file mode 100644 index 0000000000000000000000000000000000000000..99a2404d76372a1c60741f046a3e5e23919f2f0e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py @@ -0,0 +1,154 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import platform +import time +import unittest + +import numpy as np +from bert import Bert, BertPretrainingCriterion, create_pretraining_dataset + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.dataset.common import DATA_HOME, download + +SEED = 2023 +BATCH_SIZE = 2 + +URL = 'https://paddle-ci.gz.bcebos.com/prim_cinn/bert_training_data.npz' +MODULE_NAME = 'test_bert_prim_cinn' +MD5SUM = '71e730ee8d7aa77a215b7e898aa089af' +SAVE_NAME = 'bert_training_data.npz' + +if core.is_compiled_with_cuda(): + paddle.set_flags({'FLAGS_cudnn_deterministic': True}) + + +def train(to_static, enable_prim, enable_cinn): + if core.is_compiled_with_cuda(): + paddle.set_device('gpu') + else: + paddle.set_device('cpu') + fluid.core._set_prim_all_enabled( + enable_prim and platform.system() == 'Linux' + ) + + np.random.seed(SEED) + paddle.seed(SEED) + # paddle.framework.random._manual_program_seed(SEED) + + train_data_loader = create_pretraining_dataset( + os.path.join(DATA_HOME, MODULE_NAME, SAVE_NAME), + 20, + {}, + batch_size=BATCH_SIZE, + worker_init=None, + ) + + bert = Bert() + criterion = BertPretrainingCriterion() + if to_static: + # input_sepc = [ + # InputSpec(shape=(-1, -1), dtype=paddle.int64, name='input_ids'), + # InputSpec(shape=(-1, -1), dtype=paddle.int64, name='segment_ids'), + # None, + # InputSpec(shape=(-1, 1, 1, -1), dtype=paddle.float32, name='input_mask'), + # InputSpec(shape=(-1,), dtype=paddle.int32, name='masked_lm_positions'), + # ] + input_sepc = None + build_strategy = paddle.static.BuildStrategy() + if enable_cinn: + build_strategy.build_cinn_pass = True + bert = paddle.jit.to_static( + bert, input_sepc, build_strategy=build_strategy + ) + + optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters()) + + losses = [] + for step, batch in enumerate(train_data_loader): + start_time = time.time() + ( + input_ids, + segment_ids, + input_mask, + masked_lm_positions, + masked_lm_labels, + next_sentence_labels, + masked_lm_scale, + ) = batch + + prediction_scores, seq_relationship_score = bert( + input_ids=input_ids, + token_type_ids=segment_ids, + attention_mask=input_mask, + masked_positions=masked_lm_positions, + ) + + loss = criterion( + prediction_scores, + seq_relationship_score, + masked_lm_labels, + next_sentence_labels, + masked_lm_scale, + ) + + loss.backward() + optimizer.minimize(loss) + bert.clear_gradients() + losses.append(loss) + + print( + "step: {}, loss: {}, batch_cost: {:.5}".format( + step, + loss.numpy(), + time.time() - start_time, + ) + ) + if step >= 9: + break + return losses + + +class TestBert(unittest.TestCase): + @classmethod + def setUpClass(cls): + download(URL, MODULE_NAME, MD5SUM, SAVE_NAME) + cls.dy2st = train(to_static=True, enable_prim=False, enable_cinn=False) + + def test_prim(self): + dy2st_prim = train(to_static=True, enable_prim=True, enable_cinn=False) + np.testing.assert_allclose(self.dy2st, dy2st_prim, rtol=1e-1) + + @unittest.skipIf( + not paddle.is_compiled_with_cinn(), "padle is not compiled with CINN" + ) + def test_cinn(self): + dy2st_cinn = train(to_static=True, enable_prim=False, enable_cinn=True) + np.testing.assert_allclose(self.dy2st, dy2st_cinn, rtol=1e-6) + + @unittest.skipIf( + not paddle.is_compiled_with_cinn(), "padle is not compiled with CINN" + ) + def test_prim_cinn(self): + dy2st_prim_cinn = train( + to_static=True, enable_prim=True, enable_cinn=True + ) + np.testing.assert_allclose(self.dy2st, dy2st_prim_cinn, rtol=1e-1) + + +if __name__ == '__main__': + unittest.main()