modeling_ernie.py 30.0 KB
Newer Older
M
Meiyim 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals

import json
import logging
M
Meiyim 已提交
22 23 24 25 26
import six
if six.PY2:
    from pathlib2 import Path
else:
    from pathlib import Path
M
Meiyim 已提交
27

M
Meiyim 已提交
28 29 30
import paddle as P
from paddle import nn
from paddle.nn import functional as F
M
Meiyim 已提交
31
from ernie.file_utils import _fetch_from_remote, add_docstring
M
Meiyim 已提交
32 33 34

log = logging.getLogger(__name__)

M
Meiyim 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48
ACT_DICT = {
    'relu': nn.ReLU,
    'gelu': nn.GELU,
}


def _build_linear(n_in, n_out, name, init):
    return nn.Linear(
        n_in,
        n_out,
        weight_attr=P.ParamAttr(
            name='%s.w_0' % name if name is not None else None,
            initializer=init),
        bias_attr='%s.b_0' % name if name is not None else None, )
M
Meiyim 已提交
49 50 51


def _build_ln(n_in, name):
M
Meiyim 已提交
52 53 54 55 56 57 58 59
    return nn.LayerNorm(
        normalized_shape=n_in,
        weight_attr=P.ParamAttr(
            name='%s_layer_norm_scale' % name if name is not None else None,
            initializer=nn.initializer.Constant(1.)),
        bias_attr=P.ParamAttr(
            name='%s_layer_norm_bias' % name if name is not None else None,
            initializer=nn.initializer.Constant(0.)), )
M
Meiyim 已提交
60 61 62 63


def append_name(name, postfix):
    if name is None:
M
Meiyim 已提交
64
        ret = None
M
Meiyim 已提交
65
    elif name == '':
M
Meiyim 已提交
66
        ret = postfix
M
Meiyim 已提交
67
    else:
M
Meiyim 已提交
68 69
        ret = '%s_%s' % (name, postfix)
    return ret
M
Meiyim 已提交
70 71


M
Meiyim 已提交
72
class AttentionLayer(nn.Layer):
M
Meiyim 已提交
73 74
    def __init__(self, cfg, name=None):
        super(AttentionLayer, self).__init__()
M
Meiyim 已提交
75 76
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
M
Meiyim 已提交
77 78 79
        d_model = cfg['hidden_size']
        n_head = cfg['num_attention_heads']
        assert d_model % n_head == 0
