From 52a8ce19af1c610f463ec22cd3f1b515e7478e2d Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 4 Nov 2019 10:29:04 +0800 Subject: [PATCH] Update Transformer details --- PaddleNLP/PaddleMT/transformer/reader.py | 1 + PaddleNLP/PaddleMT/transformer/transformer.py | 97 ++++++++++++++----- 2 files changed, 73 insertions(+), 25 deletions(-) diff --git a/PaddleNLP/PaddleMT/transformer/reader.py b/PaddleNLP/PaddleMT/transformer/reader.py index e69b4a25..b3c09b7b 100644 --- a/PaddleNLP/PaddleMT/transformer/reader.py +++ b/PaddleNLP/PaddleMT/transformer/reader.py @@ -89,6 +89,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) trg_word = trg_word.reshape(-1, trg_max_len) + trg_word = trg_word[:, 1:] # pad by fluid.layers.pad trg_pos = trg_pos.reshape(-1, trg_max_len) trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], diff --git a/PaddleNLP/PaddleMT/transformer/transformer.py b/PaddleNLP/PaddleMT/transformer/transformer.py index be20001b..f310a50e 100644 --- a/PaddleNLP/PaddleMT/transformer/transformer.py +++ b/PaddleNLP/PaddleMT/transformer/transformer.py @@ -17,6 +17,7 @@ import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers +from paddle.fluid.layer_helper import LayerHelper from desc import * @@ -70,8 +71,8 @@ def position_encoding_init(n_position, d_pos_vec): num_timescales = channels // 2 log_timescale_increment = (np.log(float(1e4) / float(1)) / (num_timescales - 1)) - inv_timescales = np.exp(np.arange( - num_timescales)) * -log_timescale_increment + inv_timescales = np.exp( + np.arange(num_timescales) * -log_timescale_increment) scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0) signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) @@ -80,6 +81,39 @@ def position_encoding_init(n_position, d_pos_vec): return position_enc.astype("float32") +def layer_norm(x, + begin_norm_axis=1, + epsilon=1e-5, + param_attr=None, + bias_attr=None): + helper = LayerHelper('layer_norm', **locals()) + mean = layers.reduce_mean(x, + dim=range(begin_norm_axis, len(x.shape)), + keep_dim=True) + shift_x = layers.elementwise_sub(x=x, y=mean, axis=0) + variance = layers.reduce_mean(layers.square(shift_x), + dim=range(begin_norm_axis, len(x.shape)), + keep_dim=True) + r_stdev = layers.rsqrt(variance + epsilon) + norm_x = layers.elementwise_mul(x=shift_x, y=r_stdev, axis=0) + param_shape = norm_x.shape[begin_norm_axis:] + param_dtype = norm_x.dtype + scale = helper.create_parameter( + attr=param_attr, + shape=param_shape, + dtype=param_dtype, + default_initializer=fluid.initializer.Constant(1.)) + bias = helper.create_parameter( + attr=bias_attr, + shape=param_shape, + dtype=param_dtype, + is_bias=True, + default_initializer=fluid.initializer.Constant(0.)) + out = layers.elementwise_mul(x=norm_x, y=scale, axis=-1) + out = layers.elementwise_add(x=out, y=bias, axis=-1) + return out + + def multi_head_attention(queries, keys, values, @@ -212,18 +246,18 @@ def multi_head_attention(queries, product += attn_bias weights = layers.softmax(product) if dropout_rate: - weights = layers.dropout( - weights, - dropout_prob=dropout_rate, - seed=dropout_seed, - is_test=False) + weights = layers.dropout(weights, + dropout_prob=dropout_rate, + seed=dropout_seed, + is_test=False, + dropout_implementation="upscale_in_train") out = layers.matmul(weights, v) return out q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) q, k, v = __split_heads_qkv(q, k, v, n_head, d_key, d_value) - ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model, + ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate) out = __combine_heads(ctx_multiheads) @@ -247,8 +281,11 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate): num_flatten_dims=2, act="relu") if dropout_rate: - hidden = layers.dropout( - hidden, dropout_prob=dropout_rate, seed=dropout_seed, is_test=False) + hidden = layers.dropout(hidden, + dropout_prob=dropout_rate, + seed=dropout_seed, + is_test=False, + dropout_implementation="upscale_in_train") out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2) return out @@ -264,18 +301,17 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.): if cmd == "a": # add residual connection out = out + prev_out if prev_out else out elif cmd == "n": # add layer normalization - out = layers.layer_norm( - out, - begin_norm_axis=len(out.shape) - 1, - param_attr=fluid.initializer.Constant(1.), - bias_attr=fluid.initializer.Constant(0.)) + out = layer_norm(out, + begin_norm_axis=len(out.shape) - 1, + param_attr=fluid.initializer.Constant(1.), + bias_attr=fluid.initializer.Constant(0.)) elif cmd == "d": # add dropout if dropout_rate: - out = layers.dropout( - out, - dropout_prob=dropout_rate, - seed=dropout_seed, - is_test=False) + out = layers.dropout(out, + dropout_prob=dropout_rate, + seed=dropout_seed, + is_test=False, + dropout_implementation="upscale_in_train") return out @@ -290,6 +326,7 @@ def prepare_encoder_decoder(src_word, src_max_len, dropout_rate=0., bos_idx=0, + pad_bos=False, word_emb_param_name=None, pos_enc_param_name=None): """Add word embeddings and position encodings. @@ -305,6 +342,9 @@ def prepare_encoder_decoder(src_word, initializer=fluid.initializer.Normal( 0., src_emb_dim**-0.5))) + if pad_bos: # if inputs not include bos, set embedding of bos to 0 + src_word_emb = layers.pad(src_word_emb, [0, 0, 1, 0, 0, 0]) + src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) src_pos_enc = fluid.embedding(src_pos, size=[src_max_len, src_emb_dim], @@ -312,9 +352,12 @@ def prepare_encoder_decoder(src_word, name=pos_enc_param_name, trainable=False)) src_pos_enc.stop_gradient = True enc_input = src_word_emb + src_pos_enc - return layers.dropout( - enc_input, dropout_prob=dropout_rate, seed=dropout_seed, - is_test=False) if dropout_rate else enc_input + return layers.dropout(enc_input, + dropout_prob=dropout_rate, + seed=dropout_seed, + is_test=False, + dropout_implementation="upscale_in_train" + ) if dropout_rate else enc_input prepare_encoder = partial( @@ -568,7 +611,9 @@ def transformer(model_input, preprocess_cmd, postprocess_cmd, weight_sharing, - enc_output=enc_output) + enc_output=enc_output, + bos_idx=bos_idx, + is_test=is_test) # Padding index do not contribute to the total loss. The weights is used to # cancel padding index in calculating the loss. @@ -655,7 +700,8 @@ def wrap_decoder(dec_inputs, enc_output=None, caches=None, gather_idx=None, - bos_idx=0): + bos_idx=0, + is_test=False): """ The wrapper assembles together all needed layers for the decoder. """ @@ -669,6 +715,7 @@ def wrap_decoder(dec_inputs, max_length, prepostprocess_dropout, bos_idx=bos_idx, + pad_bos=not is_test, # target inputs don't include bos for training word_emb_param_name=word_emb_param_names[0] if weight_sharing else word_emb_param_names[1]) dec_output = decoder( -- GitLab