modeling_ernie.py 32.1 KB
Newer Older
M
Meiyim 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals

import json
import logging
Z
zhanghan17 已提交
22
import math
M
Meiyim 已提交
23 24 25 26 27
import six
if six.PY2:
    from pathlib2 import Path
else:
    from pathlib import Path
Z
zhanghan17 已提交
28
import numpy as np
C
chenxuyi 已提交
29 30 31
import paddle as P
from paddle import nn
from paddle.nn import functional as F
M
Meiyim 已提交
32
from ernie.file_utils import _fetch_from_remote, add_docstring
M
Meiyim 已提交
33 34 35

log = logging.getLogger(__name__)

C
chenxuyi 已提交
36 37 38 39
ACT_DICT = {
    'relu': nn.ReLU,
    'gelu': nn.GELU,
}
Z
zhanghan17 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
def _get_rel_pos_bias(seq_len, max_len=128, num_buckets=32, bidirectional=True, reset=True):
    #max_len = 520
    pos = np.array(range(seq_len))
    rel_pos = pos[:, None] - pos[None, :]
    ret = 0
    n = -rel_pos
    if bidirectional:
        num_buckets //= 2
        ret += (n < 0).astype('int32') * num_buckets  # mtf.to_int32(mtf.less(n, 0)) * num_buckets
        n = np.abs(n)
    else:
        n = np.max(n, np.zeros_like(n))
    # now n is in the range [0, inf)

    # half of the buckets are for exact increments in positions
    max_exact = num_buckets // 2
    is_small = n < max_exact
    # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
    val_if_large = max_exact + (np.log(n.astype('float32') / max_exact) / math.log(max_len / max_exact) * (num_buckets - max_exact)).astype('int32')
    tmp = np.full_like(val_if_large, num_buckets-1)
    val_if_large = np.where(val_if_large < tmp, val_if_large, tmp)

    ret += np.where(is_small, n, val_if_large)
    if reset:
        num_buckets *= 2
        ret[:, 0] = num_buckets
        ret[0, :] = num_buckets // 2
C
chenxuyi 已提交
67

Z
zhanghan17 已提交
68
    return np.array(ret).reshape([seq_len, seq_len]).astype("int64")
C
chenxuyi 已提交
69 70 71 72 73 74 75 76 77

def _build_linear(n_in, n_out, name, init):
    return nn.Linear(
        n_in,
        n_out,
        weight_attr=P.ParamAttr(
            name='%s.w_0' % name if name is not None else None,
            initializer=init),
        bias_attr='%s.b_0' % name if name is not None else None, )
M
Meiyim 已提交
78 79 80


def _build_ln(n_in, name):
C
chenxuyi 已提交
81 82 83 84 85 86 87 88
    return nn.LayerNorm(
        normalized_shape=n_in,
        weight_attr=P.ParamAttr(
            name='%s_layer_norm_scale' % name if name is not None else None,
            initializer=nn.initializer.Constant(1.)),
        bias_attr=P.ParamAttr(
            name='%s_layer_norm_bias' % name if name is not None else None,
            initializer=nn.initializer.Constant(0.)), )
M
Meiyim 已提交
89 90 91 92


def append_name(name, postfix):
    if name is None:
C
chenxuyi 已提交
93
        ret = None
M
Meiyim 已提交
94
    elif name == '':
C
chenxuyi 已提交
95
        ret = postfix
M
Meiyim 已提交
96
    else:
C
chenxuyi 已提交
97 98
        ret = '%s_%s' % (name, postfix)
    return ret
M
Meiyim 已提交
99 100


C
chenxuyi 已提交
101
class AttentionLayer(nn.Layer):
M
Meiyim 已提交
102 103
    def __init__(self, cfg, name=None):
        super(AttentionLayer, self).__init__()
C
chenxuyi 已提交
104 105
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
M
Meiyim 已提交
106 107 108
        d_model = cfg['hidden_size']
        n_head = cfg['num_attention_heads']
        assert d_model % n_head == 0
