From dad22cfabfcecf6d7adb0d7f7a6504969ad13388 Mon Sep 17 00:00:00 2001 From: 0YuanZhang0 <953963890@qq.com> Date: Wed, 25 Sep 2019 19:57:07 +0800 Subject: [PATCH] add_nets (#3416) --- .../server/bert_server/pdnlp/nets/bert.py | 231 ++++++++++++ .../pdnlp/nets/transformer_encoder.py | 353 ++++++++++++++++++ 2 files changed, 584 insertions(+) create mode 100644 PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/bert.py create mode 100644 PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/transformer_encoder.py diff --git a/PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/bert.py b/PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/bert.py new file mode 100644 index 00000000..67f1f51c --- /dev/null +++ b/PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/bert.py @@ -0,0 +1,231 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BERT model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import json +import numpy as np +import paddle.fluid as fluid + +from palm.nets.transformer_encoder import encoder as encoder +from palm.nets.transformer_encoder import pre_process_layer as pre_process_layer + + +class BertModel(object): + def __init__(self, + src_ids, + position_ids, + sentence_ids, + input_mask, + config, + weight_sharing=True, + use_fp16=False, + model_name=''): + + self._emb_size = config["hidden_size"] + self._n_layer = config["num_hidden_layers"] + self._n_head = config["num_attention_heads"] + self._voc_size = config["vocab_size"] + self._max_position_seq_len = config["max_position_embeddings"] + self._sent_types = config["type_vocab_size"] + self._hidden_act = config["hidden_act"] + self._prepostprocess_dropout = config["hidden_dropout_prob"] + self._attention_dropout = config["attention_probs_dropout_prob"] + self._weight_sharing = weight_sharing + + self.model_name = model_name + + self._word_emb_name = self.model_name + "word_embedding" + self._pos_emb_name = self.model_name + "pos_embedding" + self._sent_emb_name = self.model_name + "sent_embedding" + self._dtype = "float16" if use_fp16 else "float32" + + # Initialize all weigths by truncated normal initializer, and all biases + # will be initialized by constant zero by default. + self._param_initializer = fluid.initializer.TruncatedNormal( + scale=config["initializer_range"]) + + self._build_model(src_ids, position_ids, sentence_ids, input_mask, + config) + + def _build_model(self, src_ids, position_ids, sentence_ids, input_mask, + config): + # padding id in vocabulary must be set to 0 + emb_out = fluid.layers.embedding( + input=src_ids, + size=[self._voc_size, self._emb_size], + dtype=self._dtype, + param_attr=fluid.ParamAttr( + name=self._word_emb_name, initializer=self._param_initializer), + is_sparse=False) + + self.emb_out = emb_out + + position_emb_out = fluid.layers.embedding( + input=position_ids, + size=[self._max_position_seq_len, self._emb_size], + dtype=self._dtype, + param_attr=fluid.ParamAttr( + name=self._pos_emb_name, initializer=self._param_initializer)) + + self.position_emb_out = position_emb_out + + sent_emb_out = fluid.layers.embedding( + sentence_ids, + size=[self._sent_types, self._emb_size], + dtype=self._dtype, + param_attr=fluid.ParamAttr( + name=self._sent_emb_name, initializer=self._param_initializer)) + + self.sent_emb_out = sent_emb_out + + emb_out = emb_out + position_emb_out + emb_out = emb_out + sent_emb_out + + emb_out = pre_process_layer( + emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder') + + if self._dtype == "float16": + input_mask = fluid.layers.cast(x=input_mask, dtype=self._dtype) + + self_attn_mask = fluid.layers.matmul( + x=input_mask, y=input_mask, transpose_y=True) + + self_attn_mask = fluid.layers.scale( + x=self_attn_mask, + scale=config["self_att_scale"], + bias=-1.0, + bias_after_scale=False) + + n_head_self_attn_mask = fluid.layers.stack( + x=[self_attn_mask] * self._n_head, axis=1) + + n_head_self_attn_mask.stop_gradient = True + + self._enc_out = encoder( + enc_input=emb_out, + attn_bias=n_head_self_attn_mask, + n_layer=self._n_layer, + n_head=self._n_head, + d_key=self._emb_size // self._n_head, + d_value=self._emb_size // self._n_head, + d_model=self._emb_size, + d_inner_hid=self._emb_size * 4, + prepostprocess_dropout=self._prepostprocess_dropout, + attention_dropout=self._attention_dropout, + relu_dropout=0, + hidden_act=self._hidden_act, + preprocess_cmd="", + postprocess_cmd="dan", + param_initializer=self._param_initializer, + name=self.model_name + 'encoder') + + def get_sequence_output(self): + return self._enc_out + + def get_pooled_output(self): + """Get the first feature of each sequence for classification""" + + next_sent_feat = fluid.layers.slice( + input=self._enc_out, axes=[1], starts=[0], ends=[1]) + next_sent_feat = fluid.layers.fc( + input=next_sent_feat, + size=self._emb_size, + act="tanh", + param_attr=fluid.ParamAttr( + name=self.model_name + "pooled_fc.w_0", + initializer=self._param_initializer), + bias_attr="pooled_fc.b_0") + return next_sent_feat + + def get_pretraining_output(self, mask_label, mask_pos, labels): + """Get the loss & accuracy for pretraining""" + + mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32') + + # extract the first token feature in each sentence + next_sent_feat = self.get_pooled_output() + reshaped_emb_out = fluid.layers.reshape( + x=self._enc_out, shape=[-1, self._emb_size]) + # extract masked tokens' feature + mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos) + + # transform: fc + mask_trans_feat = fluid.layers.fc( + input=mask_feat, + size=self._emb_size, + act=self._hidden_act, + param_attr=fluid.ParamAttr( + name=self.model_name + 'mask_lm_trans_fc.w_0', + initializer=self._param_initializer), + bias_attr=fluid.ParamAttr( + name=self.model_name + 'mask_lm_trans_fc.b_0')) + # transform: layer norm + mask_trans_feat = pre_process_layer( + mask_trans_feat, 'n', name=self.model_name + 'mask_lm_trans') + + mask_lm_out_bias_attr = fluid.ParamAttr( + name=self.model_name + "mask_lm_out_fc.b_0", + initializer=fluid.initializer.Constant(value=0.0)) + if self._weight_sharing: + fc_out = fluid.layers.matmul( + x=mask_trans_feat, + y=fluid.default_main_program().global_block().var( + self._word_emb_name), + transpose_y=True) + fc_out += fluid.layers.create_parameter( + shape=[self._voc_size], + dtype=self._dtype, + attr=mask_lm_out_bias_attr, + is_bias=True) + + else: + fc_out = fluid.layers.fc( + input=mask_trans_feat, + size=self._voc_size, + param_attr=fluid.ParamAttr( + name=self.model_name + "mask_lm_out_fc.w_0", + initializer=self._param_initializer), + bias_attr=mask_lm_out_bias_attr) + + mask_lm_loss = fluid.layers.softmax_with_cross_entropy( + logits=fc_out, label=mask_label) + mean_mask_lm_loss = fluid.layers.mean(mask_lm_loss) + + next_sent_fc_out = fluid.layers.fc( + input=next_sent_feat, + size=2, + param_attr=fluid.ParamAttr( + name=self.model_name + "next_sent_fc.w_0", + initializer=self._param_initializer), + bias_attr=self.model_name + "next_sent_fc.b_0") + + next_sent_loss, next_sent_softmax = fluid.layers.softmax_with_cross_entropy( + logits=next_sent_fc_out, label=labels, return_softmax=True) + + next_sent_acc = fluid.layers.accuracy( + input=next_sent_softmax, label=labels) + + mean_next_sent_loss = fluid.layers.mean(next_sent_loss) + + loss = mean_next_sent_loss + mean_mask_lm_loss + return next_sent_acc, mean_mask_lm_loss, loss + + +if __name__ == "__main__": + print("hello wolrd!") diff --git a/PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/transformer_encoder.py b/PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/transformer_encoder.py new file mode 100644 index 00000000..d1297efe --- /dev/null +++ b/PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/transformer_encoder.py @@ -0,0 +1,353 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformer encoder.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import partial +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.layers as layers + + +def multi_head_attention(queries, + keys, + values, + attn_bias, + d_key, + d_value, + d_model, + n_head=1, + dropout_rate=0., + cache=None, + param_initializer=None, + name='multi_head_att'): + """ + Multi-Head Attention. Note that attn_bias is added to the logit before + computing softmax activiation to mask certain selected positions so that + they will not considered in attention weights. + """ + keys = queries if keys is None else keys + values = keys if values is None else values + + if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3): + raise ValueError( + "Inputs: quries, keys and values should all be 3-D tensors.") + + def __compute_qkv(queries, keys, values, n_head, d_key, d_value): + """ + Add linear projection to queries, keys, and values. + """ + q = layers.fc(input=queries, + size=d_key * n_head, + num_flatten_dims=2, + param_attr=fluid.ParamAttr( + name=name + '_query_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_query_fc.b_0') + k = layers.fc(input=keys, + size=d_key * n_head, + num_flatten_dims=2, + param_attr=fluid.ParamAttr( + name=name + '_key_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_key_fc.b_0') + v = layers.fc(input=values, + size=d_value * n_head, + num_flatten_dims=2, + param_attr=fluid.ParamAttr( + name=name + '_value_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_value_fc.b_0') + return q, k, v + + def __split_heads(x, n_head): + """ + Reshape the last dimension of inpunt tensor x so that it becomes two + dimensions and then transpose. Specifically, input a tensor with shape + [bs, max_sequence_length, n_head * hidden_dim] then output a tensor + with shape [bs, n_head, max_sequence_length, hidden_dim]. + """ + hidden_size = x.shape[-1] + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + reshaped = layers.reshape( + x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True) + + # permuate the dimensions into: + # [batch_size, n_head, max_sequence_len, hidden_size_per_head] + return layers.transpose(x=reshaped, perm=[0, 2, 1, 3]) + + def __combine_heads(x): + """ + Transpose and then reshape the last two dimensions of inpunt tensor x + so that it becomes one dimension, which is reverse to __split_heads. + """ + if len(x.shape) == 3: return x + if len(x.shape) != 4: + raise ValueError("Input(x) should be a 4-D Tensor.") + + trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + return layers.reshape( + x=trans_x, + shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]], + inplace=True) + + def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate): + """ + Scaled Dot-Product Attention + """ + scaled_q = layers.scale(x=q, scale=d_key**-0.5) + product = layers.matmul(x=scaled_q, y=k, transpose_y=True) + if attn_bias: + product += attn_bias + weights = layers.softmax(product) + if dropout_rate: + weights = layers.dropout( + weights, + dropout_prob=dropout_rate, + dropout_implementation="upscale_in_train", + is_test=False) + out = layers.matmul(weights, v) + return out + + q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) + + if cache is not None: # use cache and concat time steps + # Since the inplace reshape in __split_heads changes the shape of k and + # v, which is the cache input for next time step, reshape the cache + # input from the previous time step first. + k = cache["k"] = layers.concat( + [layers.reshape( + cache["k"], shape=[0, 0, d_model]), k], axis=1) + v = cache["v"] = layers.concat( + [layers.reshape( + cache["v"], shape=[0, 0, d_model]), v], axis=1) + + q = __split_heads(q, n_head) + k = __split_heads(k, n_head) + v = __split_heads(v, n_head) + + ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key, + dropout_rate) + + out = __combine_heads(ctx_multiheads) + + # Project back to the model size. + proj_out = layers.fc(input=out, + size=d_model, + num_flatten_dims=2, + param_attr=fluid.ParamAttr( + name=name + '_output_fc.w_0', + initializer=param_initializer), + bias_attr=name + '_output_fc.b_0') + return proj_out + + +def positionwise_feed_forward(x, + d_inner_hid, + d_hid, + dropout_rate, + hidden_act, + param_initializer=None, + name='ffn'): + """ + Position-wise Feed-Forward Networks. + This module consists of two linear transformations with a ReLU activation + in between, which is applied to each position separately and identically. + """ + hidden = layers.fc(input=x, + size=d_inner_hid, + num_flatten_dims=2, + act=hidden_act, + param_attr=fluid.ParamAttr( + name=name + '_fc_0.w_0', + initializer=param_initializer), + bias_attr=name + '_fc_0.b_0') + if dropout_rate: + hidden = layers.dropout( + hidden, + dropout_prob=dropout_rate, + dropout_implementation="upscale_in_train", + is_test=False) + + out = layers.fc(input=hidden, + size=d_hid, + num_flatten_dims=2, + param_attr=fluid.ParamAttr( + name=name + '_fc_1.w_0', initializer=param_initializer), + bias_attr=name + '_fc_1.b_0') + return out + + +def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0., + name=''): + """ + Add residual connection, layer normalization and droput to the out tensor + optionally according to the value of process_cmd. + This will be used before or after multi-head attention and position-wise + feed-forward networks. + """ + for cmd in process_cmd: + if cmd == "a": # add residual connection + out = out + prev_out if prev_out else out + elif cmd == "n": # add layer normalization + out_dtype = out.dtype + if out_dtype == fluid.core.VarDesc.VarType.FP16: + out = layers.cast(x=out, dtype="float32") + out = layers.layer_norm( + out, + begin_norm_axis=len(out.shape) - 1, + param_attr=fluid.ParamAttr( + name=name + '_layer_norm_scale', + initializer=fluid.initializer.Constant(1.)), + bias_attr=fluid.ParamAttr( + name=name + '_layer_norm_bias', + initializer=fluid.initializer.Constant(0.))) + if out_dtype == fluid.core.VarDesc.VarType.FP16: + out = layers.cast(x=out, dtype="float16") + elif cmd == "d": # add dropout + if dropout_rate: + out = layers.dropout( + out, + dropout_prob=dropout_rate, + dropout_implementation="upscale_in_train", + is_test=False) + return out + + +pre_process_layer = partial(pre_post_process_layer, None) +post_process_layer = pre_post_process_layer + + +def encoder_layer(enc_input, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + hidden_act, + preprocess_cmd="n", + postprocess_cmd="da", + param_initializer=None, + name=''): + """ + The encoder layers that can be stacked to form a deep encoder. + This module consits of a multi-head (self) attention followed by + position-wise feed-forward networks and both the two components companied + with the post_process_layer to add residual connection, layer normalization + and droput. + """ + attn_output = multi_head_attention( + pre_process_layer( + enc_input, + preprocess_cmd, + prepostprocess_dropout, + name=name + '_pre_att'), + None, + None, + attn_bias, + d_key, + d_value, + d_model, + n_head, + attention_dropout, + param_initializer=param_initializer, + name=name + '_multi_head_att') + attn_output = post_process_layer( + enc_input, + attn_output, + postprocess_cmd, + prepostprocess_dropout, + name=name + '_post_att') + ffd_output = positionwise_feed_forward( + pre_process_layer( + attn_output, + preprocess_cmd, + prepostprocess_dropout, + name=name + '_pre_ffn'), + d_inner_hid, + d_model, + relu_dropout, + hidden_act, + param_initializer=param_initializer, + name=name + '_ffn') + return post_process_layer( + attn_output, + ffd_output, + postprocess_cmd, + prepostprocess_dropout, + name=name + '_post_ffn') + + +def encoder(enc_input, + attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + hidden_act, + preprocess_cmd="n", + postprocess_cmd="da", + param_initializer=None, + name='', + return_all=False): + """ + The encoder is composed of a stack of identical layers returned by calling + encoder_layer. + """ + enc_outputs = [] + for i in range(n_layer): + enc_output = encoder_layer( + enc_input, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + hidden_act, + preprocess_cmd, + postprocess_cmd, + param_initializer=param_initializer, + name=name + '_layer_' + str(i)) + enc_input = enc_output + if i < n_layer - 1: + enc_outputs.append(enc_output) + + enc_output = pre_process_layer( + enc_output, preprocess_cmd, prepostprocess_dropout, name="post_encoder") + enc_outputs.append(enc_output) + + if not return_all: + return enc_output + else: + return enc_output, enc_outputs -- GitLab