ernie.py 6.8 KB
Newer Older
X
xixiaoyao 已提交
1
# -*- coding: UTF-8 -*-
X
xixiaoyao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ernie model."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import

X
xixiaoyao 已提交
23 24
from paddle import fluid
from paddle.fluid import layers
X
xixiaoyao 已提交
25

X
xixiaoyao 已提交
26 27
from paddlepalm.backbone.utils.transformer import pre_process_layer, encoder
from paddlepalm.interface import backbone
X
xixiaoyao 已提交
28 29


X
xixiaoyao 已提交
30 31
class Model(backbone):

X
xixiaoyao 已提交
32 33
    def __init__(self,
                 config,
X
xixiaoyao 已提交
34 35 36
                 phase):

        # self._is_training = phase == 'train' # backbone一般不用关心运行阶段,因为outputs在任何阶段基本不会变
X
xixiaoyao 已提交
37 38 39 40 41 42 43 44 45 46 47

        self._emb_size = config['hidden_size']
        self._n_layer = config['num_hidden_layers']
        self._n_head = config['num_attention_heads']
        self._voc_size = config['vocab_size']
        self._max_position_seq_len = config['max_position_embeddings']
        if config['sent_type_vocab_size']:
            self._sent_types = config['sent_type_vocab_size']
        else:
            self._sent_types = config['type_vocab_size']

X
xixiaoyao 已提交
48 49
        self._task_types = config['task_type_vocab_size']

X
xixiaoyao 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62
        self._hidden_act = config['hidden_act']
        self._prepostprocess_dropout = config['hidden_dropout_prob']
        self._attention_dropout = config['attention_probs_dropout_prob']

        self._word_emb_name = "word_embedding"
        self._pos_emb_name = "pos_embedding"
        self._sent_emb_name = "sent_embedding"
        self._task_emb_name = "task_embedding"
        self._emb_dtype = "float32"

        self._param_initializer = fluid.initializer.TruncatedNormal(
            scale=config['initializer_range'])

X
xixiaoyao 已提交
63 64 65 66 67 68 69 70 71 72 73
    @property
    def inputs_attr(self):
        return {"token_ids": [[-1, -1, 1], 'int64'],
                "position_ids": [[-1, -1, 1], 'int64'],
                "segment_ids": [[-1, -1, 1], 'int64'],
                "input_mask": [[-1, -1, 1], 'float32'],
                "task_ids": [[-1,-1, 1], 'int64']}

    @property
    def outputs_attr(self):
        return {"word_embedding": [[-1, -1, self._emb_size], 'float32'],
X
xixiaoyao 已提交
74
                "embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'],
X
xixiaoyao 已提交
75 76 77
                "encoder_outputs": [[-1, -1, self._emb_size], 'float32'],
                "sentence_embedding": [[-1, self._emb_size], 'float32'],
                "sentence_pair_embedding": [[-1, self._emb_size], 'float32']}
X
xixiaoyao 已提交
78

X
xixiaoyao 已提交
79
    def build(self, inputs, scope_name=""):
X
xixiaoyao 已提交
80

X
xixiaoyao 已提交
81 82 83 84 85
        src_ids = inputs['token_ids']
        pos_ids = inputs['position_ids']
        sent_ids = inputs['segment_ids']
        input_mask = inputs['input_mask']
        task_ids = inputs['task_ids']
X
xixiaoyao 已提交
86 87 88 89 90 91 92

        # padding id in vocabulary must be set to 0
        emb_out = fluid.layers.embedding(
            input=src_ids,
            size=[self._voc_size, self._emb_size],
            dtype=self._emb_dtype,
            param_attr=fluid.ParamAttr(
X
xixiaoyao 已提交
93
                name=scope_name+self._word_emb_name, initializer=self._param_initializer),
X
xixiaoyao 已提交
94
            is_sparse=False)
X
xixiaoyao 已提交
95 96

        # fluid.global_scope().find_var('backbone-word_embedding').get_tensor()
X
xixiaoyao 已提交
97
        embedding_table = fluid.default_main_program().global_block().var(scope_name+self._word_emb_name)
