From 64eb26ae5910a6905d8b60d853789574d63b51b2 Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Fri, 15 May 2020 10:49:46 +0800 Subject: [PATCH] Add validation for dygraph Transformer. (#4628) Add cross-attention cache for dygraph Transformer. Add greedy search for dygraph Transformer. --- dygraph/transformer/README.md | 9 +- dygraph/transformer/config.py | 199 ------------ dygraph/transformer/model.py | 470 +++++++++++++++++---------- dygraph/transformer/train.py | 62 +++- dygraph/transformer/transformer.yaml | 2 + 5 files changed, 358 insertions(+), 384 deletions(-) delete mode 100644 dygraph/transformer/config.py diff --git a/dygraph/transformer/README.md b/dygraph/transformer/README.md index 6cec2d79..4b8247ac 100644 --- a/dygraph/transformer/README.md +++ b/dygraph/transformer/README.md @@ -76,6 +76,7 @@ python -u train.py \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --special_token '' '' '' \ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ + --validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ --batch_size 4096 ``` @@ -91,6 +92,7 @@ python -u train.py \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --special_token '' '' '' \ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ + --validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ --batch_size 4096 \ --n_head 16 \ --d_model 1024 \ @@ -121,10 +123,11 @@ Paddle动态图支持多进程多卡进行模型训练,启动训练的方式 ```sh python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 --log_dir ./mylog train.py \ --epoch 30 \ - --src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \ - --trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \ + --src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ + --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --special_token '' '' '' \ - --training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ + --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ + --validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ --batch_size 4096 \ --print_step 100 \ --use_cuda True \ diff --git a/dygraph/transformer/config.py b/dygraph/transformer/config.py deleted file mode 100644 index 2841e04d..00000000 --- a/dygraph/transformer/config.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class TrainTaskConfig(object): - """ - TrainTaskConfig - """ - # the epoch number to train. - pass_num = 20 - # the number of sequences contained in a mini-batch. - # deprecated, set batch_size in args. - batch_size = 32 - # the hyper parameters for Adam optimizer. - # This static learning_rate will be multiplied to the LearningRateScheduler - # derived learning rate the to get the final learning rate. - learning_rate = 2.0 - beta1 = 0.9 - beta2 = 0.997 - eps = 1e-9 - # the parameters for learning rate scheduling. - warmup_steps = 8000 - # the weight used to mix up the ground-truth distribution and the fixed - # uniform distribution in label smoothing when training. - # Set this as zero if label smoothing is not wanted. - label_smooth_eps = 0.1 - - -class InferTaskConfig(object): - # the number of examples in one run for sequence generation. - batch_size = 4 - # the parameters for beam search. - beam_size = 4 - alpha = 0.6 - # max decoded length, should be less than ModelHyperParams.max_length - max_out_len = 30 - - -class ModelHyperParams(object): - """ - ModelHyperParams - """ - # These following five vocabularies related configurations will be set - # automatically according to the passed vocabulary path and special tokens. - # size of source word dictionary. - src_vocab_size = 10000 - # size of target word dictionay - trg_vocab_size = 10000 - # index for token - bos_idx = 0 - # index for token - eos_idx = 1 - # index for token - unk_idx = 2 - - # max length of sequences deciding the size of position encoding table. - max_length = 50 - # the dimension for word embeddings, which is also the last dimension of - # the input and output of multi-head attention, position-wise feed-forward - # networks, encoder and decoder. - d_model = 512 - # size of the hidden layer in position-wise feed-forward networks. - d_inner_hid = 2048 - # the dimension that keys are projected to for dot-product attention. - d_key = 64 - # the dimension that values are projected to for dot-product attention. - d_value = 64 - # number of head used in multi-head attention. - n_head = 8 - # number of sub-layers to be stacked in the encoder and decoder. - n_layer = 6 - # dropout rates of different modules. - prepostprocess_dropout = 0.1 - attention_dropout = 0.1 - relu_dropout = 0.1 - # to process before each sub-layer - preprocess_cmd = "n" # layer normalization - # to process after each sub-layer - postprocess_cmd = "da" # dropout + residual connection - # the flag indicating whether to share embedding and softmax weights. - # vocabularies in source and target should be same for weight sharing. - weight_sharing = False - - -# The placeholder for batch_size in compile time. Must be -1 currently to be -# consistent with some ops' infer-shape output in compile time, such as the -# sequence_expand op used in beamsearch decoder. -batch_size = -1 -# The placeholder for squence length in compile time. -seq_len = ModelHyperParams.max_length -# Here list the data shapes and data types of all inputs. -# The shapes here act as placeholder and are set to pass the infer-shape in -# compile time. -input_descs = { - # The actual data shape of src_word is: - # [batch_size, max_src_len_in_batch, 1] - "src_word": [(batch_size, seq_len, 1), "int64", 2], - # The actual data shape of src_pos is: - # [batch_size, max_src_len_in_batch, 1] - "src_pos": [(batch_size, seq_len, 1), "int64"], - # This input is used to remove attention weights on paddings in the - # encoder. - # The actual data shape of src_slf_attn_bias is: - # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] - "src_slf_attn_bias": - [(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"], - # The actual data shape of trg_word is: - # [batch_size, max_trg_len_in_batch, 1] - "trg_word": [(batch_size, seq_len, 1), "int64", - 2], # lod_level is only used in fast decoder. - # The actual data shape of trg_pos is: - # [batch_size, max_trg_len_in_batch, 1] - "trg_pos": [(batch_size, seq_len, 1), "int64"], - # This input is used to remove attention weights on paddings and - # subsequent words in the decoder. - # The actual data shape of trg_slf_attn_bias is: - # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] - "trg_slf_attn_bias": - [(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"], - # This input is used to remove attention weights on paddings of the source - # input in the encoder-decoder attention. - # The actual data shape of trg_src_attn_bias is: - # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] - "trg_src_attn_bias": - [(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"], - # This input is used in independent decoder program for inference. - # The actual data shape of enc_output is: - # [batch_size, max_src_len_in_batch, d_model] - "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"], - # The actual data shape of label_word is: - # [batch_size * max_trg_len_in_batch, 1] - "lbl_word": [(batch_size * seq_len, 1), "int64"], - # This input is used to mask out the loss of paddding tokens. - # The actual data shape of label_weight is: - # [batch_size * max_trg_len_in_batch, 1] - "lbl_weight": [(batch_size * seq_len, 1), "float32"], - # This input is used in beam-search decoder. - "init_score": [(batch_size, 1), "float32", 2], - # This input is used in beam-search decoder for the first gather - # (cell states updation) - "init_idx": [(batch_size, ), "int32"], -} - -# Names of word embedding table which might be reused for weight sharing. -word_emb_param_names = ( - "src_word_emb_table", - "trg_word_emb_table", ) -# Names of position encoding table which will be initialized externally. -pos_enc_param_names = ( - "src_pos_enc_table", - "trg_pos_enc_table", ) -# separated inputs for different usages. -encoder_data_input_fields = ( - "src_word", - "src_pos", - "src_slf_attn_bias", ) -decoder_data_input_fields = ( - "trg_word", - "trg_pos", - "trg_slf_attn_bias", - "trg_src_attn_bias", - "enc_output", ) -label_data_input_fields = ( - "lbl_word", - "lbl_weight", ) -# In fast decoder, trg_pos (only containing the current time step) is generated -# by ops and trg_slf_attn_bias is not needed. -fast_decoder_data_input_fields = ( - "trg_word", - # "init_score", - # "init_idx", - "trg_src_attn_bias", ) - - -def merge_cfg_from_list(cfg_list, g_cfgs): - """ - Set the above global configurations using the cfg_list. - """ - assert len(cfg_list) % 2 == 0 - for key, value in zip(cfg_list[0::2], cfg_list[1::2]): - for g_cfg in g_cfgs: - if hasattr(g_cfg, key): - try: - value = eval(value) - except Exception: # for file path - pass - setattr(g_cfg, key, value) - break diff --git a/dygraph/transformer/model.py b/dygraph/transformer/model.py index 1693a8a7..70f64f50 100644 --- a/dygraph/transformer/model.py +++ b/dygraph/transformer/model.py @@ -18,12 +18,9 @@ import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers -from paddle.fluid.layers.utils import map_structure from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay -from config import word_emb_param_names, pos_enc_param_names - def position_encoding_init(n_position, d_pos_vec): """ @@ -34,10 +31,10 @@ def position_encoding_init(n_position, d_pos_vec): num_timescales = channels // 2 log_timescale_increment = (np.log(float(1e4) / float(1)) / (num_timescales - 1)) - inv_timescales = np.exp(np.arange( - num_timescales)) * -log_timescale_increment - scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, - 0) + inv_timescales = np.exp( + np.arange(num_timescales)) * -log_timescale_increment + scaled_time = np.expand_dims(position, 1) * np.expand_dims( + inv_timescales, 0) signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant') position_enc = signal @@ -48,7 +45,6 @@ class NoamDecay(LearningRateDecay): """ learning rate scheduler """ - def __init__(self, d_model, warmup_steps, @@ -73,7 +69,6 @@ class PrePostProcessLayer(Layer): """ PrePostProcessLayer """ - def __init__(self, process_cmd, d_model, dropout_rate): super(PrePostProcessLayer, self).__init__() self.process_cmd = process_cmd @@ -84,8 +79,8 @@ class PrePostProcessLayer(Layer): elif cmd == "n": # add layer normalization self.functors.append( self.add_sublayer( - "layer_norm_%d" % len( - self.sublayers(include_sublayers=False)), + "layer_norm_%d" % + len(self.sublayers(include_sublayers=False)), LayerNorm( normalized_shape=d_model, param_attr=fluid.ParamAttr( @@ -93,9 +88,9 @@ class PrePostProcessLayer(Layer): bias_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(0.))))) elif cmd == "d": # add dropout - if dropout_rate: - self.functors.append(lambda x: layers.dropout( - x, dropout_prob=dropout_rate, is_test=False)) + self.functors.append(lambda x: layers.dropout( + x, dropout_prob=dropout_rate, is_test=False) + if dropout_rate else x) def forward(self, x, residual=None): for i, cmd in enumerate(self.process_cmd): @@ -110,7 +105,6 @@ class MultiHeadAttention(Layer): """ Multi-Head Attention """ - def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.): super(MultiHeadAttention, self).__init__() self.n_head = n_head @@ -118,49 +112,73 @@ class MultiHeadAttention(Layer): self.d_value = d_value self.d_model = d_model self.dropout_rate = dropout_rate - self.q_fc = Linear( - input_dim=d_model, output_dim=d_key * n_head, bias_attr=False) - self.k_fc = Linear( - input_dim=d_model, output_dim=d_key * n_head, bias_attr=False) - self.v_fc = Linear( - input_dim=d_model, output_dim=d_value * n_head, bias_attr=False) - self.proj_fc = Linear( - input_dim=d_value * n_head, output_dim=d_model, bias_attr=False) - - def forward(self, queries, keys, values, attn_bias, cache=None): - # compute q ,k ,v - keys = queries if keys is None else keys - values = keys if values is None else values + self.q_fc = Linear(input_dim=d_model, + output_dim=d_key * n_head, + bias_attr=False) + self.k_fc = Linear(input_dim=d_model, + output_dim=d_key * n_head, + bias_attr=False) + self.v_fc = Linear(input_dim=d_model, + output_dim=d_value * n_head, + bias_attr=False) + self.proj_fc = Linear(input_dim=d_value * n_head, + output_dim=d_model, + bias_attr=False) + + def _prepare_qkv(self, queries, keys, values, cache=None): + if keys is None: # self-attention + keys, values = queries, queries + static_kv = False + else: # cross-attention + static_kv = True q = self.q_fc(queries) - k = self.k_fc(keys) - v = self.v_fc(values) - - # split head q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key]) q = layers.transpose(x=q, perm=[0, 2, 1, 3]) - k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key]) - k = layers.transpose(x=k, perm=[0, 2, 1, 3]) - v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value]) - v = layers.transpose(x=v, perm=[0, 2, 1, 3]) + + if cache is not None and static_kv and "static_k" in cache: + # for encoder-decoder attention in inference and has cached + k = cache["static_k"] + v = cache["static_v"] + else: + k = self.k_fc(keys) + v = self.v_fc(values) + k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key]) + k = layers.transpose(x=k, perm=[0, 2, 1, 3]) + v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value]) + v = layers.transpose(x=v, perm=[0, 2, 1, 3]) if cache is not None: - cache_k, cache_v = cache["k"], cache["v"] - k = layers.concat([cache_k, k], axis=2) - v = layers.concat([cache_v, v], axis=2) - cache["k"], cache["v"] = k, v + if static_kv and not "static_k" in cache: + # for encoder-decoder attention in inference and has not cached + cache["static_k"], cache["static_v"] = k, v + elif not static_kv: + # for decoder self-attention in inference + cache_k, cache_v = cache["k"], cache["v"] + k = layers.concat([cache_k, k], axis=2) + v = layers.concat([cache_v, v], axis=2) + cache["k"], cache["v"] = k, v + + return q, k, v + + def forward(self, queries, keys, values, attn_bias, cache=None): + # compute q ,k ,v + q, k, v = self._prepare_qkv(queries, keys, values, cache) # scale dot product attention - product = layers.matmul( - x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5) + product = layers.matmul(x=q, + y=k, + transpose_y=True, + alpha=self.d_model**-0.5) if attn_bias is not None: product += attn_bias weights = layers.softmax(product) if self.dropout_rate: - weights = layers.dropout( - weights, dropout_prob=self.dropout_rate, is_test=False) + weights = layers.dropout(weights, + dropout_prob=self.dropout_rate, + is_test=False) - out = layers.matmul(weights, v) + out = layers.matmul(weights, v) # combine heads out = layers.transpose(out, perm=[0, 2, 1, 3]) @@ -175,7 +193,6 @@ class FFN(Layer): """ Feed-Forward Network """ - def __init__(self, d_inner_hid, d_model, dropout_rate): super(FFN, self).__init__() self.dropout_rate = dropout_rate @@ -185,8 +202,9 @@ class FFN(Layer): def forward(self, x): hidden = self.fc1(x) if self.dropout_rate: - hidden = layers.dropout( - hidden, dropout_prob=self.dropout_rate, is_test=False) + hidden = layers.dropout(hidden, + dropout_prob=self.dropout_rate, + is_test=False) out = self.fc2(hidden) return out @@ -195,7 +213,6 @@ class EncoderLayer(Layer): """ EncoderLayer """ - def __init__(self, n_head, d_key, @@ -224,8 +241,8 @@ class EncoderLayer(Layer): prepostprocess_dropout) def forward(self, enc_input, attn_bias): - attn_output = self.self_attn( - self.preprocesser1(enc_input), None, None, attn_bias) + attn_output = self.self_attn(self.preprocesser1(enc_input), None, None, + attn_bias) attn_output = self.postprocesser1(attn_output, enc_input) ffn_output = self.ffn(self.preprocesser2(attn_output)) @@ -237,7 +254,6 @@ class Encoder(Layer): """ encoder """ - def __init__(self, n_layer, n_head, @@ -277,7 +293,6 @@ class Embedder(Layer): """ Word Embedding + Position Encoding """ - def __init__(self, vocab_size, emb_dim, bos_idx=0): super(Embedder, self).__init__() @@ -296,7 +311,6 @@ class WrapEncoder(Layer): """ embedder + encoder """ - def __init__(self, src_vocab_size, max_length, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, @@ -324,9 +338,9 @@ class WrapEncoder(Layer): pos_enc = self.pos_encoder(src_pos) pos_enc.stop_gradient = True emb = word_emb + pos_enc - enc_input = layers.dropout( - emb, dropout_prob=self.emb_dropout, - is_test=False) if self.emb_dropout else emb + enc_input = layers.dropout(emb, + dropout_prob=self.emb_dropout, + is_test=False) if self.emb_dropout else emb enc_output = self.encoder(enc_input, src_slf_attn_bias) return enc_output @@ -336,7 +350,6 @@ class DecoderLayer(Layer): """ decoder """ - def __init__(self, n_head, d_key, @@ -376,13 +389,13 @@ class DecoderLayer(Layer): self_attn_bias, cross_attn_bias, cache=None): - self_attn_output = self.self_attn( - self.preprocesser1(dec_input), None, None, self_attn_bias, cache) + self_attn_output = self.self_attn(self.preprocesser1(dec_input), None, + None, self_attn_bias, cache) self_attn_output = self.postprocesser1(self_attn_output, dec_input) cross_attn_output = self.cross_attn( self.preprocesser2(self_attn_output), enc_output, enc_output, - cross_attn_bias) + cross_attn_bias, cache) cross_attn_output = self.postprocesser2(cross_attn_output, self_attn_output) @@ -396,7 +409,6 @@ class Decoder(Layer): """ decoder """ - def __init__(self, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd): @@ -422,8 +434,8 @@ class Decoder(Layer): caches=None): for i, decoder_layer in enumerate(self.decoder_layers): dec_output = decoder_layer(dec_input, enc_output, self_attn_bias, - cross_attn_bias, None - if caches is None else caches[i]) + cross_attn_bias, + None if caches is None else caches[i]) dec_input = dec_output return self.processer(dec_output) @@ -433,7 +445,6 @@ class WrapDecoder(Layer): """ embedder + decoder """ - def __init__(self, trg_vocab_size, max_length, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, @@ -461,8 +472,9 @@ class WrapDecoder(Layer): word_embedder.weight, transpose_y=True) else: - self.linear = Linear( - input_dim=d_model, output_dim=trg_vocab_size, bias_attr=False) + self.linear = Linear(input_dim=d_model, + output_dim=trg_vocab_size, + bias_attr=False) def forward(self, trg_word, @@ -476,14 +488,15 @@ class WrapDecoder(Layer): pos_enc = self.pos_encoder(trg_pos) pos_enc.stop_gradient = True emb = word_emb + pos_enc - dec_input = layers.dropout( - emb, dropout_prob=self.emb_dropout, - is_test=False) if self.emb_dropout else emb + dec_input = layers.dropout(emb, + dropout_prob=self.emb_dropout, + is_test=False) if self.emb_dropout else emb dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias, trg_src_attn_bias, caches) dec_output = layers.reshape( dec_output, - shape=[-1, dec_output.shape[-1]], ) + shape=[-1, dec_output.shape[-1]], + ) logits = self.linear(dec_output) return logits @@ -494,10 +507,9 @@ class CrossEntropyCriterion(object): def __call__(self, predict, label, weights): if self.label_smooth_eps: - label_out = layers.label_smooth( - label=layers.one_hot( - input=label, depth=predict.shape[-1]), - epsilon=self.label_smooth_eps) + label_out = layers.label_smooth(label=layers.one_hot( + input=label, depth=predict.shape[-1]), + epsilon=self.label_smooth_eps) cost = layers.softmax_with_cross_entropy( logits=predict, @@ -515,7 +527,6 @@ class Transformer(Layer): """ model """ - def __init__(self, src_vocab_size, trg_vocab_size, @@ -535,25 +546,29 @@ class Transformer(Layer): bos_id=0, eos_id=1): super(Transformer, self).__init__() - src_word_embedder = Embedder( - vocab_size=src_vocab_size, emb_dim=d_model, bos_idx=bos_id) - self.encoder = WrapEncoder( - src_vocab_size, max_length, n_layer, n_head, d_key, d_value, - d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, - relu_dropout, preprocess_cmd, postprocess_cmd, src_word_embedder) + src_word_embedder = Embedder(vocab_size=src_vocab_size, + emb_dim=d_model, + bos_idx=bos_id) + self.encoder = WrapEncoder(src_vocab_size, max_length, n_layer, n_head, + d_key, d_value, d_model, d_inner_hid, + prepostprocess_dropout, attention_dropout, + relu_dropout, preprocess_cmd, + postprocess_cmd, src_word_embedder) if weight_sharing: assert src_vocab_size == trg_vocab_size, ( "Vocabularies in source and target should be same for weight sharing." ) trg_word_embedder = src_word_embedder else: - trg_word_embedder = Embedder( - vocab_size=trg_vocab_size, emb_dim=d_model, bos_idx=bos_id) - self.decoder = WrapDecoder( - trg_vocab_size, max_length, n_layer, n_head, d_key, d_value, - d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, - relu_dropout, preprocess_cmd, postprocess_cmd, weight_sharing, - trg_word_embedder) + trg_word_embedder = Embedder(vocab_size=trg_vocab_size, + emb_dim=d_model, + bos_idx=bos_id) + self.decoder = WrapDecoder(trg_vocab_size, max_length, n_layer, n_head, + d_key, d_value, d_model, d_inner_hid, + prepostprocess_dropout, attention_dropout, + relu_dropout, preprocess_cmd, + postprocess_cmd, weight_sharing, + trg_word_embedder) self.trg_vocab_size = trg_vocab_size self.n_layer = n_layer @@ -583,18 +598,14 @@ class Transformer(Layer): Beam search with the alive and finished two queues, both have a beam size capicity separately. It includes `grow_topk` `grow_alive` `grow_finish` as steps. - 1. `grow_topk` selects the top `2*beam_size` candidates to avoid all getting EOS. - 2. `grow_alive` selects the top `beam_size` non-EOS candidates as the inputs of next decoding step. - 3. `grow_finish` compares the already finished candidates in the finished queue and newly added finished candidates from `grow_topk`, and selects the top `beam_size` finished candidates. """ - def expand_to_beam_size(tensor, beam_size): tensor = layers.reshape(tensor, [tensor.shape[0], 1] + tensor.shape[1:]) @@ -616,23 +627,19 @@ class Transformer(Layer): ### initialize states of beam search ### ## init for the alive ## initial_log_probs = to_variable( - np.array( - [[0.] + [-inf] * (beam_size - 1)], dtype="float32")) + np.array([[0.] + [-inf] * (beam_size - 1)], dtype="float32")) alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1]) alive_seq = to_variable( - np.tile( - np.array( - [[[bos_id]]], dtype="int64"), (batch_size, beam_size, 1))) + np.tile(np.array([[[bos_id]]], dtype="int64"), + (batch_size, beam_size, 1))) ## init for the finished ## finished_scores = to_variable( - np.array( - [[-inf] * beam_size], dtype="float32")) + np.array([[-inf] * beam_size], dtype="float32")) finished_scores = layers.expand(finished_scores, [batch_size, 1]) finished_seq = to_variable( - np.tile( - np.array( - [[[bos_id]]], dtype="int64"), (batch_size, beam_size, 1))) + np.tile(np.array([[[bos_id]]], dtype="int64"), + (batch_size, beam_size, 1))) finished_flags = layers.zeros_like(finished_scores) ### initialize inputs and states of transformer decoder ### @@ -644,11 +651,13 @@ class Transformer(Layer): enc_output = merge_beam_dim(expand_to_beam_size(enc_output, beam_size)) ## init states (caches) for transformer, need to be updated according to selected beam caches = [{ - "k": layers.fill_constant( + "k": + layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_key], dtype=enc_output.dtype, value=0), - "v": layers.fill_constant( + "v": + layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_value], dtype=enc_output.dtype, value=0), @@ -667,11 +676,11 @@ class Transformer(Layer): beam_size, batch_size, need_flat=True): - batch_idx = layers.range( - 0, batch_size, 1, dtype="int64") * beam_size + batch_idx = layers.range(0, batch_size, 1, + dtype="int64") * beam_size flat_tensor = merge_beam_dim(tensor_nd) if need_flat else tensor_nd - idx = layers.reshape( - layers.elementwise_add(beam_idx, batch_idx, 0), [-1]) + idx = layers.reshape(layers.elementwise_add(beam_idx, batch_idx, 0), + [-1]) new_flat_tensor = layers.gather(flat_tensor, idx) new_tensor_nd = layers.reshape( new_flat_tensor, @@ -714,8 +723,8 @@ class Transformer(Layer): curr_scores = log_probs / length_penalty flat_curr_scores = layers.reshape(curr_scores, [batch_size, -1]) - topk_scores, topk_ids = layers.topk( - flat_curr_scores, k=beam_size * 2) + topk_scores, topk_ids = layers.topk(flat_curr_scores, + k=beam_size * 2) topk_log_probs = topk_scores * length_penalty @@ -726,11 +735,13 @@ class Transformer(Layer): topk_seq = gather_2d_by_gather(alive_seq, topk_beam_index, beam_size, batch_size) topk_seq = layers.concat( - [topk_seq, layers.reshape(topk_ids, topk_ids.shape + [1])], + [topk_seq, + layers.reshape(topk_ids, topk_ids.shape + [1])], axis=2) states = update_states(states, topk_beam_index, beam_size) - eos = layers.fill_constant( - shape=topk_ids.shape, dtype="int64", value=eos_id) + eos = layers.fill_constant(shape=topk_ids.shape, + dtype="int64", + value=eos_id) topk_finished = layers.cast(layers.equal(topk_ids, eos), "float32") #topk_seq: [batch_size, 2*beam_size, i+1] @@ -752,35 +763,37 @@ class Transformer(Layer): def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq, curr_scores, curr_finished): # finished scores - finished_seq = layers.concat( - [ - finished_seq, layers.fill_constant( - shape=[batch_size, beam_size, 1], - dtype="int64", - value=eos_id) - ], - axis=2) + finished_seq = layers.concat([ + finished_seq, + layers.fill_constant(shape=[batch_size, beam_size, 1], + dtype="int64", + value=eos_id) + ], + axis=2) # Set the scores of the unfinished seq in curr_seq to large negative # values curr_scores += (1. - curr_finished) * -inf # concatenating the sequences and scores along beam axis curr_finished_seq = layers.concat([finished_seq, curr_seq], axis=1) - curr_finished_scores = layers.concat( - [finished_scores, curr_scores], axis=1) - curr_finished_flags = layers.concat( - [finished_flags, curr_finished], axis=1) + curr_finished_scores = layers.concat([finished_scores, curr_scores], + axis=1) + curr_finished_flags = layers.concat([finished_flags, curr_finished], + axis=1) _, topk_indexes = layers.topk(curr_finished_scores, k=beam_size) finished_seq = gather_2d_by_gather(curr_finished_seq, topk_indexes, beam_size * 3, batch_size) - finished_scores = gather_2d_by_gather( - curr_finished_scores, topk_indexes, beam_size * 3, batch_size) - finished_flags = gather_2d_by_gather( - curr_finished_flags, topk_indexes, beam_size * 3, batch_size) + finished_scores = gather_2d_by_gather(curr_finished_scores, + topk_indexes, beam_size * 3, + batch_size) + finished_flags = gather_2d_by_gather(curr_finished_flags, + topk_indexes, beam_size * 3, + batch_size) return finished_seq, finished_scores, finished_flags for i in range(max_len): - trg_pos = layers.fill_constant( - shape=trg_word.shape, dtype="int64", value=i) + trg_pos = layers.fill_constant(shape=trg_word.shape, + dtype="int64", + value=i) logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, enc_output, caches) topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk( @@ -809,6 +822,36 @@ class Transformer(Layer): eos_id=1, beam_size=4, max_len=256): + if beam_size == 1: + return self._greedy_search(src_word, + src_pos, + src_slf_attn_bias, + trg_word, + trg_src_attn_bias, + bos_id=bos_id, + eos_id=eos_id, + max_len=max_len) + else: + return self._beam_search(src_word, + src_pos, + src_slf_attn_bias, + trg_word, + trg_src_attn_bias, + bos_id=bos_id, + eos_id=eos_id, + beam_size=beam_size, + max_len=max_len) + + def _beam_search(self, + src_word, + src_pos, + src_slf_attn_bias, + trg_word, + trg_src_attn_bias, + bos_id=0, + eos_id=1, + beam_size=4, + max_len=256): def expand_to_beam_size(tensor, beam_size): tensor = layers.reshape(tensor, [tensor.shape[0], 1] + tensor.shape[1:]) @@ -817,29 +860,34 @@ class Transformer(Layer): return layers.expand(tensor, tile_dims) def merge_batch_beams(tensor): - return layers.reshape( - tensor, [tensor.shape[0] * tensor.shape[1]] + tensor.shape[2:]) + return layers.reshape(tensor, [tensor.shape[0] * tensor.shape[1]] + + tensor.shape[2:]) def split_batch_beams(tensor): - return fluid.layers.reshape( - tensor, shape=[-1, beam_size] + list(tensor.shape[1:])) + return layers.reshape(tensor, + shape=[-1, beam_size] + + list(tensor.shape[1:])) def mask_probs(probs, finished, noend_mask_tensor): # TODO: use where_op finished = layers.cast(finished, dtype=probs.dtype) - probs = layers.elementwise_mul( - layers.expand( - layers.unsqueeze(finished, [2]), - [1, 1, self.trg_vocab_size]), - noend_mask_tensor, - axis=-1) - layers.elementwise_mul( - probs, (finished - 1), axis=0) + probs = layers.elementwise_mul(layers.expand( + layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]), + noend_mask_tensor, + axis=-1) - layers.elementwise_mul( + probs, (finished - 1), axis=0) return probs def gather(x, indices, batch_pos): - topk_coordinates = fluid.layers.stack([batch_pos, indices], axis=2) + topk_coordinates = layers.stack([batch_pos, indices], axis=2) return layers.gather_nd(x, topk_coordinates) + def update_states(func, caches): + for cache in caches: # no need to update static_kv + cache["k"] = func(cache["k"]) + cache["v"] = func(cache["v"]) + return caches + # run encoder enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) @@ -847,33 +895,32 @@ class Transformer(Layer): inf = float(1. * 1e7) batch_size = enc_output.shape[0] max_len = (enc_output.shape[1] + 20) if max_len is None else max_len - vocab_size_tensor = layers.fill_constant( - shape=[1], dtype="int64", value=self.trg_vocab_size) + vocab_size_tensor = layers.fill_constant(shape=[1], + dtype="int64", + value=self.trg_vocab_size) end_token_tensor = to_variable( - np.full( - [batch_size, beam_size], eos_id, dtype="int64")) + np.full([batch_size, beam_size], eos_id, dtype="int64")) noend_array = [-inf] * self.trg_vocab_size noend_array[eos_id] = 0 noend_mask_tensor = to_variable(np.array(noend_array, dtype="float32")) batch_pos = layers.expand( layers.unsqueeze( - to_variable(np.arange( - 0, batch_size, 1, dtype="int64")), [1]), [1, beam_size]) + to_variable(np.arange(0, batch_size, 1, dtype="int64")), [1]), + [1, beam_size]) predict_ids = [] parent_ids = [] ### initialize states of beam search ### log_probs = to_variable( - np.array( - [[0.] + [-inf] * (beam_size - 1)] * batch_size, - dtype="float32")) - finished = to_variable( - np.full( - [batch_size, beam_size], 0, dtype="bool")) + np.array([[0.] + [-inf] * (beam_size - 1)] * batch_size, + dtype="float32")) + finished = to_variable(np.full([batch_size, beam_size], 0, + dtype="bool")) ### initialize inputs and states of transformer decoder ### ## init inputs for decoder, shaped `[batch_size*beam_size, ...]` - trg_word = layers.fill_constant( - shape=[batch_size * beam_size, 1], dtype="int64", value=bos_id) + trg_word = layers.fill_constant(shape=[batch_size * beam_size, 1], + dtype="int64", + value=bos_id) trg_pos = layers.zeros_like(trg_word) trg_src_attn_bias = merge_batch_beams( expand_to_beam_size(trg_src_attn_bias, beam_size)) @@ -881,42 +928,45 @@ class Transformer(Layer): expand_to_beam_size(enc_output, beam_size)) ## init states (caches) for transformer, need to be updated according to selected beam caches = [{ - "k": layers.fill_constant( + "k": + layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_key], dtype=enc_output.dtype, value=0), - "v": layers.fill_constant( + "v": + layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_value], dtype=enc_output.dtype, value=0), } for i in range(self.n_layer)] for i in range(max_len): - trg_pos = layers.fill_constant( - shape=trg_word.shape, dtype="int64", value=i) - caches = map_structure( # can not be reshaped since the 0 size + trg_pos = layers.fill_constant(shape=trg_word.shape, + dtype="int64", + value=i) + caches = update_states( # can not be reshaped since the 0 size lambda x: x if i == 0 else merge_batch_beams(x), caches) logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, enc_output, caches) - caches = map_structure(split_batch_beams, caches) + caches = update_states(split_batch_beams, caches) step_log_probs = split_batch_beams( - fluid.layers.log(fluid.layers.softmax(logits))) + layers.log(layers.softmax(logits))) step_log_probs = mask_probs(step_log_probs, finished, noend_mask_tensor) - log_probs = layers.elementwise_add( - x=step_log_probs, y=log_probs, axis=0) + log_probs = layers.elementwise_add(x=step_log_probs, + y=log_probs, + axis=0) log_probs = layers.reshape(log_probs, [-1, beam_size * self.trg_vocab_size]) scores = log_probs - topk_scores, topk_indices = fluid.layers.topk( - input=scores, k=beam_size) - beam_indices = fluid.layers.elementwise_floordiv(topk_indices, - vocab_size_tensor) - token_indices = fluid.layers.elementwise_mod(topk_indices, - vocab_size_tensor) + topk_scores, topk_indices = layers.topk(input=scores, k=beam_size) + beam_indices = layers.elementwise_floordiv(topk_indices, + vocab_size_tensor) + token_indices = layers.elementwise_mod(topk_indices, + vocab_size_tensor) # update states - caches = map_structure(lambda x: gather(x, beam_indices, batch_pos), + caches = update_states(lambda x: gather(x, beam_indices, batch_pos), caches) log_probs = gather(log_probs, topk_indices, batch_pos) finished = gather(finished, beam_indices, batch_pos) @@ -937,3 +987,75 @@ class Transformer(Layer): finished_scores = topk_scores return finished_seq, finished_scores + + def _greedy_search(self, + src_word, + src_pos, + src_slf_attn_bias, + trg_word, + trg_src_attn_bias, + bos_id=0, + eos_id=1, + max_len=256): + # run encoder + enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) + + # constant number + batch_size = enc_output.shape[0] + max_len = (enc_output.shape[1] + 20) if max_len is None else max_len + end_token_tensor = layers.fill_constant(shape=[batch_size, 1], + dtype="int64", + value=eos_id) + + predict_ids = [] + log_probs = layers.fill_constant(shape=[batch_size, 1], + dtype="float32", + value=0) + trg_word = layers.fill_constant(shape=[batch_size, 1], + dtype="int64", + value=bos_id) + finished = layers.fill_constant(shape=[batch_size, 1], + dtype="bool", + value=0) + + ## init states (caches) for transformer + caches = [{ + "k": + layers.fill_constant(shape=[batch_size, self.n_head, 0, self.d_key], + dtype=enc_output.dtype, + value=0), + "v": + layers.fill_constant( + shape=[batch_size, self.n_head, 0, self.d_value], + dtype=enc_output.dtype, + value=0), + } for i in range(self.n_layer)] + + for i in range(max_len): + trg_pos = layers.fill_constant(shape=trg_word.shape, + dtype="int64", + value=i) + logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, + enc_output, caches) + step_log_probs = layers.log(layers.softmax(logits)) + log_probs = layers.elementwise_add(x=step_log_probs, + y=log_probs, + axis=0) + scores = log_probs + topk_scores, topk_indices = layers.topk(input=scores, k=1) + + finished = layers.logical_or( + finished, layers.equal(topk_indices, end_token_tensor)) + trg_word = topk_indices + log_probs = topk_scores + + predict_ids.append(topk_indices) + + if layers.reduce_all(finished).numpy(): + break + + predict_ids = layers.stack(predict_ids, axis=0) + finished_seq = layers.transpose(predict_ids, [1, 2, 0]) + finished_scores = topk_scores + + return finished_seq, finished_scores diff --git a/dygraph/transformer/train.py b/dygraph/transformer/train.py index 39392c82..75cbb277 100644 --- a/dygraph/transformer/train.py +++ b/dygraph/transformer/train.py @@ -24,7 +24,6 @@ import paddle.fluid as fluid from utils.configure import PDConfig from utils.check import check_gpu, check_version -from utils.load import load_dygraph # include task-specific libs import reader @@ -58,6 +57,25 @@ def do_train(args): max_length=args.max_length, n_head=args.n_head) batch_generator = processor.data_generator(phase="train") + if args.validation_file: + val_processor = reader.DataProcessor( + fpattern=args.validation_file, + src_vocab_fpath=args.src_vocab_fpath, + trg_vocab_fpath=args.trg_vocab_fpath, + token_delimiter=args.token_delimiter, + use_token_batch=args.use_token_batch, + batch_size=args.batch_size, + device_count=trainer_count, + pool_size=args.pool_size, + sort_type=args.sort_type, + shuffle=False, + shuffle_batch=False, + start_mark=args.special_token[0], + end_mark=args.special_token[1], + unk_mark=args.special_token[2], + max_length=args.max_length, + n_head=args.n_head) + val_batch_generator = val_processor.data_generator(phase="train") if trainer_count > 1: # for multi-process gpu training batch_generator = fluid.contrib.reader.distributed_batch_reader( batch_generator) @@ -74,6 +92,9 @@ def do_train(args): # define data loader train_loader = fluid.io.DataLoader.from_generator(capacity=10) train_loader.set_batch_generator(batch_generator, places=place) + if args.validation_file: + val_loader = fluid.io.DataLoader.from_generator(capacity=10) + val_loader.set_batch_generator(val_batch_generator, places=place) # define model transformer = Transformer( @@ -98,13 +119,13 @@ def do_train(args): ## init from some checkpoint, to resume the previous training if args.init_from_checkpoint: - model_dict, opt_dict = load_dygraph( + model_dict, opt_dict = fluid.load_dygraph( os.path.join(args.init_from_checkpoint, "transformer")) transformer.load_dict(model_dict) optimizer.set_dict(opt_dict) ## init from some pretrain models, to better solve the current task if args.init_from_pretrain_model: - model_dict, _ = load_dygraph( + model_dict, _ = fluid.load_dygraph( os.path.join(args.init_from_pretrain_model, "transformer")) transformer.load_dict(model_dict) @@ -174,13 +195,38 @@ def do_train(args): total_avg_cost - loss_normalizer, np.exp([min(total_avg_cost, 100)]), args.print_step / (time.time() - avg_batch_time))) - ce_ppl.append(np.exp([min(total_avg_cost, 100)])) avg_batch_time = time.time() - if step_idx % args.save_step == 0 and step_idx != 0 and ( - trainer_count == 1 - or fluid.dygraph.parallel.Env().dev_id == 0): - if args.save_model: + + if step_idx % args.save_step == 0 and step_idx != 0: + # validation + if args.validation_file: + transformer.eval() + total_sum_cost = 0 + total_token_num = 0 + for input_data in val_loader(): + (src_word, src_pos, src_slf_attn_bias, trg_word, + trg_pos, trg_slf_attn_bias, trg_src_attn_bias, + lbl_word, lbl_weight) = input_data + logits = transformer(src_word, src_pos, + src_slf_attn_bias, trg_word, + trg_pos, trg_slf_attn_bias, + trg_src_attn_bias) + sum_cost, avg_cost, token_num = criterion( + logits, lbl_word, lbl_weight) + total_sum_cost += sum_cost.numpy() + total_token_num += token_num.numpy() + total_avg_cost = total_sum_cost / total_token_num + logging.info("validation, step_idx: %d, avg loss: %f, " + "normalized loss: %f, ppl: %f" % + (step_idx, total_avg_cost, + total_avg_cost - loss_normalizer, + np.exp([min(total_avg_cost, 100)]))) + transformer.train() + + if args.save_model and ( + trainer_count == 1 + or fluid.dygraph.parallel.Env().dev_id == 0): model_dir = os.path.join(args.save_model, "step_" + str(step_idx)) if not os.path.exists(model_dir): diff --git a/dygraph/transformer/transformer.yaml b/dygraph/transformer/transformer.yaml index 76151f20..15d9e783 100644 --- a/dygraph/transformer/transformer.yaml +++ b/dygraph/transformer/transformer.yaml @@ -19,6 +19,8 @@ inference_model_dir: "infer_model" random_seed: None # The pattern to match training data files. training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de" +# The pattern to match validation data files. +validation_file: "wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de" # The pattern to match test data files. predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de" # The file to output the translation results of predict_file to. -- GitLab