modeling_ernie_gen.py 3.7 KB
Newer Older
M
Meiyim 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals

import sys
import logging
import numpy as np

import paddle.fluid as F
import paddle.fluid.layers as L
import paddle.fluid.dygraph as D

from ernie.modeling_ernie import ErnieModel
from ernie.modeling_ernie import _build_linear, _build_ln, append_name

class ErnieModelForGeneration(ErnieModel):
    resource_map = {
        'ernie-gen-base-en': ErnieModel.bce + 'model-ernie-gen-base-en.1.tar.gz',
        'ernie-gen-large-en': ErnieModel.bce + 'model-ernie-gen-large-en.1.tar.gz',
M
Meiyim 已提交
35
        'ernie-gen-large-160g-en': ErnieModel.bce + 'model-ernie-gen-large-160g-en.1.tar.gz',
M
Meiyim 已提交
36
        'ernie-1.0': ErnieModel.bce + 'model-ernie1.0.1.tar.gz',
M
Meiyim 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    }
    def __init__(self, cfg, name=None):
        cfg['return_additional_info'] = True
        cfg['has_pooler'] = False
        super(ErnieModelForGeneration, self).__init__(cfg, name=name)
        initializer = F.initializer.TruncatedNormal(scale=cfg['initializer_range'])
        d_model = cfg['hidden_size']
        d_vocab = cfg['vocab_size']

        self.mlm = _build_linear(d_model, d_model, append_name(name, 'mask_lm_trans_fc'), initializer, act=cfg['hidden_act'])
        self.mlm_ln = _build_ln(d_model, name = append_name(name, 'mask_lm_trans'))
        self.mlm_bias = L.create_parameter(
                dtype='float32',
                shape=[d_vocab], 
                attr=F.ParamAttr(
                    name=append_name(name, 'mask_lm_out_fc.b_0'), 
                    initializer=F.initializer.Constant(value=0.0)
                    ),
                is_bias=True,
            )

    def forward(self, src_ids, *args, **kwargs):
        tgt_labels = kwargs.pop('tgt_labels', None)
        tgt_pos = kwargs.pop('tgt_pos', None)
        encode_only = kwargs.pop('encode_only', False)
        _, encoded, info = ErnieModel.forward(self, src_ids, *args, **kwargs)
        #log.debug('hidden_-1 %r'% L.reduce_mean(info['hiddens'][0]).numpy())
        #log.debug('hidden_0 %r'% L.reduce_mean(info['hiddens'][1]).numpy())
        if encode_only:
            return None, None, info
        elif tgt_labels is None:
            encoded = self.mlm(encoded)
            encoded = self.mlm_ln(encoded)
            logits = L.matmul(encoded, self.word_emb.weight, transpose_y=True) + self.mlm_bias
            output_ids = L.argmax(logits, -1)
            return output_ids, logits, info
        else:
            encoded_2d = L.gather_nd(encoded, tgt_pos)
            #log.debug('input shape %s' % repr(src_ids.shape))
            #log.debug(L.gather_nd(src_ids, tgt_pos).numpy())
            encoded_2d = self.mlm(encoded_2d)
            encoded_2d = self.mlm_ln(encoded_2d)
            logits_2d = L.matmul(encoded_2d, self.word_emb.weight, transpose_y=True) + self.mlm_bias
            if len(tgt_labels.shape) == 1:
                tgt_labels = L.reshape(tgt_labels, [-1, 1])
            
            loss = L.reduce_mean(
                    L.softmax_with_cross_entropy(logits_2d, tgt_labels, soft_label=(tgt_labels.shape[-1] != 1))
                    )
            return loss, logits_2d, info