X
xixiaoyao 已提交
98 99 100 101 102 103
        
        position_emb_out = fluid.layers.embedding(
            input=pos_ids,
            size=[self._max_position_seq_len, self._emb_size],
            dtype=self._emb_dtype,
            param_attr=fluid.ParamAttr(
X
xixiaoyao 已提交
104
                name=scope_name+self._pos_emb_name, initializer=self._param_initializer))
X
xixiaoyao 已提交
105 106 107 108 109 110

        sent_emb_out = fluid.layers.embedding(
            sent_ids,
            size=[self._sent_types, self._emb_size],
            dtype=self._emb_dtype,
            param_attr=fluid.ParamAttr(
X
xixiaoyao 已提交
111
                name=scope_name+self._sent_emb_name, initializer=self._param_initializer))
X
xixiaoyao 已提交
112 113 114 115

        emb_out = emb_out + position_emb_out
        emb_out = emb_out + sent_emb_out

X
xixiaoyao 已提交
116 117 118 119 120
        task_emb_out = fluid.layers.embedding(
            task_ids,
            size=[self._task_types, self._emb_size],
            dtype=self._emb_dtype,
            param_attr=fluid.ParamAttr(
X
xixiaoyao 已提交
121
                name=scope_name+self._task_emb_name,
X
xixiaoyao 已提交
122 123 124 125 126
                initializer=self._param_initializer))

        emb_out = emb_out + task_emb_out

        emb_out = pre_process_layer(
X
xixiaoyao 已提交
127
            emb_out, 'nd', self._prepostprocess_dropout, name=scope_name+'pre_encoder')
X
xixiaoyao 已提交
128 129 130 131 132 133 134 135 136 137

        self_attn_mask = fluid.layers.matmul(
            x=input_mask, y=input_mask, transpose_y=True)

        self_attn_mask = fluid.layers.scale(
            x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False)
        n_head_self_attn_mask = fluid.layers.stack(
            x=[self_attn_mask] * self._n_head, axis=1)
        n_head_self_attn_mask.stop_gradient = True

X
xixiaoyao 已提交
138
        enc_out = encoder(
X
xixiaoyao 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
            enc_input=emb_out,
            attn_bias=n_head_self_attn_mask,
            n_layer=self._n_layer,
            n_head=self._n_head,
            d_key=self._emb_size // self._n_head,
            d_value=self._emb_size // self._n_head,
            d_model=self._emb_size,
            d_inner_hid=self._emb_size * 4,
            prepostprocess_dropout=self._prepostprocess_dropout,
            attention_dropout=self._attention_dropout,
            relu_dropout=0,
            hidden_act=self._hidden_act,
            preprocess_cmd="",
            postprocess_cmd="dan",
            param_initializer=self._param_initializer,
X
xixiaoyao 已提交
154
            name=scope_name+'encoder')
X
xixiaoyao 已提交
155

X
xixiaoyao 已提交
156
        
X
xixiaoyao 已提交
157
        next_sent_feat = fluid.layers.slice(
X
xixiaoyao 已提交
158 159
            input=enc_out, axes=[1], starts=[0], ends=[1])
        next_sent_feat = fluid.layers.reshape(next_sent_feat, [-1, next_sent_feat.shape[-1]])
X
xixiaoyao 已提交
160 161 162 163 164
        next_sent_feat = fluid.layers.fc(
            input=next_sent_feat,
            size=self._emb_size,
            act="tanh",
            param_attr=fluid.ParamAttr(
X
xixiaoyao 已提交
165 166
                name=scope_name+"pooled_fc.w_0", initializer=self._param_initializer),
            bias_attr=scope_name+"pooled_fc.b_0")
X
xixiaoyao 已提交
167

X
xixiaoyao 已提交
168 169
        return {'embedding_table': embedding_table,
                'word_embedding': emb_out,
X
xixiaoyao 已提交
170 171 172 173 174 175
                'encoder_outputs': enc_out,
                'sentence_embedding': next_sent_feat,
                'sentence_pair_embedding': next_sent_feat}

    def postprocess(self, rt_outputs):
        pass