# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import io import copy import logging import six import json import paddle from paddle import nn from paddle.nn import functional as F from paddlenlp.utils.env import MODEL_HOME from paddle.utils.download import get_path_from_url from paddlenlp.utils.log import logger from paddlenlp.transformers import BertPretrainedModel, ElectraPretrainedModel, RobertaPretrainedModel, ErniePretrainedModel from ..utils import InitTrackerMeta, fn_args_to_dict __all__ = ["ErnieGenPretrainedModel", "ErnieForGeneration"] def _build_linear(n_in, n_out, name, init): return nn.Linear( n_in, n_out, weight_attr=paddle.ParamAttr( name='%s.w_0' % name if name is not None else None, initializer=init), bias_attr='%s.b_0' % name if name is not None else None, ) def _build_ln(n_in, name): return nn.LayerNorm( normalized_shape=n_in, weight_attr=paddle.ParamAttr( name='%s_layer_norm_scale' % name if name is not None else None, initializer=nn.initializer.Constant(1.)), bias_attr=paddle.ParamAttr( name='%s_layer_norm_bias' % name if name is not None else None, initializer=nn.initializer.Constant(1.)), ) def append_name(name, postfix): if name is None: ret = None elif name == '': ret = postfix else: ret = '%s_%s' % (name, postfix) return ret class AttentionLayer(nn.Layer): def __init__(self, cfg, name=None): super(AttentionLayer, self).__init__() initializer = nn.initializer.TruncatedNormal( std=cfg['initializer_range']) d_model = cfg['hidden_size'] n_head = cfg['num_attention_heads'] assert d_model % n_head == 0 d_model_q = cfg.get('query_hidden_size_per_head', d_model // n_head) * n_head d_model_v = cfg.get('value_hidden_size_per_head', d_model // n_head) * n_head self.n_head = n_head self.d_key = d_model_q // n_head self.q = _build_linear(d_model, d_model_q, append_name(name, 'query_fc'), initializer) self.k = _build_linear(d_model, d_model_q, append_name(name, 'key_fc'), initializer) self.v = _build_linear(d_model, d_model_v, append_name(name, 'value_fc'), initializer) self.o = _build_linear(d_model_v, d_model, append_name(name, 'output_fc'), initializer) self.dropout = nn.Dropout(p=cfg['attention_probs_dropout_prob']) def forward(self, queries, keys, values, attn_bias, past_cache): assert len(queries.shape) == len(keys.shape) == len(values.shape) == 3 #bsz, q_len, q_dim = queries.shape #bsz, k_len, k_dim = keys.shape #bsz, v_len, v_dim = values.shape #assert k_len == v_len q = self.q(queries) k = self.k(keys) v = self.v(values) cache = (k, v) if past_cache is not None: cached_k, cached_v = past_cache k = paddle.concat([cached_k, k], 1) v = paddle.concat([cached_v, v], 1) q = q.reshape( [0, 0, self.n_head, q.shape[-1] // self.n_head]).transpose( [0, 2, 1, 3]) #[batch, head, seq, dim] k = k.reshape( [0, 0, self.n_head, k.shape[-1] // self.n_head]).transpose( [0, 2, 1, 3]) #[batch, head, seq, dim] v = v.reshape( [0, 0, self.n_head, v.shape[-1] // self.n_head]).transpose( [0, 2, 1, 3]) #[batch, head, seq, dim] q = q.scale(self.d_key**-0.5) score = q.matmul(k, transpose_y=True) if attn_bias is not None: score += attn_bias score = F.softmax(score) score = self.dropout(score) out = score.matmul(v).transpose([0, 2, 1, 3]) out = out.reshape([0, 0, out.shape[2] * out.shape[3]]) out = self.o(out) return out, cache class PositionwiseFeedForwardLayer(nn.Layer): def __init__(self, cfg, name=None): super(PositionwiseFeedForwardLayer, self).__init__() initializer = nn.initializer.TruncatedNormal( std=cfg['initializer_range']) d_model = cfg['hidden_size'] d_ffn = cfg.get('intermediate_size', 4 * d_model) self.act = getattr(paddle.nn.functional, cfg['hidden_act']) self.i = _build_linear( d_model, d_ffn, append_name(name, 'fc_0'), initializer, ) self.o = _build_linear(d_ffn, d_model, append_name(name, 'fc_1'), initializer) prob = cfg.get('intermediate_dropout_prob', 0.) self.dropout = nn.Dropout(p=prob) def forward(self, inputs): hidden = self.act(self.i(inputs)) hidden = self.dropout(hidden) out = self.o(hidden) return out class ErnieEncoderLayer(nn.Layer): def __init__(self, cfg, name=None): super(ErnieEncoderLayer, self).__init__() d_model = cfg['hidden_size'] self.attn = AttentionLayer( cfg, name=append_name(name, 'multi_head_att')) self.ln1 = _build_ln(d_model, name=append_name(name, 'post_att')) self.ffn = PositionwiseFeedForwardLayer( cfg, name=append_name(name, 'ffn')) self.ln2 = _build_ln(d_model, name=append_name(name, 'post_ffn')) prob = cfg.get('intermediate_dropout_prob', cfg['hidden_dropout_prob']) self.dropout = nn.Dropout(p=prob) def forward(self, inputs, attn_bias=None, past_cache=None): attn_out, cache = self.attn( inputs, inputs, inputs, attn_bias, past_cache=past_cache) #self attn attn_out = self.dropout(attn_out) hidden = attn_out + inputs hidden = self.ln1(hidden) # dropout/ add/ norm ffn_out = self.ffn(hidden) ffn_out = self.dropout(ffn_out) hidden = ffn_out + hidden hidden = self.ln2(hidden) return hidden, cache class ErnieEncoderStack(nn.Layer): def __init__(self, cfg, name=None): super(ErnieEncoderStack, self).__init__() n_layers = cfg['num_hidden_layers'] self.block = nn.LayerList([ ErnieEncoderLayer(cfg, append_name(name, 'layer_%d' % i)) for i in range(n_layers) ]) def forward(self, inputs, attn_bias=None, past_cache=None): if past_cache is not None: assert isinstance( past_cache, tuple ), 'unknown type of `past_cache`, expect tuple or list. got %s' % repr( type(past_cache)) past_cache = list(zip(*past_cache)) else: past_cache = [None] * len(self.block) cache_list_k, cache_list_v, hidden_list = [], [], [inputs] for b, p in zip(self.block, past_cache): inputs, cache = b(inputs, attn_bias=attn_bias, past_cache=p) cache_k, cache_v = cache cache_list_k.append(cache_k) cache_list_v.append(cache_v) hidden_list.append(inputs) return inputs, hidden_list, (cache_list_k, cache_list_v) @six.add_metaclass(InitTrackerMeta) class ErnieGenPretrainedModel(object): model_config_file = "model_config.json" ernie_gen_pretrained_init_configuration = { "ernie-gen-base-en": { "attention_probs_dropout_prob": 0.1, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "max_position_embeddings": 1024, "num_attention_heads": 12, "num_hidden_layers": 12, "type_vocab_size": 4, "vocab_size": 30522, "pad_token_id": 0, }, "ernie-gen-large-en": { "attention_probs_dropout_prob": 0.1, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 1024, "initializer_range": 0.02, "intermediate_size": 4096, "max_position_embeddings": 1024, "num_attention_heads": 16, "num_hidden_layers": 24, "type_vocab_size": 4, "vocab_size": 30522, "pad_token_id": 0, }, "ernie-gen-large-en-430g": { "attention_probs_dropout_prob": 0.1, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 1024, "initializer_range": 0.02, "intermediate_size": 4096, "max_position_embeddings": 1024, "num_attention_heads": 16, "num_hidden_layers": 24, "type_vocab_size": 4, "vocab_size": 30522, "pad_token_id": 0, }, } resource_files_names = {"model_state": "model_state.pdparams"} ernie_gen_pretrained_resource_files_map = { "model_state": { "ernie-gen-base-en": "https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-base/ernie_gen_base.pdparams", "ernie-gen-large-en": "https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-large/ernie_gen_large.pdparams", "ernie-gen-large-430g-en": "https://paddlenlp.bj.bcebos.com/models/transformers/ernie-gen-large-430g/ernie_gen_large_430g.pdparams", } } # Support more model to warm start. pretrained_init_configuration = { ** ernie_gen_pretrained_init_configuration, ** BertPretrainedModel.pretrained_init_configuration, ** ElectraPretrainedModel.pretrained_init_configuration, ** RobertaPretrainedModel.pretrained_init_configuration, ** ErniePretrainedModel.pretrained_init_configuration } pretrained_resource_files_map = { "model_state": { ** ernie_gen_pretrained_resource_files_map["model_state"], ** BertPretrainedModel.pretrained_resource_files_map["model_state"], ** ElectraPretrainedModel.pretrained_resource_files_map["model_state"], ** RobertaPretrainedModel.pretrained_resource_files_map["model_state"], ** ErniePretrainedModel.pretrained_resource_files_map["model_state"] } } @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): pretrained_models = list(cls.pretrained_init_configuration.keys()) resource_files = {} init_configuration = {} if pretrained_model_name_or_path in pretrained_models: for file_id, map_list in cls.pretrained_resource_files_map.items(): resource_files[file_id] = map_list[ pretrained_model_name_or_path] init_configuration = copy.deepcopy( cls.pretrained_init_configuration[ pretrained_model_name_or_path]) else: if os.path.isdir(pretrained_model_name_or_path): for file_id, file_name in cls.resource_files_names.items(): full_file_name = os.path.join(pretrained_model_name_or_path, file_name) resource_files[file_id] = full_file_name resource_files["model_config_file"] = os.path.join( pretrained_model_name_or_path, cls.model_config_file) else: raise ValueError( "Calling {}.from_pretrained() with a model identifier or the " "path to a directory instead. The supported model " "identifiers are as follows: {}".format( cls.__name__, cls.pretrained_init_configuration.keys())) default_root = os.path.join(MODEL_HOME, pretrained_model_name_or_path) resolved_resource_files = {} for file_id, file_path in resource_files.items(): path = os.path.join(default_root, file_path.split('/')[-1]) if file_path is None or os.path.isfile(file_path): resolved_resource_files[file_id] = file_path elif os.path.exists(path): logger.info("Already cached %s" % path) resolved_resource_files[file_id] = path else: logger.info("Downloading %s and saved to %s" % (file_path, default_root)) resolved_resource_files[file_id] = get_path_from_url( file_path, default_root) # Prepare model initialization kwargs # Did we saved some inputs and kwargs to reload ? model_config_file = resolved_resource_files.pop("model_config_file", None) if model_config_file is not None: with io.open(model_config_file, encoding="utf-8") as f: init_kwargs = json.load(f) else: init_kwargs = init_configuration # import pdb; pdb.set_trace() if not os.path.exists(resolved_resource_files[file_id]): raise ValueError('pretrain dir not found: %s' % resolved_resource_files[file_id]) name_prefix = kwargs.pop('name', None) model = cls(init_kwargs, name=name_prefix) weight_path = list(resolved_resource_files.values())[0] logger.info('loading pretrained model from %s' % weight_path) if os.path.exists(weight_path): m = paddle.load(weight_path) params_name = list(m.keys()) if 'mlm.weight' not in params_name: # ernie_gen is not implemented with paddle.transformer. # So, when loading the params saved by paddle.transformer, we should convert the params name. # We will update ernie_gen with paddle.transformer in the future. name_index_begin = params_name[0].index('.') + 1 for old_name in params_name: new_name = old_name[name_index_begin:].replace("embeddings.word_embeddings","word_emb").replace("embeddings.position_embeddings","pos_emb")\ .replace("embeddings.token_type_embeddings","sent_emb").replace("embeddings.layer_norm","ln").replace("encoder.layers","encoder_stack.block")\ .replace("self_attn","attn").replace("k_proj","k").replace("q_proj","q").replace("v_proj","v").replace("out_proj","o")\ .replace("linear1","ffn.i").replace("linear2","ffn.o").replace("norm1","ln1").replace("norm2","ln2").replace("pooler.dense","pooler") m[new_name] = m.pop(old_name) for k, v in model.state_dict().items(): if k not in m: logger.info('param:%s not set in pretrained model, skip' % k) m[k] = v # FIXME: no need to do this in the future model.set_state_dict(m) else: raise ValueError('weight file not found in pretrain dir: %s' % weight_path) return model def save_pretrained(self, save_directory): """ Save model configuration and related resources (model state) to files under `save_directory`. Args: save_directory (str): Directory to save files into. """ assert os.path.isdir( save_directory ), "Saving directory ({}) should be a directory".format(save_directory) # save model config model_config_file = os.path.join(save_directory, self.model_config_file) model_config = self.init_config # If init_config contains a Layer, use the layer's init_config to save for key, value in model_config.items(): if key == "init_args": args = [] for arg in value: args.append( arg.init_config if isinstance(arg, ErnieGenPretrainedModel) else arg) model_config[key] = tuple(args) elif isinstance(value, ErnieGenPretrainedModel): model_config[key] = value.init_config with io.open(model_config_file, "w", encoding="utf-8") as f: f.write(json.dumps(model_config, ensure_ascii=False)) # save model file_name = os.path.join(save_directory, list(self.resource_files_names.values())[0]) paddle.save(self.state_dict(), file_name) def _wrap_init(self, original_init, *args, **kwargs): """ It would be hooked after `__init__` to add a dict including arguments of `__init__` as a attribute named `config` of the prtrained model instance. """ init_dict = fn_args_to_dict(original_init, *args, **kwargs) self.config = init_dict class ErnieModel(nn.Layer, ErnieGenPretrainedModel): def __init__(self, cfg, name=None): """ Fundamental pretrained Ernie model """ logger.debug('init ErnieModel with config: %s' % repr(cfg)) nn.Layer.__init__(self) d_model = cfg['hidden_size'] d_emb = cfg.get('emb_size', cfg['hidden_size']) d_vocab = cfg['vocab_size'] d_pos = cfg['max_position_embeddings'] d_sent = cfg.get("sent_type_vocab_size") or cfg['type_vocab_size'] self.n_head = cfg['num_attention_heads'] self.return_additional_info = cfg.get('return_additional_info', False) initializer = nn.initializer.TruncatedNormal( std=cfg['initializer_range']) self.ln = _build_ln(d_model, name=append_name(name, 'pre_encoder')) self.word_emb = nn.Embedding( d_vocab, d_emb, weight_attr=paddle.ParamAttr( name=append_name(name, 'word_embedding'), initializer=initializer)) self.pos_emb = nn.Embedding( d_pos, d_emb, weight_attr=paddle.ParamAttr( name=append_name(name, 'pos_embedding'), initializer=initializer)) self.sent_emb = nn.Embedding( d_sent, d_emb, weight_attr=paddle.ParamAttr( name=append_name(name, 'sent_embedding'), initializer=initializer)) prob = cfg['hidden_dropout_prob'] self.dropout = nn.Dropout(p=prob) self.encoder_stack = ErnieEncoderStack(cfg, append_name(name, 'encoder')) def forward(self, src_ids, sent_ids=None, pos_ids=None, input_mask=None, attn_bias=None, past_cache=None, use_causal_mask=False): """ Args: src_ids (`Variable` of shape `[batch_size, seq_len]`): Indices of input sequence tokens in the vocabulary. sent_ids (optional, `Variable` of shape `[batch_size, seq_len]`): aka token_type_ids, Segment token indices to indicate first and second portions of the inputs. if None, assume all tokens come from `segment_a` pos_ids(optional, `Variable` of shape `[batch_size, seq_len]`): Indices of positions of each input sequence tokens in the position embeddings. input_mask(optional `Variable` of shape `[batch_size, seq_len]`): Mask to avoid performing attention on the padding token indices of the encoder input. attn_bias(optional, `Variable` of shape `[batch_size, seq_len, seq_len] or False`): 3D version of `input_mask`, if set, overrides `input_mask`; if set not False, will not apply attention mask past_cache(optional, tuple of two lists: cached key and cached value, each is a list of `Variable`s of shape `[batch_size, seq_len, hidden_size]`): cached key/value tensor that will be concated to generated key/value when performing self attention. if set, `attn_bias` should not be None. Returns: pooled (`Variable` of shape `[batch_size, hidden_size]`): output logits of pooler classifier encoded(`Variable` of shape `[batch_size, seq_len, hidden_size]`): output logits of transformer stack info (Dictionary): addtional middle level info, inclues: all hidden stats, k/v caches. """ assert len( src_ids. shape) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % ( repr(src_ids.shape)) assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None' d_seqlen = paddle.shape(src_ids)[1] if pos_ids is None: pos_ids = paddle.arange( 0, d_seqlen, 1, dtype='int32').reshape([1, -1]).cast('int64') if attn_bias is None: if input_mask is None: input_mask = paddle.cast(src_ids != 0, 'float32') assert len(input_mask.shape) == 2 input_mask = input_mask.unsqueeze(-1) attn_bias = input_mask.matmul(input_mask, transpose_y=True) if use_causal_mask: sequence = paddle.reshape( paddle.arange( 0, d_seqlen, 1, dtype='float32') + 1., [1, 1, -1, 1]) causal_mask = (sequence.matmul( 1. / sequence, transpose_y=True) >= 1.).cast('float32') attn_bias *= causal_mask else: assert len( attn_bias.shape ) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape attn_bias = (1. - attn_bias) * -10000.0 attn_bias = attn_bias.unsqueeze(1).tile( [1, self.n_head, 1, 1]) # avoid broadcast =_= if sent_ids is None: sent_ids = paddle.zeros_like(src_ids) src_embedded = self.word_emb(src_ids) pos_embedded = self.pos_emb(pos_ids) sent_embedded = self.sent_emb(sent_ids) embedded = src_embedded + pos_embedded + sent_embedded embedded = self.dropout(self.ln(embedded)) encoded, hidden_list, cache_list = self.encoder_stack( embedded, attn_bias, past_cache=past_cache) additional_info = { 'hiddens': hidden_list, 'caches': cache_list, } return encoded, additional_info class ErnieForGeneration(ErnieModel): """ Ernie Model for sequence to sequence generation. """ def __init__(self, cfg, name=None): super(ErnieForGeneration, self).__init__(cfg, name=name) initializer = nn.initializer.TruncatedNormal( std=cfg['initializer_range']) d_model = cfg['hidden_size'] d_vocab = cfg['vocab_size'] self.mlm = _build_linear( d_model, d_model, append_name(name, 'mask_lm_trans_fc'), initializer, ) self.act = getattr(paddle.nn.functional, cfg['hidden_act']) self.mlm_ln = _build_ln( d_model, name=append_name(name, 'mask_lm_trans')) self.mlm_bias = paddle.create_parameter( dtype='float32', shape=[d_vocab], attr=paddle.ParamAttr( name=append_name(name, 'mask_lm_out_fc.b_0'), initializer=nn.initializer.Constant(value=0.0)), is_bias=True, ) def forward(self, *args, **kwargs): """ Args tgt_labels(`Variable` of shape [batch_size, seqlen] or [batch, seqlen, vocab_size]): ground trouth target sequence id (hard label) or distribution (soft label) tgt_pos(`Variable` of shape [n_targets, 2]): index of tgt_labels in `src_ids`, can be obtained from `fluid.layers.where(src_ids==mask_id)` encoder_only(Bool): if set, will not return loss, logits_2d Returns: loss(`Variable` of shape []): cross entropy loss mean over every target label. if `encode_only`, returns None. logits(`Variable` of shape [n_targets, vocab_size]): logits for every targets. if `encode_only`, returns None. info(Dictionary): see `ErnieModel` """ tgt_labels = kwargs.pop('tgt_labels', None) tgt_pos = kwargs.pop('tgt_pos', None) encode_only = kwargs.pop('encode_only', False) encoded, info = ErnieModel.forward(self, *args, **kwargs) if encode_only: return None, None, info if tgt_labels is None or tgt_pos is None: encoded = self.act(self.mlm(encoded)) encoded = self.mlm_ln(encoded) logits = encoded.matmul( self.word_emb.weight, transpose_y=True) + self.mlm_bias output_ids = logits.argmax(-1) return output_ids, logits, info else: encoded_2d = encoded.gather_nd(tgt_pos) encoded_2d = self.act(self.mlm(encoded_2d)) encoded_2d = self.mlm_ln(encoded_2d) logits_2d = encoded_2d.matmul( self.word_emb.weight, transpose_y=True) + self.mlm_bias if len(tgt_labels.shape) == 1: tgt_labels = paddle.reshape(tgt_labels, [-1, 1]) loss = paddle.nn.functional.cross_entropy( logits_2d, tgt_labels, soft_label=(tgt_labels.shape[-1] != 1)) return loss, logits_2d, info