M
Meiyim 已提交
80 81 82 83
        d_model_q = cfg.get('query_hidden_size_per_head',
                            d_model // n_head) * n_head
        d_model_v = cfg.get('value_hidden_size_per_head',
                            d_model // n_head) * n_head
M
Meiyim 已提交
84 85
        self.n_head = n_head
        self.d_key = d_model_q // n_head
M
Meiyim 已提交
86 87 88 89 90 91 92 93 94
        self.q = _build_linear(d_model, d_model_q,
                               append_name(name, 'query_fc'), initializer)
        self.k = _build_linear(d_model, d_model_q,
                               append_name(name, 'key_fc'), initializer)
        self.v = _build_linear(d_model, d_model_v,
                               append_name(name, 'value_fc'), initializer)
        self.o = _build_linear(d_model_v, d_model,
                               append_name(name, 'output_fc'), initializer)
        self.dropout = nn.Dropout(p=cfg['attention_probs_dropout_prob'])
M
Meiyim 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

    def forward(self, queries, keys, values, attn_bias, past_cache):
        assert len(queries.shape) == len(keys.shape) == len(values.shape) == 3
        #bsz, q_len, q_dim = queries.shape
        #bsz, k_len, k_dim = keys.shape
        #bsz, v_len, v_dim = values.shape
        #assert k_len == v_len

        q = self.q(queries)
        k = self.k(keys)
        v = self.v(values)

        cache = (k, v)
        if past_cache is not None:
            cached_k, cached_v = past_cache
M
Meiyim 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
            k = P.concat([cached_k, k], 1)
            v = P.concat([cached_v, v], 1)

        q = q.reshape(
            [0, 0, self.n_head, q.shape[-1] // self.n_head]).transpose(
                [0, 2, 1, 3])  #[batch, head, seq, dim]
        k = k.reshape(
            [0, 0, self.n_head, k.shape[-1] // self.n_head]).transpose(
                [0, 2, 1, 3])  #[batch, head, seq, dim]
        v = v.reshape(
            [0, 0, self.n_head, v.shape[-1] // self.n_head]).transpose(
                [0, 2, 1, 3])  #[batch, head, seq, dim]

        q = q.scale(self.d_key**-0.5)
        score = q.matmul(k, transpose_y=True)
M
Meiyim 已提交
125 126
        if attn_bias is not None:
            score += attn_bias
M
Meiyim 已提交
127
        score = F.softmax(score)
M
Meiyim 已提交
128 129
        score = self.dropout(score)

M
Meiyim 已提交
130 131
        out = score.matmul(v).transpose([0, 2, 1, 3])
        out = out.reshape([0, 0, out.shape[2] * out.shape[3]])
M
Meiyim 已提交
132 133 134 135
        out = self.o(out)
        return out, cache


M
Meiyim 已提交
136
class PositionwiseFeedForwardLayer(nn.Layer):
M
Meiyim 已提交
137 138
    def __init__(self, cfg, name=None):
        super(PositionwiseFeedForwardLayer, self).__init__()
M
Meiyim 已提交
139 140
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
M
Meiyim 已提交
141 142
        d_model = cfg['hidden_size']
        d_ffn = cfg.get('intermediate_size', 4 * d_model)
M
Meiyim 已提交
143 144 145 146 147 148 149 150
        self.act = ACT_DICT[cfg['hidden_act']]()
        self.i = _build_linear(
            d_model,
            d_ffn,
            append_name(name, 'fc_0'),
            initializer, )
        self.o = _build_linear(d_ffn, d_model,
                               append_name(name, 'fc_1'), initializer)
M
Meiyim 已提交
151
        prob = cfg.get('intermediate_dropout_prob', 0.)
M
Meiyim 已提交
152
        self.dropout = nn.Dropout(p=prob)
M
Meiyim 已提交
153 154

    def forward(self, inputs):
M
Meiyim 已提交
155
        hidden = self.act(self.i(inputs))
M
Meiyim 已提交
156 157 158 159 160
        hidden = self.dropout(hidden)
        out = self.o(hidden)
        return out


M
Meiyim 已提交
161
class ErnieBlock(nn.Layer):
M
Meiyim 已提交
162 163 164
    def __init__(self, cfg, name=None):
        super(ErnieBlock, self).__init__()
        d_model = cfg['hidden_size']
M
Meiyim 已提交
165 166 167 168 169 170
        self.attn = AttentionLayer(
            cfg, name=append_name(name, 'multi_head_att'))
        self.ln1 = _build_ln(d_model, name=append_name(name, 'post_att'))
        self.ffn = PositionwiseFeedForwardLayer(
            cfg, name=append_name(name, 'ffn'))
        self.ln2 = _build_ln(d_model, name=append_name(name, 'post_ffn'))
M
Meiyim 已提交
171
        prob = cfg.get('intermediate_dropout_prob', cfg['hidden_dropout_prob'])
M
Meiyim 已提交
172
        self.dropout = nn.Dropout(p=prob)
M
Meiyim 已提交
173 174

    def forward(self, inputs, attn_bias=None, past_cache=None):
M
Meiyim 已提交
175 176 177
        attn_out, cache = self.attn(
            inputs, inputs, inputs, attn_bias,
            past_cache=past_cache)  #self attn
M
Meiyim 已提交
178
        attn_out = self.dropout(attn_out)
M
Meiyim 已提交
179 180
        hidden = attn_out + inputs
        hidden = self.ln1(hidden)  # dropout/ add/ norm
M
Meiyim 已提交
181 182 183 184 185 186 187

        ffn_out = self.ffn(hidden)
        ffn_out = self.dropout(ffn_out)
        hidden = ffn_out + hidden
        hidden = self.ln2(hidden)
        return hidden, cache

M
Meiyim 已提交
188 189

class ErnieEncoderStack(nn.Layer):
M
Meiyim 已提交
190 191 192
    def __init__(self, cfg, name=None):
        super(ErnieEncoderStack, self).__init__()
        n_layers = cfg['num_hidden_layers']
M
Meiyim 已提交
193 194 195 196
        self.block = nn.LayerList([
            ErnieBlock(cfg, append_name(name, 'layer_%d' % i))
            for i in range(n_layers)
        ])
M
Meiyim 已提交
197 198 199

    def forward(self, inputs, attn_bias=None, past_cache=None):
        if past_cache is not None:
M
Meiyim 已提交
200 201 202 203
            assert isinstance(
                past_cache, tuple
            ), 'unknown type of `past_cache`, expect tuple or list. got %s' % repr(
                type(past_cache))
M
Meiyim 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
            past_cache = list(zip(*past_cache))
        else:
            past_cache = [None] * len(self.block)
        cache_list_k, cache_list_v, hidden_list = [], [], [inputs]

        for b, p in zip(self.block, past_cache):
            inputs, cache = b(inputs, attn_bias=attn_bias, past_cache=p)
            cache_k, cache_v = cache
            cache_list_k.append(cache_k)
            cache_list_v.append(cache_v)
            hidden_list.append(inputs)

        return inputs, hidden_list, (cache_list_k, cache_list_v)


class PretrainedModel(object):
    bce = 'https://ernie-github.cdn.bcebos.com/'
    resource_map = {
        'ernie-1.0': bce + 'model-ernie1.0.1.tar.gz',
        'ernie-2.0-en': bce + 'model-ernie2.0-en.1.tar.gz',
M
Meiyim 已提交
224
        'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz',
M
Meiyim 已提交
225 226
        'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz',
    }
M
Meiyim 已提交
227

M
Meiyim 已提交
228
    @classmethod
M
Meiyim 已提交
229 230 231 232
    def from_pretrained(cls,
                        pretrain_dir_or_url,
                        force_download=False,
                        **kwargs):
M
Meiyim 已提交
233
        if not Path(pretrain_dir_or_url).exists() and pretrain_dir_or_url in cls.resource_map:
M
Meiyim 已提交
234 235 236 237 238
            url = cls.resource_map[pretrain_dir_or_url]
            log.info('get pretrain dir from %s' % url)
            pretrain_dir = _fetch_from_remote(url, force_download)
        else:
            log.info('pretrain dir %s not in %s, read from local' % (pretrain_dir_or_url, repr(cls.resource_map)))
W
Weiyue Su 已提交
239
            pretrain_dir = Path(pretrain_dir_or_url)
M
Meiyim 已提交
240

M
Meiyim 已提交
241
        if not pretrain_dir.exists():
M
Meiyim 已提交
242
            raise ValueError('pretrain dir not found: %s' % pretrain_dir)
M
Meiyim 已提交
243
        state_dict_path = pretrain_dir / 'saved_weights.pdparams'
M
Meiyim 已提交
244
        config_path = pretrain_dir / 'ernie_config.json'
M
Meiyim 已提交
245

M
Meiyim 已提交
246
        if not config_path.exists():
M
Meiyim 已提交
247
            raise ValueError('config path not found: %s' % config_path)
M
Meiyim 已提交
248
        name_prefix = kwargs.pop('name', None)
M
Meiyim 已提交
249
        cfg_dict = dict(json.loads(config_path.open().read()), **kwargs)
M
Meiyim 已提交
250
        model = cls(cfg_dict, name=name_prefix)
M
Meiyim 已提交
251

M
Meiyim 已提交
252 253
        log.info('loading pretrained model from %s' % pretrain_dir)

M
Meiyim 已提交
254
        #param_path = pretrain_dir / 'params'
M
Meiyim 已提交
255 256 257 258
        #if os.path.exists(param_path):
        #    raise NotImplementedError()
        #    log.debug('load pretrained weight from program state')
        #    F.io.load_program_state(param_path) #buggy in dygraph.gurad, push paddle to fix
M
Meiyim 已提交
259 260
        if state_dict_path.exists():
            m = P.load(state_dict_path)
M
Meiyim 已提交
261 262 263
            for k, v in model.state_dict().items():
                if k not in m:
                    log.warn('param:%s not set in pretrained model, skip' % k)
M
Meiyim 已提交
264 265
                    m[k] = v  # FIXME: no need to do this in the future
            model.set_state_dict(m)
M
Meiyim 已提交
266
        else:
M
Meiyim 已提交
267 268
            raise ValueError('weight file not found in pretrain dir: %s' %
                             pretrain_dir)
M
Meiyim 已提交
269 270 271
        return model


M
Meiyim 已提交
272
class ErnieModel(nn.Layer, PretrainedModel):
M
Meiyim 已提交
273 274 275 276 277
    def __init__(self, cfg, name=None):
        """
        Fundamental pretrained Ernie model
        """
        log.debug('init ErnieModel with config: %s' % repr(cfg))
M
Meiyim 已提交
278
        nn.Layer.__init__(self)
M
Meiyim 已提交
279 280 281 282 283 284 285
        d_model = cfg['hidden_size']
        d_emb = cfg.get('emb_size', cfg['hidden_size'])
        d_vocab = cfg['vocab_size']
        d_pos = cfg['max_position_embeddings']
        d_sent = cfg.get("sent_type_vocab_size") or cfg['type_vocab_size']
        self.n_head = cfg['num_attention_heads']
        self.return_additional_info = cfg.get('return_additional_info', False)
M
Meiyim 已提交
286 287
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
M
Meiyim 已提交
288 289

        self.ln = _build_ln(d_model, name=append_name(name, 'pre_encoder'))
M
Meiyim 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
        self.word_emb = nn.Embedding(
            d_vocab,
            d_emb,
            weight_attr=P.ParamAttr(
                name=append_name(name, 'word_embedding'),
                initializer=initializer))
        self.pos_emb = nn.Embedding(
            d_pos,
            d_emb,
            weight_attr=P.ParamAttr(
                name=append_name(name, 'pos_embedding'),
                initializer=initializer))
        self.sent_emb = nn.Embedding(
            d_sent,
            d_emb,
            weight_attr=P.ParamAttr(
                name=append_name(name, 'sent_embedding'),
                initializer=initializer))
M
Meiyim 已提交
308
        prob = cfg['hidden_dropout_prob']
M
Meiyim 已提交
309
        self.dropout = nn.Dropout(p=prob)
M
Meiyim 已提交
310

M
Meiyim 已提交
311 312
        self.encoder_stack = ErnieEncoderStack(cfg,
                                               append_name(name, 'encoder'))
M
Meiyim 已提交
313
        if cfg.get('has_pooler', True):
M
Meiyim 已提交
314 315 316 317 318
            self.pooler = _build_linear(
                cfg['hidden_size'],
                cfg['hidden_size'],
                append_name(name, 'pooled_fc'),
                initializer, )
M
Meiyim 已提交
319 320 321 322
        else:
            self.pooler = None
        self.train()

M
Meiyim 已提交
323
    #FIXME:remove this
M
Meiyim 已提交
324
    def eval(self):
M
Meiyim 已提交
325
        if P.in_dynamic_mode():
M
Meiyim 已提交
326 327 328 329
            super(ErnieModel, self).eval()
        self.training = False
        for l in self.sublayers():
            l.training = False
M
Meiyim 已提交
330
        return self
M
Meiyim 已提交
331 332

    def train(self):
M
Meiyim 已提交
333
        if P.in_dynamic_mode():
M
Meiyim 已提交
334 335 336 337
            super(ErnieModel, self).train()
        self.training = True
        for l in self.sublayers():
            l.training = True
M
Meiyim 已提交
338 339 340 341 342 343 344 345 346 347
        return self

    def forward(self,
                src_ids,
                sent_ids=None,
                pos_ids=None,
                input_mask=None,
                attn_bias=None,
                past_cache=None,
                use_causal_mask=False):
M
Meiyim 已提交
348 349
        """
        Args:
M
Meiyim 已提交
350
            src_ids (`Variable` of shape `[batch_size, seq_len]`):
M
Meiyim 已提交
351
                Indices of input sequence tokens in the vocabulary.
M
Meiyim 已提交
352
            sent_ids (optional, `Variable` of shape `[batch_size, seq_len]`):
M
Meiyim 已提交
353 354
                aka token_type_ids, Segment token indices to indicate first and second portions of the inputs.
                if None, assume all tokens come from `segment_a`
M
Meiyim 已提交
355
            pos_ids(optional, `Variable` of shape `[batch_size, seq_len]`):
M
Meiyim 已提交
356
                Indices of positions of each input sequence tokens in the position embeddings.
M
Meiyim 已提交
357
            input_mask(optional `Variable` of shape `[batch_size, seq_len]`):
M
Meiyim 已提交
358
                Mask to avoid performing attention on the padding token indices of the encoder input.
M
Meiyim 已提交
359
            attn_bias(optional, `Variable` of shape `[batch_size, seq_len, seq_len] or False`):
M
Meiyim 已提交
360
                3D version of `input_mask`, if set, overrides `input_mask`; if set not False, will not apply attention mask
M
Meiyim 已提交
361 362
            past_cache(optional, tuple of two lists: cached key and cached value,
                each is a list of `Variable`s of shape `[batch_size, seq_len, hidden_size]`):
M
Meiyim 已提交
363
                cached key/value tensor that will be concated to generated key/value when performing self attention.
M
Meiyim 已提交
364 365
                if set, `attn_bias` should not be None.

M
Meiyim 已提交
366 367 368 369 370
        Returns:
            pooled (`Variable` of shape `[batch_size, hidden_size]`):
                output logits of pooler classifier
            encoded(`Variable` of shape `[batch_size, seq_len, hidden_size]`):
                output logits of transformer stack
M
Meiyim 已提交
371 372
            info (Dictionary):
                addtional middle level info, inclues: all hidden stats, k/v caches.
M
Meiyim 已提交
373
        """
M
Meiyim 已提交
374 375 376 377
        assert len(
            src_ids.
            shape) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % (
                repr(src_ids.shape))
M
Meiyim 已提交
378
        assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None'
M
Meiyim 已提交
379
        d_seqlen = P.shape(src_ids)[1]
M
Meiyim 已提交
380
        if pos_ids is None:
M
Meiyim 已提交
381 382
            pos_ids = P.arange(
                0, d_seqlen, 1, dtype='int32').reshape([1, -1]).cast('int64')
M
Meiyim 已提交
383 384
        if attn_bias is None:
            if input_mask is None:
M
Meiyim 已提交
385
                input_mask = P.cast(src_ids != 0, 'float32')
M
Meiyim 已提交
386
            assert len(input_mask.shape) == 2
M
Meiyim 已提交
387 388
            input_mask = input_mask.unsqueeze(-1)
            attn_bias = input_mask.matmul(input_mask, transpose_y=True)
M
Meiyim 已提交
389
            if use_causal_mask:
M
Meiyim 已提交
390 391 392 393 394
                sequence = P.reshape(
                    P.arange(
                        0, d_seqlen, 1, dtype='float32') + 1., [1, 1, -1, 1])
                causal_mask = (sequence.matmul(
                    1. / sequence, transpose_y=True) >= 1.).cast('float32')
M
Meiyim 已提交
395 396
                attn_bias *= causal_mask
        else:
M
Meiyim 已提交
397 398 399
            assert len(
                attn_bias.shape
            ) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape
M
Meiyim 已提交
400
        attn_bias = (1. - attn_bias) * -10000.0
M
Meiyim 已提交
401 402 403
        attn_bias = attn_bias.unsqueeze(1).tile(
            [1, self.n_head, 1, 1])  # avoid broadcast =_=

M
Meiyim 已提交
404
        if sent_ids is None:
M
Meiyim 已提交
405
            sent_ids = P.zeros_like(src_ids)
M
Meiyim 已提交
406 407 408 409 410 411 412 413

        src_embedded = self.word_emb(src_ids)
        pos_embedded = self.pos_emb(pos_ids)
        sent_embedded = self.sent_emb(sent_ids)
        embedded = src_embedded + pos_embedded + sent_embedded

        embedded = self.dropout(self.ln(embedded))

M
Meiyim 已提交
414 415
        encoded, hidden_list, cache_list = self.encoder_stack(
            embedded, attn_bias, past_cache=past_cache)
M
Meiyim 已提交
416
        if self.pooler is not None:
M
Meiyim 已提交
417
            pooled = F.tanh(self.pooler(encoded[:, 0, :]))
M
Meiyim 已提交
418 419 420 421 422 423 424 425 426 427
        else:
            pooled = None

        additional_info = {
            'hiddens': hidden_list,
            'caches': cache_list,
        }

        if self.return_additional_info:
            return pooled, encoded, additional_info
M
Meiyim 已提交
428 429
        return pooled, encoded

M
Meiyim 已提交
430 431 432

class ErnieModelForSequenceClassification(ErnieModel):
    """
M
Meiyim 已提交
433
    Ernie Model for text classfication or pointwise ranking tasks
M
Meiyim 已提交
434 435 436
    """

    def __init__(self, cfg, name=None):
M
Meiyim 已提交
437 438
        super(ErnieModelForSequenceClassification, self).__init__(
            cfg, name=name)
M
Meiyim 已提交
439

M
Meiyim 已提交
440 441 442 443
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
        self.classifier = _build_linear(cfg['hidden_size'], cfg['num_labels'],
                                        append_name(name, 'cls'), initializer)
M
Meiyim 已提交
444 445

        prob = cfg.get('classifier_dropout_prob', cfg['hidden_dropout_prob'])
M
Meiyim 已提交
446 447
        self.dropout = nn.Dropout(p=prob)
        self.train()
M
Meiyim 已提交
448

M
Meiyim 已提交
449
    @add_docstring(ErnieModel.forward.__doc__)
M
Meiyim 已提交
450 451 452
    def forward(self, *args, **kwargs):
        """
        Args:
M
Meiyim 已提交
453
            labels (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
454 455 456 457 458 459 460 461 462
                ground truth label id for each sentence
        Returns:
            loss (`Variable` of shape []):
                Cross entropy loss mean over batch
                if labels not set, returns None
            logits (`Variable` of shape [batch_size, hidden_size]):
                output logits of classifier
        """
        labels = kwargs.pop('labels', None)
M
Meiyim 已提交
463 464
        pooled, encoded = super(ErnieModelForSequenceClassification,
                                self).forward(*args, **kwargs)
M
Meiyim 已提交
465 466 467 468
        hidden = self.dropout(pooled)
        logits = self.classifier(hidden)

        if labels is not None:
M
Meiyim 已提交
469 470 471
            if len(labels.shape) != 1:
                labels = labels.squeeze()
            loss = F.cross_entropy(logits, labels)
M
Meiyim 已提交
472 473 474 475 476 477 478 479 480
        else:
            loss = None
        return loss, logits


class ErnieModelForTokenClassification(ErnieModel):
    """
    Ernie Model for Named entity tasks(NER)
    """
M
Meiyim 已提交
481

M
Meiyim 已提交
482 483 484
    def __init__(self, cfg, name=None):
        super(ErnieModelForTokenClassification, self).__init__(cfg, name=name)

M
Meiyim 已提交
485 486 487 488
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
        self.classifier = _build_linear(cfg['hidden_size'], cfg['num_labels'],
                                        append_name(name, 'cls'), initializer)
M
Meiyim 已提交
489 490

        prob = cfg.get('classifier_dropout_prob', cfg['hidden_dropout_prob'])
M
Meiyim 已提交
491 492
        self.dropout = nn.Dropout(p=prob)
        self.train()
M
Meiyim 已提交
493

M
Meiyim 已提交
494
    @add_docstring(ErnieModel.forward.__doc__)
495
    def forward(self, *args, **kwargs):
M
Meiyim 已提交
496 497
        """
        Args:
M
Meiyim 已提交
498
            labels (optional, `Variable` of shape [batch_size, seq_len]):
M
Meiyim 已提交
499 500 501 502 503 504 505
                ground truth label id for each token
        Returns:
            loss (`Variable` of shape []):
                Cross entropy loss mean over batch and time, ignore positions where label == -100
                if labels not set, returns None
            logits (`Variable` of shape [batch_size, seq_len, hidden_size]):
                output logits of classifier
506 507 508 509
            loss_weights (`Variable` of shape [batch_size, seq_len]):
                weigths of loss for each tokens.
            ignore_index (int):
                when label == `ignore_index`, this token will not contribute to loss
M
Meiyim 已提交
510
        """
511 512 513
        ignore_index = kwargs.pop('ignore_index', -100)
        labels = kwargs.pop('labels', None)
        loss_weights = kwargs.pop('loss_weights', None)
M
Meiyim 已提交
514 515 516
        pooled, encoded = super(ErnieModelForTokenClassification,
                                self).forward(*args, **kwargs)
        hidden = self.dropout(encoded)  # maybe not?
M
Meiyim 已提交
517 518 519
        logits = self.classifier(hidden)

        if labels is not None:
M
Meiyim 已提交
520 521 522 523
            if len(labels.shape) != 2:
                labels = labels.squeeze()
            loss = F.cross_entropy(
                logits, labels, ignore_index=ignore_index, reduction='none')
M
Meiyim 已提交
524
            if loss_weights is not None:
M
Meiyim 已提交
525 526
                loss = loss * loss_weights
            loss = loss.mean()
M
Meiyim 已提交
527 528 529 530 531 532 533 534 535
        else:
            loss = None
        return loss, logits


class ErnieModelForQuestionAnswering(ErnieModel):
    """
    Ernie model for reading comprehension tasks (SQuAD)
    """
M
Meiyim 已提交
536

M
Meiyim 已提交
537 538 539
    def __init__(self, cfg, name=None):
        super(ErnieModelForQuestionAnswering, self).__init__(cfg, name=name)

M
Meiyim 已提交
540 541 542 543 544
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
        self.classifier = _build_linear(cfg['hidden_size'], 2,
                                        append_name(name, 'cls_mrc'),
                                        initializer)
M
Meiyim 已提交
545 546

        prob = cfg.get('classifier_dropout_prob', cfg['hidden_dropout_prob'])
M
Meiyim 已提交
547 548
        self.dropout = nn.Dropout(p=prob)
        self.train()
M
Meiyim 已提交
549

M
Meiyim 已提交
550
    @add_docstring(ErnieModel.forward.__doc__)
M
Meiyim 已提交
551 552 553
    def forward(self, *args, **kwargs):
        """
        Args:
M
Meiyim 已提交
554
            start_pos (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
555
                token index of start of answer span in `context`
M
Meiyim 已提交
556
            end_pos (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
557 558 559 560 561 562 563 564 565 566 567 568 569
                token index of end of answer span in `context`
        Returns:
            loss (`Variable` of shape []):
                Cross entropy loss mean over batch and time, ignore positions where label == -100
                if labels not set, returns None
            start_logits (`Variable` of shape [batch_size, hidden_size]):
                output logits of start position, use argmax(start_logit) to get start index
            end_logits (`Variable` of shape [batch_size, hidden_size]):
                output logits of end position, use argmax(end_logit) to get end index
        """

        start_pos = kwargs.pop('start_pos', None)
        end_pos = kwargs.pop('end_pos', None)
M
Meiyim 已提交
570 571
        pooled, encoded = super(ErnieModelForQuestionAnswering, self).forward(
            *args, **kwargs)
M
Meiyim 已提交
572 573
        encoded = self.dropout(encoded)
        encoded = self.classifier(encoded)
M
Meiyim 已提交
574
        start_logit, end_logits = P.unstack(encoded, axis=-1)
M
Meiyim 已提交
575
        if start_pos is not None and end_pos is not None:
M
Meiyim 已提交
576 577 578 579 580 581 582
            if len(start_pos.shape) != 1:
                start_pos = start_pos.squeeze()
            if len(end_pos.shape) != 1:
                end_pos = end_pos.squeeze()
            start_loss = F.cross_entropy(start_logit, start_pos)
            end_loss = F.cross_entropy(end_logits, end_pos)
            loss = (start_loss.mean() + end_loss.mean()) / 2.
M
Meiyim 已提交
583 584 585 586 587
        else:
            loss = None
        return loss, start_logit, end_logits


M
Meiyim 已提交
588
class NSPHead(nn.Layer):
M
Meiyim 已提交
589 590
    def __init__(self, cfg, name=None):
        super(NSPHead, self).__init__()
M
Meiyim 已提交
591 592 593 594
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
        self.nsp = _build_linear(cfg['hidden_size'], 2,
                                 append_name(name, 'nsp_fc'), initializer)
M
Meiyim 已提交
595 596 597 598

    def forward(self, inputs, labels):
        """
        Args:
M
Meiyim 已提交
599
            start_pos (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
600
                token index of start of answer span in `context`
M
Meiyim 已提交
601
            end_pos (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
602 603 604 605 606 607 608 609 610 611 612 613
                token index of end of answer span in `context`
        Returns:
            loss (`Variable` of shape []):
                Cross entropy loss mean over batch and time, ignore positions where label == -100
                if labels not set, returns None
            start_logits (`Variable` of shape [batch_size, hidden_size]):
                output logits of start position
            end_logits (`Variable` of shape [batch_size, hidden_size]):
                output logits of end position
        """

        logits = self.nsp(inputs)
M
Meiyim 已提交
614
        loss = F.cross_entropy(logits, labels)
M
Meiyim 已提交
615 616 617 618 619 620 621
        return loss


class ErnieModelForPretraining(ErnieModel):
    """
    Ernie Model for Masked Languate Model pretrain
    """
M
Meiyim 已提交
622

M
Meiyim 已提交
623 624
    def __init__(self, cfg, name=None):
        super(ErnieModelForPretraining, self).__init__(cfg, name=name)
M
Meiyim 已提交
625 626
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
M
Meiyim 已提交
627 628 629
        d_model = cfg['hidden_size']
        d_vocab = cfg['vocab_size']

M
Meiyim 已提交
630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
        self.pooler_heads = nn.LayerList([NSPHead(cfg, name=name)])
        self.mlm = _build_linear(
            d_model,
            d_model,
            append_name(name, 'mask_lm_trans_fc'),
            initializer, )
        self.act = ACT_DICT[cfg['hidden_act']]()
        self.mlm_ln = _build_ln(
            d_model, name=append_name(name, 'mask_lm_trans'))
        self.mlm_bias = P.create_parameter(
            dtype='float32',
            shape=[d_vocab],
            attr=P.ParamAttr(
                name=append_name(name, 'mask_lm_out_fc.b_0'),
                initializer=nn.initializer.Constant(value=0.0)),
            is_bias=True, )
        self.train()
M
Meiyim 已提交
647

M
Meiyim 已提交
648
    @add_docstring(ErnieModel.forward.__doc__)
M
Meiyim 已提交
649 650 651
    def forward(self, *args, **kwargs):
        """
        Args:
M
Meiyim 已提交
652
            nsp_labels (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
653
                labels for `next sentence prediction` tasks
M
Meiyim 已提交
654
            mlm_pos (optional, `Variable` of shape [n_mask, 2]):
M
Meiyim 已提交
655
                index of mask_id in `src_ids`, can be obtained from `fluid.layers.where(src_ids==mask_id)`
M
Meiyim 已提交
656
            labels (optional, `Variable` of shape [n_mask]):
M
Meiyim 已提交
657 658 659 660 661 662 663 664 665 666 667 668 669
                labels for `mask language model` tasks, the original token indices in masked position in `src_ids`
        Returns:
            loss (`Variable` of shape []):
                total_loss of `next sentence prediction` and `masked language model`
            mlm_loss (`Variable` of shape []):
                loss for `masked language model` task
            nsp_loss (`Variable` of shape []):
                loss for `next sentence prediction` task
        """

        mlm_labels = kwargs.pop('labels')
        mlm_pos = kwargs.pop('mlm_pos')
        nsp_labels = kwargs.pop('nsp_labels')
M
Meiyim 已提交
670 671 672 673
        pooled, encoded = super(ErnieModelForPretraining, self).forward(
            *args, **kwargs)
        if len(mlm_labels.shape) != 1:
            mlm_labels = mlm_labels.squeeze()
M
Meiyim 已提交
674
        if len(nsp_labels.shape) == 1:
M
Meiyim 已提交
675
            nsp_labels = nsp_labels.squeeze()
M
Meiyim 已提交
676 677 678

        nsp_loss = self.pooler_heads[0](pooled, nsp_labels)

M
Meiyim 已提交
679 680
        encoded_2d = encoded.gather_nd(mlm_pos)
        encoded_2d = self.act(self.mlm(encoded_2d))
M
Meiyim 已提交
681
        encoded_2d = self.mlm_ln(encoded_2d)
M
Meiyim 已提交
682 683 684
        logits_2d = encoded_2d.matmul(
            self.word_emb.weight, transpose_y=True) + self.mlm_bias
        mlm_loss = F.cross_entropy(logits_2d, mlm_labels)
M
Meiyim 已提交
685 686 687
        total_loss = mlm_loss + nsp_loss
        return total_loss, mlm_loss, nsp_loss

M
Meiyim 已提交
688 689 690 691 692 693

class ErnieModelForGeneration(ErnieModel):
    """
    Ernie Model for sequence to sequence generation.
    """
    resource_map = {
M
Meiyim 已提交
694 695 696 697 698 699
        'ernie-gen-base-en':
        ErnieModel.bce + 'model-ernie-gen-base-en.1.tar.gz',
        'ernie-gen-large-en':
        ErnieModel.bce + 'model-ernie-gen-large-en.1.tar.gz',
        'ernie-gen-large-430g-en':
        ErnieModel.bce + 'model-ernie-gen-large-430g-en.1.tar.gz',
M
Meiyim 已提交
700 701
        'ernie-1.0': ErnieModel.bce + 'model-ernie1.0.1.tar.gz',
    }
M
Meiyim 已提交
702

M
Meiyim 已提交
703 704 705 706
    def __init__(self, cfg, name=None):
        cfg['return_additional_info'] = True
        cfg['has_pooler'] = False
        super(ErnieModelForGeneration, self).__init__(cfg, name=name)
M
Meiyim 已提交
707 708
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
M
Meiyim 已提交
709 710 711
        d_model = cfg['hidden_size']
        d_vocab = cfg['vocab_size']

M
Meiyim 已提交
712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727
        self.mlm = _build_linear(
            d_model,
            d_model,
            append_name(name, 'mask_lm_trans_fc'),
            initializer, )
        self.act = ACT_DICT[cfg['hidden_act']]()
        self.mlm_ln = _build_ln(
            d_model, name=append_name(name, 'mask_lm_trans'))
        self.mlm_bias = P.create_parameter(
            dtype='float32',
            shape=[d_vocab],
            attr=P.ParamAttr(
                name=append_name(name, 'mask_lm_out_fc.b_0'),
                initializer=nn.initializer.Constant(value=0.0)),
            is_bias=True, )
        self.train()
M
Meiyim 已提交
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751

    @add_docstring(ErnieModel.forward.__doc__)
    def forward(self, *args, **kwargs):
        """
        Args
            tgt_labels(`Variable` of shape [batch_size, seqlen] or [batch, seqlen, vocab_size]):
                ground trouth target sequence id (hard label) or distribution (soft label)
            tgt_pos(`Variable` of shape [n_targets, 2]):
                index of tgt_labels in `src_ids`, can be obtained from `fluid.layers.where(src_ids==mask_id)`
            encoder_only(Bool):
                if set, will not return loss, logits_2d
        Returns:
            loss(`Variable` of shape []):
                cross entropy loss mean over every target label. if `encode_only`, returns None.
            logits(`Variable` of shape [n_targets, vocab_size]):
                logits for every targets. if `encode_only`, returns None.
            info(Dictionary): see `ErnieModel`
        """
        tgt_labels = kwargs.pop('tgt_labels', None)
        tgt_pos = kwargs.pop('tgt_pos', None)
        encode_only = kwargs.pop('encode_only', False)
        _, encoded, info = ErnieModel.forward(self, *args, **kwargs)
        if encode_only:
            return None, None, info
M
Meiyim 已提交
752 753
        if tgt_labels is None or tgt_pos is None:
            encoded = self.act(self.mlm(encoded))
M
Meiyim 已提交
754
            encoded = self.mlm_ln(encoded)
M
Meiyim 已提交
755 756 757
            logits = encoded.matmul(
                self.word_emb.weight, transpose_y=True) + self.mlm_bias
            output_ids = logits.cast('float32').argmax(-1)
M
Meiyim 已提交
758 759
            return output_ids, logits, info
        else:
M
Meiyim 已提交
760 761
            encoded_2d = encoded.gather_nd(tgt_pos)
            encoded_2d = self.act(self.mlm(encoded_2d))
M
Meiyim 已提交
762
            encoded_2d = self.mlm_ln(encoded_2d)
M
Meiyim 已提交
763 764 765 766
            logits_2d = encoded_2d.matmul(
                self.word_emb.weight, transpose_y=True) + self.mlm_bias
            assert len(
                tgt_labels.shape) == 2, 'expect 2d label, got %r' % tgt_labels
M
Meiyim 已提交
767

M
Meiyim 已提交
768 769
            loss = F.cross_entropy(logits_2d, tgt_labels, soft_label=True)
            return loss, logits_2d, info