C
chenxuyi 已提交
109 110 111 112
        d_model_q = cfg.get('query_hidden_size_per_head',
                            d_model // n_head) * n_head
        d_model_v = cfg.get('value_hidden_size_per_head',
                            d_model // n_head) * n_head
M
Meiyim 已提交
113 114
        self.n_head = n_head
        self.d_key = d_model_q // n_head
C
chenxuyi 已提交
115 116 117 118 119 120 121 122 123
        self.q = _build_linear(d_model, d_model_q,
                               append_name(name, 'query_fc'), initializer)
        self.k = _build_linear(d_model, d_model_q,
                               append_name(name, 'key_fc'), initializer)
        self.v = _build_linear(d_model, d_model_v,
                               append_name(name, 'value_fc'), initializer)
        self.o = _build_linear(d_model_v, d_model,
                               append_name(name, 'output_fc'), initializer)
        self.dropout = nn.Dropout(p=cfg['attention_probs_dropout_prob'])
M
Meiyim 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138

    def forward(self, queries, keys, values, attn_bias, past_cache):
        assert len(queries.shape) == len(keys.shape) == len(values.shape) == 3
        #bsz, q_len, q_dim = queries.shape
        #bsz, k_len, k_dim = keys.shape
        #bsz, v_len, v_dim = values.shape
        #assert k_len == v_len

        q = self.q(queries)
        k = self.k(keys)
        v = self.v(values)

        cache = (k, v)
        if past_cache is not None:
            cached_k, cached_v = past_cache
C
chenxuyi 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
            k = P.concat([cached_k, k], 1)
            v = P.concat([cached_v, v], 1)

        q = q.reshape(
            [0, 0, self.n_head, q.shape[-1] // self.n_head]).transpose(
                [0, 2, 1, 3])  #[batch, head, seq, dim]
        k = k.reshape(
            [0, 0, self.n_head, k.shape[-1] // self.n_head]).transpose(
                [0, 2, 1, 3])  #[batch, head, seq, dim]
        v = v.reshape(
            [0, 0, self.n_head, v.shape[-1] // self.n_head]).transpose(
                [0, 2, 1, 3])  #[batch, head, seq, dim]

        q = q.scale(self.d_key**-0.5)
        score = q.matmul(k, transpose_y=True)
M
Meiyim 已提交
154 155
        if attn_bias is not None:
            score += attn_bias
C
chenxuyi 已提交
156
        score = F.softmax(score)
M
Meiyim 已提交
157 158
        score = self.dropout(score)

C
chenxuyi 已提交
159 160
        out = score.matmul(v).transpose([0, 2, 1, 3])
        out = out.reshape([0, 0, out.shape[2] * out.shape[3]])
M
Meiyim 已提交
161 162 163 164
        out = self.o(out)
        return out, cache


C
chenxuyi 已提交
165
class PositionwiseFeedForwardLayer(nn.Layer):
M
Meiyim 已提交
166 167
    def __init__(self, cfg, name=None):
        super(PositionwiseFeedForwardLayer, self).__init__()
C
chenxuyi 已提交
168 169
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
M
Meiyim 已提交
170 171
        d_model = cfg['hidden_size']
        d_ffn = cfg.get('intermediate_size', 4 * d_model)
C
chenxuyi 已提交
172 173 174 175 176 177 178 179
        self.act = ACT_DICT[cfg['hidden_act']]()
        self.i = _build_linear(
            d_model,
            d_ffn,
            append_name(name, 'fc_0'),
            initializer, )
        self.o = _build_linear(d_ffn, d_model,
                               append_name(name, 'fc_1'), initializer)
M
Meiyim 已提交
180
        prob = cfg.get('intermediate_dropout_prob', 0.)
C
chenxuyi 已提交
181
        self.dropout = nn.Dropout(p=prob)
M
Meiyim 已提交
182 183

    def forward(self, inputs):
C
chenxuyi 已提交
184
        hidden = self.act(self.i(inputs))
M
Meiyim 已提交
185 186 187 188 189
        hidden = self.dropout(hidden)
        out = self.o(hidden)
        return out


C
chenxuyi 已提交
190
class ErnieBlock(nn.Layer):
M
Meiyim 已提交
191 192 193
    def __init__(self, cfg, name=None):
        super(ErnieBlock, self).__init__()
        d_model = cfg['hidden_size']
C
chenxuyi 已提交
194 195 196 197 198 199
        self.attn = AttentionLayer(
            cfg, name=append_name(name, 'multi_head_att'))
        self.ln1 = _build_ln(d_model, name=append_name(name, 'post_att'))
        self.ffn = PositionwiseFeedForwardLayer(
            cfg, name=append_name(name, 'ffn'))
        self.ln2 = _build_ln(d_model, name=append_name(name, 'post_ffn'))
M
Meiyim 已提交
200
        prob = cfg.get('intermediate_dropout_prob', cfg['hidden_dropout_prob'])
C
chenxuyi 已提交
201
        self.dropout = nn.Dropout(p=prob)
M
Meiyim 已提交
202 203

    def forward(self, inputs, attn_bias=None, past_cache=None):
C
chenxuyi 已提交
204 205 206
        attn_out, cache = self.attn(
            inputs, inputs, inputs, attn_bias,
            past_cache=past_cache)  #self attn
M
Meiyim 已提交
207
        attn_out = self.dropout(attn_out)
C
chenxuyi 已提交
208 209
        hidden = attn_out + inputs
        hidden = self.ln1(hidden)  # dropout/ add/ norm
M
Meiyim 已提交
210 211 212 213 214 215 216

        ffn_out = self.ffn(hidden)
        ffn_out = self.dropout(ffn_out)
        hidden = ffn_out + hidden
        hidden = self.ln2(hidden)
        return hidden, cache

C
chenxuyi 已提交
217 218

class ErnieEncoderStack(nn.Layer):
M
Meiyim 已提交
219 220 221
    def __init__(self, cfg, name=None):
        super(ErnieEncoderStack, self).__init__()
        n_layers = cfg['num_hidden_layers']
C
chenxuyi 已提交
222 223 224 225
        self.block = nn.LayerList([
            ErnieBlock(cfg, append_name(name, 'layer_%d' % i))
            for i in range(n_layers)
        ])
M
Meiyim 已提交
226 227 228

    def forward(self, inputs, attn_bias=None, past_cache=None):
        if past_cache is not None:
C
chenxuyi 已提交
229 230 231 232
            assert isinstance(
                past_cache, tuple
            ), 'unknown type of `past_cache`, expect tuple or list. got %s' % repr(
                type(past_cache))
M
Meiyim 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
            past_cache = list(zip(*past_cache))
        else:
            past_cache = [None] * len(self.block)
        cache_list_k, cache_list_v, hidden_list = [], [], [inputs]

        for b, p in zip(self.block, past_cache):
            inputs, cache = b(inputs, attn_bias=attn_bias, past_cache=p)
            cache_k, cache_v = cache
            cache_list_k.append(cache_k)
            cache_list_v.append(cache_v)
            hidden_list.append(inputs)

        return inputs, hidden_list, (cache_list_k, cache_list_v)


class PretrainedModel(object):
    bce = 'https://ernie-github.cdn.bcebos.com/'
    resource_map = {
        'ernie-1.0': bce + 'model-ernie1.0.1.tar.gz',
        'ernie-2.0-en': bce + 'model-ernie2.0-en.1.tar.gz',
C
chenxuyi 已提交
253
        'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz',
M
Meiyim 已提交
254
        'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz',
Z
zhanghan17 已提交
255 256
        'ernie-gram-zh': bce + 'model-ernie-gram-zh.1.tar.gz',
        'ernie-gram-en': bce + 'model-ernie-gram-en.1.tar.gz',
M
Meiyim 已提交
257
    }
C
chenxuyi 已提交
258

M
Meiyim 已提交
259
    @classmethod
C
chenxuyi 已提交
260 261 262 263 264 265 266
    def from_pretrained(cls,
                        pretrain_dir_or_url,
                        force_download=False,
                        **kwargs):
        if not Path(pretrain_dir_or_url).exists() and str(
                pretrain_dir_or_url) in cls.resource_map:
            url = cls.resource_map[str(pretrain_dir_or_url)]
M
Meiyim 已提交
267 268 269
            log.info('get pretrain dir from %s' % url)
            pretrain_dir = _fetch_from_remote(url, force_download)
        else:
C
chenxuyi 已提交
270 271
            log.info('pretrain dir %s not in %s, read from local' %
                     (pretrain_dir_or_url, repr(cls.resource_map)))
W
Weiyue Su 已提交
272
            pretrain_dir = Path(pretrain_dir_or_url)
M
Meiyim 已提交
273

M
Meiyim 已提交
274
        if not pretrain_dir.exists():
M
Meiyim 已提交
275
            raise ValueError('pretrain dir not found: %s' % pretrain_dir)
C
chenxuyi 已提交
276
        state_dict_path = pretrain_dir / 'saved_weights.pdparams'
M
Meiyim 已提交
277
        config_path = pretrain_dir / 'ernie_config.json'
M
Meiyim 已提交
278

M
Meiyim 已提交
279
        if not config_path.exists():
M
Meiyim 已提交
280
            raise ValueError('config path not found: %s' % config_path)
C
chenxuyi 已提交
281
        name_prefix = kwargs.pop('name', None)
M
Meiyim 已提交
282
        cfg_dict = dict(json.loads(config_path.open().read()), **kwargs)
M
Meiyim 已提交
283
        model = cls(cfg_dict, name=name_prefix)
C
chenxuyi 已提交
284

M
Meiyim 已提交
285 286
        log.info('loading pretrained model from %s' % pretrain_dir)

C
chenxuyi 已提交
287
        #param_path = pretrain_dir / 'params'
M
Meiyim 已提交
288 289 290 291
        #if os.path.exists(param_path):
        #    raise NotImplementedError()
        #    log.debug('load pretrained weight from program state')
        #    F.io.load_program_state(param_path) #buggy in dygraph.gurad, push paddle to fix
C
chenxuyi 已提交
292 293
        if state_dict_path.exists():
            m = P.load(state_dict_path)
M
Meiyim 已提交
294 295 296
            for k, v in model.state_dict().items():
                if k not in m:
                    log.warn('param:%s not set in pretrained model, skip' % k)
C
chenxuyi 已提交
297 298
                    m[k] = v  # FIXME: no need to do this in the future
            model.set_state_dict(m)
M
Meiyim 已提交
299
        else:
C
chenxuyi 已提交
300 301
            raise ValueError('weight file not found in pretrain dir: %s' %
                             pretrain_dir)
M
Meiyim 已提交
302 303 304
        return model


C
chenxuyi 已提交
305
class ErnieModel(nn.Layer, PretrainedModel):
M
Meiyim 已提交
306 307 308 309 310
    def __init__(self, cfg, name=None):
        """
        Fundamental pretrained Ernie model
        """
        log.debug('init ErnieModel with config: %s' % repr(cfg))
C
chenxuyi 已提交
311
        nn.Layer.__init__(self)
M
Meiyim 已提交
312 313 314 315 316
        d_model = cfg['hidden_size']
        d_emb = cfg.get('emb_size', cfg['hidden_size'])
        d_vocab = cfg['vocab_size']
        d_pos = cfg['max_position_embeddings']
        d_sent = cfg.get("sent_type_vocab_size") or cfg['type_vocab_size']
Z
zhanghan17 已提交
317 318
        self.d_rel_pos = cfg.get('rel_pos_size', None)
        max_seq_len = cfg.get("max_seq_len", 512)
M
Meiyim 已提交
319 320
        self.n_head = cfg['num_attention_heads']
        self.return_additional_info = cfg.get('return_additional_info', False)
C
chenxuyi 已提交
321 322
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
Z
zhanghan17 已提交
323 324
        if self.d_rel_pos:
            self.rel_pos_bias = _get_rel_pos_bias(max_seq_len) 
M
Meiyim 已提交
325 326

        self.ln = _build_ln(d_model, name=append_name(name, 'pre_encoder'))
C
chenxuyi 已提交
327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
        self.word_emb = nn.Embedding(
            d_vocab,
            d_emb,
            weight_attr=P.ParamAttr(
                name=append_name(name, 'word_embedding'),
                initializer=initializer))
        self.pos_emb = nn.Embedding(
            d_pos,
            d_emb,
            weight_attr=P.ParamAttr(
                name=append_name(name, 'pos_embedding'),
                initializer=initializer))
        self.sent_emb = nn.Embedding(
            d_sent,
            d_emb,
            weight_attr=P.ParamAttr(
                name=append_name(name, 'sent_embedding'),
                initializer=initializer))
Z
zhanghan17 已提交
345 346 347 348 349 350 351
        if self.d_rel_pos:
            self.rel_pos_bias_emb = nn.Embedding(
                self.d_rel_pos,
                self.n_head,
                weight_attr=P.ParamAttr(
                    name=append_name(name, 'rel_pos_embedding'),
                    initializer=initializer))
M
Meiyim 已提交
352
        prob = cfg['hidden_dropout_prob']
C
chenxuyi 已提交
353
        self.dropout = nn.Dropout(p=prob)
M
Meiyim 已提交
354

C
chenxuyi 已提交
355 356
        self.encoder_stack = ErnieEncoderStack(cfg,
                                               append_name(name, 'encoder'))
M
Meiyim 已提交
357
        if cfg.get('has_pooler', True):
C
chenxuyi 已提交
358 359 360 361 362
            self.pooler = _build_linear(
                cfg['hidden_size'],
                cfg['hidden_size'],
                append_name(name, 'pooled_fc'),
                initializer, )
M
Meiyim 已提交
363 364 365 366
        else:
            self.pooler = None
        self.train()

C
chenxuyi 已提交
367
    #FIXME:remove this
M
Meiyim 已提交
368
    def eval(self):
C
chenxuyi 已提交
369
        if P.in_dynamic_mode():
M
Meiyim 已提交
370 371 372 373
            super(ErnieModel, self).eval()
        self.training = False
        for l in self.sublayers():
            l.training = False
C
chenxuyi 已提交
374
        return self
M
Meiyim 已提交
375 376

    def train(self):
C
chenxuyi 已提交
377
        if P.in_dynamic_mode():
M
Meiyim 已提交
378 379 380 381
            super(ErnieModel, self).train()
        self.training = True
        for l in self.sublayers():
            l.training = True
C
chenxuyi 已提交
382 383 384 385 386 387 388 389 390 391
        return self

    def forward(self,
                src_ids,
                sent_ids=None,
                pos_ids=None,
                input_mask=None,
                attn_bias=None,
                past_cache=None,
                use_causal_mask=False):
Z
zhanghan17 已提交
392
                
M
Meiyim 已提交
393 394
        """
        Args:
C
chenxuyi 已提交
395
            src_ids (`Variable` of shape `[batch_size, seq_len]`):
M
Meiyim 已提交
396
                Indices of input sequence tokens in the vocabulary.
C
chenxuyi 已提交
397
            sent_ids (optional, `Variable` of shape `[batch_size, seq_len]`):
M
Meiyim 已提交
398 399
                aka token_type_ids, Segment token indices to indicate first and second portions of the inputs.
                if None, assume all tokens come from `segment_a`
C
chenxuyi 已提交
400
            pos_ids(optional, `Variable` of shape `[batch_size, seq_len]`):
M
Meiyim 已提交
401
                Indices of positions of each input sequence tokens in the position embeddings.
C
chenxuyi 已提交
402
            input_mask(optional `Variable` of shape `[batch_size, seq_len]`):
M
Meiyim 已提交
403
                Mask to avoid performing attention on the padding token indices of the encoder input.
C
chenxuyi 已提交
404
            attn_bias(optional, `Variable` of shape `[batch_size, seq_len, seq_len] or False`):
M
Meiyim 已提交
405
                3D version of `input_mask`, if set, overrides `input_mask`; if set not False, will not apply attention mask
M
Meiyim 已提交
406 407
            past_cache(optional, tuple of two lists: cached key and cached value,
                each is a list of `Variable`s of shape `[batch_size, seq_len, hidden_size]`):
C
chenxuyi 已提交
408
                cached key/value tensor that will be concated to generated key/value when performing self attention.
M
Meiyim 已提交
409 410
                if set, `attn_bias` should not be None.

M
Meiyim 已提交
411 412 413 414 415
        Returns:
            pooled (`Variable` of shape `[batch_size, hidden_size]`):
                output logits of pooler classifier
            encoded(`Variable` of shape `[batch_size, seq_len, hidden_size]`):
                output logits of transformer stack
M
Meiyim 已提交
416 417
            info (Dictionary):
                addtional middle level info, inclues: all hidden stats, k/v caches.
M
Meiyim 已提交
418
        """
C
chenxuyi 已提交
419 420 421 422
        assert len(
            src_ids.
            shape) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % (
                repr(src_ids.shape))
M
Meiyim 已提交
423
        assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None'
C
chenxuyi 已提交
424
        d_seqlen = P.shape(src_ids)[1]
M
Meiyim 已提交
425
        if pos_ids is None:
C
chenxuyi 已提交
426 427
            pos_ids = P.arange(
                0, d_seqlen, 1, dtype='int32').reshape([1, -1]).cast('int64')
M
Meiyim 已提交
428 429
        if attn_bias is None:
            if input_mask is None:
C
chenxuyi 已提交
430
                input_mask = P.cast(src_ids != 0, 'float32')
M
Meiyim 已提交
431
            assert len(input_mask.shape) == 2
C
chenxuyi 已提交
432 433
            input_mask = input_mask.unsqueeze(-1)
            attn_bias = input_mask.matmul(input_mask, transpose_y=True)
M
Meiyim 已提交
434
            if use_causal_mask:
C
chenxuyi 已提交
435 436 437 438 439
                sequence = P.reshape(
                    P.arange(
                        0, d_seqlen, 1, dtype='float32') + 1., [1, 1, -1, 1])
                causal_mask = (sequence.matmul(
                    1. / sequence, transpose_y=True) >= 1.).cast('float32')
M
Meiyim 已提交
440 441
                attn_bias *= causal_mask
        else:
C
chenxuyi 已提交
442 443 444
            assert len(
                attn_bias.shape
            ) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape
M
Meiyim 已提交
445
        attn_bias = (1. - attn_bias) * -10000.0
C
chenxuyi 已提交
446 447
        attn_bias = attn_bias.unsqueeze(1).tile(
            [1, self.n_head, 1, 1])  # avoid broadcast =_=
Z
zhanghan17 已提交
448
        attn_bias.stop_gradient=True
M
Meiyim 已提交
449
        if sent_ids is None:
C
chenxuyi 已提交
450
            sent_ids = P.zeros_like(src_ids)
Z
zhanghan17 已提交
451 452 453 454 455
        if self.d_rel_pos:
            rel_pos_ids = self.rel_pos_bias[:d_seqlen, :d_seqlen]
            rel_pos_ids = P.to_tensor(rel_pos_ids, dtype='int64')
            rel_pos_bias = self.rel_pos_bias_emb(rel_pos_ids).transpose([2, 0, 1])
            attn_bias += rel_pos_bias
M
Meiyim 已提交
456 457 458 459 460
        src_embedded = self.word_emb(src_ids)
        pos_embedded = self.pos_emb(pos_ids)
        sent_embedded = self.sent_emb(sent_ids)
        embedded = src_embedded + pos_embedded + sent_embedded

Z
zhanghan17 已提交
461

M
Meiyim 已提交
462 463
        embedded = self.dropout(self.ln(embedded))

C
chenxuyi 已提交
464 465
        encoded, hidden_list, cache_list = self.encoder_stack(
            embedded, attn_bias, past_cache=past_cache)
M
Meiyim 已提交
466
        if self.pooler is not None:
C
chenxuyi 已提交
467
            pooled = F.tanh(self.pooler(encoded[:, 0, :]))
M
Meiyim 已提交
468 469 470 471 472 473 474 475 476 477
        else:
            pooled = None

        additional_info = {
            'hiddens': hidden_list,
            'caches': cache_list,
        }

        if self.return_additional_info:
            return pooled, encoded, additional_info
C
chenxuyi 已提交
478 479
        return pooled, encoded

M
Meiyim 已提交
480 481 482

class ErnieModelForSequenceClassification(ErnieModel):
    """
C
chenxuyi 已提交
483
    Ernie Model for text classfication or pointwise ranking tasks
M
Meiyim 已提交
484 485 486
    """

    def __init__(self, cfg, name=None):
C
chenxuyi 已提交
487 488
        super(ErnieModelForSequenceClassification, self).__init__(
            cfg, name=name)
M
Meiyim 已提交
489

C
chenxuyi 已提交
490 491 492 493
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
        self.classifier = _build_linear(cfg['hidden_size'], cfg['num_labels'],
                                        append_name(name, 'cls'), initializer)
M
Meiyim 已提交
494 495

        prob = cfg.get('classifier_dropout_prob', cfg['hidden_dropout_prob'])
C
chenxuyi 已提交
496 497
        self.dropout = nn.Dropout(p=prob)
        self.train()
M
Meiyim 已提交
498

M
Meiyim 已提交
499
    @add_docstring(ErnieModel.forward.__doc__)
M
Meiyim 已提交
500 501 502
    def forward(self, *args, **kwargs):
        """
        Args:
C
chenxuyi 已提交
503
            labels (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
504 505 506 507 508 509 510 511 512
                ground truth label id for each sentence
        Returns:
            loss (`Variable` of shape []):
                Cross entropy loss mean over batch
                if labels not set, returns None
            logits (`Variable` of shape [batch_size, hidden_size]):
                output logits of classifier
        """
        labels = kwargs.pop('labels', None)
C
chenxuyi 已提交
513 514
        pooled, encoded = super(ErnieModelForSequenceClassification,
                                self).forward(*args, **kwargs)
M
Meiyim 已提交
515 516 517 518
        hidden = self.dropout(pooled)
        logits = self.classifier(hidden)

        if labels is not None:
C
chenxuyi 已提交
519 520 521
            if len(labels.shape) != 1:
                labels = labels.squeeze()
            loss = F.cross_entropy(logits, labels)
M
Meiyim 已提交
522 523 524 525 526 527 528 529 530
        else:
            loss = None
        return loss, logits


class ErnieModelForTokenClassification(ErnieModel):
    """
    Ernie Model for Named entity tasks(NER)
    """
C
chenxuyi 已提交
531

M
Meiyim 已提交
532 533 534
    def __init__(self, cfg, name=None):
        super(ErnieModelForTokenClassification, self).__init__(cfg, name=name)

C
chenxuyi 已提交
535 536 537 538
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
        self.classifier = _build_linear(cfg['hidden_size'], cfg['num_labels'],
                                        append_name(name, 'cls'), initializer)
M
Meiyim 已提交
539 540

        prob = cfg.get('classifier_dropout_prob', cfg['hidden_dropout_prob'])
C
chenxuyi 已提交
541 542
        self.dropout = nn.Dropout(p=prob)
        self.train()
M
Meiyim 已提交
543

M
Meiyim 已提交
544
    @add_docstring(ErnieModel.forward.__doc__)
545
    def forward(self, *args, **kwargs):
M
Meiyim 已提交
546 547
        """
        Args:
C
chenxuyi 已提交
548
            labels (optional, `Variable` of shape [batch_size, seq_len]):
M
Meiyim 已提交
549 550 551 552 553 554 555
                ground truth label id for each token
        Returns:
            loss (`Variable` of shape []):
                Cross entropy loss mean over batch and time, ignore positions where label == -100
                if labels not set, returns None
            logits (`Variable` of shape [batch_size, seq_len, hidden_size]):
                output logits of classifier
556 557 558 559
            loss_weights (`Variable` of shape [batch_size, seq_len]):
                weigths of loss for each tokens.
            ignore_index (int):
                when label == `ignore_index`, this token will not contribute to loss
M
Meiyim 已提交
560
        """
561 562 563
        ignore_index = kwargs.pop('ignore_index', -100)
        labels = kwargs.pop('labels', None)
        loss_weights = kwargs.pop('loss_weights', None)
C
chenxuyi 已提交
564 565 566
        pooled, encoded = super(ErnieModelForTokenClassification,
                                self).forward(*args, **kwargs)
        hidden = self.dropout(encoded)  # maybe not?
M
Meiyim 已提交
567 568 569
        logits = self.classifier(hidden)

        if labels is not None:
C
chenxuyi 已提交
570 571 572 573
            if len(labels.shape) != 2:
                labels = labels.squeeze()
            loss = F.cross_entropy(
                logits, labels, ignore_index=ignore_index, reduction='none')
M
Meiyim 已提交
574
            if loss_weights is not None:
C
chenxuyi 已提交
575 576
                loss = loss * loss_weights
            loss = loss.mean()
M
Meiyim 已提交
577 578 579 580 581 582 583 584 585
        else:
            loss = None
        return loss, logits


class ErnieModelForQuestionAnswering(ErnieModel):
    """
    Ernie model for reading comprehension tasks (SQuAD)
    """
C
chenxuyi 已提交
586

M
Meiyim 已提交
587 588 589
    def __init__(self, cfg, name=None):
        super(ErnieModelForQuestionAnswering, self).__init__(cfg, name=name)

C
chenxuyi 已提交
590 591 592 593 594
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
        self.classifier = _build_linear(cfg['hidden_size'], 2,
                                        append_name(name, 'cls_mrc'),
                                        initializer)
M
Meiyim 已提交
595 596

        prob = cfg.get('classifier_dropout_prob', cfg['hidden_dropout_prob'])
C
chenxuyi 已提交
597 598
        self.dropout = nn.Dropout(p=prob)
        self.train()
M
Meiyim 已提交
599

M
Meiyim 已提交
600
    @add_docstring(ErnieModel.forward.__doc__)
M
Meiyim 已提交
601 602 603
    def forward(self, *args, **kwargs):
        """
        Args:
C
chenxuyi 已提交
604
            start_pos (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
605
                token index of start of answer span in `context`
C
chenxuyi 已提交
606
            end_pos (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
607 608 609 610 611 612 613 614 615 616 617 618 619
                token index of end of answer span in `context`
        Returns:
            loss (`Variable` of shape []):
                Cross entropy loss mean over batch and time, ignore positions where label == -100
                if labels not set, returns None
            start_logits (`Variable` of shape [batch_size, hidden_size]):
                output logits of start position, use argmax(start_logit) to get start index
            end_logits (`Variable` of shape [batch_size, hidden_size]):
                output logits of end position, use argmax(end_logit) to get end index
        """

        start_pos = kwargs.pop('start_pos', None)
        end_pos = kwargs.pop('end_pos', None)
C
chenxuyi 已提交
620 621
        pooled, encoded = super(ErnieModelForQuestionAnswering, self).forward(
            *args, **kwargs)
M
Meiyim 已提交
622 623
        encoded = self.dropout(encoded)
        encoded = self.classifier(encoded)
C
chenxuyi 已提交
624
        start_logit, end_logits = P.unstack(encoded, axis=-1)
M
Meiyim 已提交
625
        if start_pos is not None and end_pos is not None:
C
chenxuyi 已提交
626 627 628 629 630 631 632
            if len(start_pos.shape) != 1:
                start_pos = start_pos.squeeze()
            if len(end_pos.shape) != 1:
                end_pos = end_pos.squeeze()
            start_loss = F.cross_entropy(start_logit, start_pos)
            end_loss = F.cross_entropy(end_logits, end_pos)
            loss = (start_loss.mean() + end_loss.mean()) / 2.
M
Meiyim 已提交
633 634 635 636 637
        else:
            loss = None
        return loss, start_logit, end_logits


C
chenxuyi 已提交
638
class NSPHead(nn.Layer):
M
Meiyim 已提交
639 640
    def __init__(self, cfg, name=None):
        super(NSPHead, self).__init__()
C
chenxuyi 已提交
641 642 643 644
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
        self.nsp = _build_linear(cfg['hidden_size'], 2,
                                 append_name(name, 'nsp_fc'), initializer)
M
Meiyim 已提交
645 646 647 648

    def forward(self, inputs, labels):
        """
        Args:
C
chenxuyi 已提交
649
            start_pos (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
650
                token index of start of answer span in `context`
C
chenxuyi 已提交
651
            end_pos (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
652 653 654 655 656 657 658 659 660 661 662 663
                token index of end of answer span in `context`
        Returns:
            loss (`Variable` of shape []):
                Cross entropy loss mean over batch and time, ignore positions where label == -100
                if labels not set, returns None
            start_logits (`Variable` of shape [batch_size, hidden_size]):
                output logits of start position
            end_logits (`Variable` of shape [batch_size, hidden_size]):
                output logits of end position
        """

        logits = self.nsp(inputs)
C
chenxuyi 已提交
664
        loss = F.cross_entropy(logits, labels)
M
Meiyim 已提交
665 666 667 668 669 670 671
        return loss


class ErnieModelForPretraining(ErnieModel):
    """
    Ernie Model for Masked Languate Model pretrain
    """
C
chenxuyi 已提交
672

M
Meiyim 已提交
673 674
    def __init__(self, cfg, name=None):
        super(ErnieModelForPretraining, self).__init__(cfg, name=name)
C
chenxuyi 已提交
675 676
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
M
Meiyim 已提交
677 678 679
        d_model = cfg['hidden_size']
        d_vocab = cfg['vocab_size']

C
chenxuyi 已提交
680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696
        self.pooler_heads = nn.LayerList([NSPHead(cfg, name=name)])
        self.mlm = _build_linear(
            d_model,
            d_model,
            append_name(name, 'mask_lm_trans_fc'),
            initializer, )
        self.act = ACT_DICT[cfg['hidden_act']]()
        self.mlm_ln = _build_ln(
            d_model, name=append_name(name, 'mask_lm_trans'))
        self.mlm_bias = P.create_parameter(
            dtype='float32',
            shape=[d_vocab],
            attr=P.ParamAttr(
                name=append_name(name, 'mask_lm_out_fc.b_0'),
                initializer=nn.initializer.Constant(value=0.0)),
            is_bias=True, )
        self.train()
M
Meiyim 已提交
697

M
Meiyim 已提交
698
    @add_docstring(ErnieModel.forward.__doc__)
M
Meiyim 已提交
699 700 701
    def forward(self, *args, **kwargs):
        """
        Args:
C
chenxuyi 已提交
702
            nsp_labels (optional, `Variable` of shape [batch_size]):
M
Meiyim 已提交
703
                labels for `next sentence prediction` tasks
C
chenxuyi 已提交
704
            mlm_pos (optional, `Variable` of shape [n_mask, 2]):
M
Meiyim 已提交
705
                index of mask_id in `src_ids`, can be obtained from `fluid.layers.where(src_ids==mask_id)`
C
chenxuyi 已提交
706
            labels (optional, `Variable` of shape [n_mask]):
M
Meiyim 已提交
707 708 709 710 711 712 713 714 715 716 717 718 719
                labels for `mask language model` tasks, the original token indices in masked position in `src_ids`
        Returns:
            loss (`Variable` of shape []):
                total_loss of `next sentence prediction` and `masked language model`
            mlm_loss (`Variable` of shape []):
                loss for `masked language model` task
            nsp_loss (`Variable` of shape []):
                loss for `next sentence prediction` task
        """

        mlm_labels = kwargs.pop('labels')
        mlm_pos = kwargs.pop('mlm_pos')
        nsp_labels = kwargs.pop('nsp_labels')
C
chenxuyi 已提交
720 721 722 723
        pooled, encoded = super(ErnieModelForPretraining, self).forward(
            *args, **kwargs)
        if len(mlm_labels.shape) != 1:
            mlm_labels = mlm_labels.squeeze()
M
Meiyim 已提交
724
        if len(nsp_labels.shape) == 1:
C
chenxuyi 已提交
725
            nsp_labels = nsp_labels.squeeze()
M
Meiyim 已提交
726 727 728

        nsp_loss = self.pooler_heads[0](pooled, nsp_labels)

C
chenxuyi 已提交
729 730
        encoded_2d = encoded.gather_nd(mlm_pos)
        encoded_2d = self.act(self.mlm(encoded_2d))
M
Meiyim 已提交
731
        encoded_2d = self.mlm_ln(encoded_2d)
C
chenxuyi 已提交
732 733 734
        logits_2d = encoded_2d.matmul(
            self.word_emb.weight, transpose_y=True) + self.mlm_bias
        mlm_loss = F.cross_entropy(logits_2d, mlm_labels)
M
Meiyim 已提交
735 736 737
        total_loss = mlm_loss + nsp_loss
        return total_loss, mlm_loss, nsp_loss

M
Meiyim 已提交
738 739 740 741 742 743

class ErnieModelForGeneration(ErnieModel):
    """
    Ernie Model for sequence to sequence generation.
    """
    resource_map = {
C
chenxuyi 已提交
744 745 746 747 748 749
        'ernie-gen-base-en':
        ErnieModel.bce + 'model-ernie-gen-base-en.1.tar.gz',
        'ernie-gen-large-en':
        ErnieModel.bce + 'model-ernie-gen-large-en.1.tar.gz',
        'ernie-gen-large-430g-en':
        ErnieModel.bce + 'model-ernie-gen-large-430g-en.1.tar.gz',
M
Meiyim 已提交
750 751
        'ernie-1.0': ErnieModel.bce + 'model-ernie1.0.1.tar.gz',
    }
C
chenxuyi 已提交
752

M
Meiyim 已提交
753 754 755 756
    def __init__(self, cfg, name=None):
        cfg['return_additional_info'] = True
        cfg['has_pooler'] = False
        super(ErnieModelForGeneration, self).__init__(cfg, name=name)
C
chenxuyi 已提交
757 758
        initializer = nn.initializer.TruncatedNormal(
            std=cfg['initializer_range'])
M
Meiyim 已提交
759 760 761
        d_model = cfg['hidden_size']
        d_vocab = cfg['vocab_size']

C
chenxuyi 已提交
762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777
        self.mlm = _build_linear(
            d_model,
            d_model,
            append_name(name, 'mask_lm_trans_fc'),
            initializer, )
        self.act = ACT_DICT[cfg['hidden_act']]()
        self.mlm_ln = _build_ln(
            d_model, name=append_name(name, 'mask_lm_trans'))
        self.mlm_bias = P.create_parameter(
            dtype='float32',
            shape=[d_vocab],
            attr=P.ParamAttr(
                name=append_name(name, 'mask_lm_out_fc.b_0'),
                initializer=nn.initializer.Constant(value=0.0)),
            is_bias=True, )
        self.train()
M
Meiyim 已提交
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801

    @add_docstring(ErnieModel.forward.__doc__)
    def forward(self, *args, **kwargs):
        """
        Args
            tgt_labels(`Variable` of shape [batch_size, seqlen] or [batch, seqlen, vocab_size]):
                ground trouth target sequence id (hard label) or distribution (soft label)
            tgt_pos(`Variable` of shape [n_targets, 2]):
                index of tgt_labels in `src_ids`, can be obtained from `fluid.layers.where(src_ids==mask_id)`
            encoder_only(Bool):
                if set, will not return loss, logits_2d
        Returns:
            loss(`Variable` of shape []):
                cross entropy loss mean over every target label. if `encode_only`, returns None.
            logits(`Variable` of shape [n_targets, vocab_size]):
                logits for every targets. if `encode_only`, returns None.
            info(Dictionary): see `ErnieModel`
        """
        tgt_labels = kwargs.pop('tgt_labels', None)
        tgt_pos = kwargs.pop('tgt_pos', None)
        encode_only = kwargs.pop('encode_only', False)
        _, encoded, info = ErnieModel.forward(self, *args, **kwargs)
        if encode_only:
            return None, None, info
C
chenxuyi 已提交
802 803
        if tgt_labels is None or tgt_pos is None:
            encoded = self.act(self.mlm(encoded))
M
Meiyim 已提交
804
            encoded = self.mlm_ln(encoded)
C
chenxuyi 已提交
805 806 807
            logits = encoded.matmul(
                self.word_emb.weight, transpose_y=True) + self.mlm_bias
            output_ids = logits.cast('float32').argmax(-1)
M
Meiyim 已提交
808 809
            return output_ids, logits, info
        else:
C
chenxuyi 已提交
810 811
            encoded_2d = encoded.gather_nd(tgt_pos)
            encoded_2d = self.act(self.mlm(encoded_2d))
M
Meiyim 已提交
812
            encoded_2d = self.mlm_ln(encoded_2d)
C
chenxuyi 已提交
813 814 815 816
            logits_2d = encoded_2d.matmul(
                self.word_emb.weight, transpose_y=True) + self.mlm_bias
            assert len(
                tgt_labels.shape) == 2, 'expect 2d label, got %r' % tgt_labels
M
Meiyim 已提交
817

C
chenxuyi 已提交
818 819
            loss = F.cross_entropy(logits_2d, tgt_labels, soft_label=True)
            return loss, logits_2d, info