未验证 提交 64eb26ae 编写于 作者: G Guo Sheng 提交者: GitHub

Add validation for dygraph Transformer. (#4628)

Add cross-attention cache for dygraph Transformer.
Add greedy search for dygraph Transformer.
上级 2746e74b
...@@ -76,6 +76,7 @@ python -u train.py \ ...@@ -76,6 +76,7 @@ python -u train.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \ --special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 --batch_size 4096
``` ```
...@@ -91,6 +92,7 @@ python -u train.py \ ...@@ -91,6 +92,7 @@ python -u train.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \ --special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \ --batch_size 4096 \
--n_head 16 \ --n_head 16 \
--d_model 1024 \ --d_model 1024 \
...@@ -121,10 +123,11 @@ Paddle动态图支持多进程多卡进行模型训练,启动训练的方式 ...@@ -121,10 +123,11 @@ Paddle动态图支持多进程多卡进行模型训练,启动训练的方式
```sh ```sh
python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 --log_dir ./mylog train.py \ python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 --log_dir ./mylog train.py \
--epoch 30 \ --epoch 30 \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \ --special_token '<s>' '<e>' '<unk>' \
--training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \ --batch_size 4096 \
--print_step 100 \ --print_step 100 \
--use_cuda True \ --use_cuda True \
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TrainTaskConfig(object):
"""
TrainTaskConfig
"""
# the epoch number to train.
pass_num = 20
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size = 32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 2.0
beta1 = 0.9
beta2 = 0.997
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
class InferTaskConfig(object):
# the number of examples in one run for sequence generation.
batch_size = 4
# the parameters for beam search.
beam_size = 4
alpha = 0.6
# max decoded length, should be less than ModelHyperParams.max_length
max_out_len = 30
class ModelHyperParams(object):
"""
ModelHyperParams
"""
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
trg_vocab_size = 10000
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences deciding the size of position encoding table.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = False
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table", )
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
# "init_score",
# "init_idx",
"trg_src_attn_bias", )
def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
...@@ -18,12 +18,9 @@ import numpy as np ...@@ -18,12 +18,9 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.layers.utils import map_structure
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable
from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay
from config import word_emb_param_names, pos_enc_param_names
def position_encoding_init(n_position, d_pos_vec): def position_encoding_init(n_position, d_pos_vec):
""" """
...@@ -34,10 +31,10 @@ def position_encoding_init(n_position, d_pos_vec): ...@@ -34,10 +31,10 @@ def position_encoding_init(n_position, d_pos_vec):
num_timescales = channels // 2 num_timescales = channels // 2
log_timescale_increment = (np.log(float(1e4) / float(1)) / log_timescale_increment = (np.log(float(1e4) / float(1)) /
(num_timescales - 1)) (num_timescales - 1))
inv_timescales = np.exp(np.arange( inv_timescales = np.exp(
num_timescales)) * -log_timescale_increment np.arange(num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, scaled_time = np.expand_dims(position, 1) * np.expand_dims(
0) inv_timescales, 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant') signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal position_enc = signal
...@@ -48,7 +45,6 @@ class NoamDecay(LearningRateDecay): ...@@ -48,7 +45,6 @@ class NoamDecay(LearningRateDecay):
""" """
learning rate scheduler learning rate scheduler
""" """
def __init__(self, def __init__(self,
d_model, d_model,
warmup_steps, warmup_steps,
...@@ -73,7 +69,6 @@ class PrePostProcessLayer(Layer): ...@@ -73,7 +69,6 @@ class PrePostProcessLayer(Layer):
""" """
PrePostProcessLayer PrePostProcessLayer
""" """
def __init__(self, process_cmd, d_model, dropout_rate): def __init__(self, process_cmd, d_model, dropout_rate):
super(PrePostProcessLayer, self).__init__() super(PrePostProcessLayer, self).__init__()
self.process_cmd = process_cmd self.process_cmd = process_cmd
...@@ -84,8 +79,8 @@ class PrePostProcessLayer(Layer): ...@@ -84,8 +79,8 @@ class PrePostProcessLayer(Layer):
elif cmd == "n": # add layer normalization elif cmd == "n": # add layer normalization
self.functors.append( self.functors.append(
self.add_sublayer( self.add_sublayer(
"layer_norm_%d" % len( "layer_norm_%d" %
self.sublayers(include_sublayers=False)), len(self.sublayers(include_sublayers=False)),
LayerNorm( LayerNorm(
normalized_shape=d_model, normalized_shape=d_model,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
...@@ -93,9 +88,9 @@ class PrePostProcessLayer(Layer): ...@@ -93,9 +88,9 @@ class PrePostProcessLayer(Layer):
bias_attr=fluid.ParamAttr( bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.))))) initializer=fluid.initializer.Constant(0.)))))
elif cmd == "d": # add dropout elif cmd == "d": # add dropout
if dropout_rate: self.functors.append(lambda x: layers.dropout(
self.functors.append(lambda x: layers.dropout( x, dropout_prob=dropout_rate, is_test=False)
x, dropout_prob=dropout_rate, is_test=False)) if dropout_rate else x)
def forward(self, x, residual=None): def forward(self, x, residual=None):
for i, cmd in enumerate(self.process_cmd): for i, cmd in enumerate(self.process_cmd):
...@@ -110,7 +105,6 @@ class MultiHeadAttention(Layer): ...@@ -110,7 +105,6 @@ class MultiHeadAttention(Layer):
""" """
Multi-Head Attention Multi-Head Attention
""" """
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.): def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
self.n_head = n_head self.n_head = n_head
...@@ -118,49 +112,73 @@ class MultiHeadAttention(Layer): ...@@ -118,49 +112,73 @@ class MultiHeadAttention(Layer):
self.d_value = d_value self.d_value = d_value
self.d_model = d_model self.d_model = d_model
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.q_fc = Linear( self.q_fc = Linear(input_dim=d_model,
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False) output_dim=d_key * n_head,
self.k_fc = Linear( bias_attr=False)
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False) self.k_fc = Linear(input_dim=d_model,
self.v_fc = Linear( output_dim=d_key * n_head,
input_dim=d_model, output_dim=d_value * n_head, bias_attr=False) bias_attr=False)
self.proj_fc = Linear( self.v_fc = Linear(input_dim=d_model,
input_dim=d_value * n_head, output_dim=d_model, bias_attr=False) output_dim=d_value * n_head,
bias_attr=False)
def forward(self, queries, keys, values, attn_bias, cache=None): self.proj_fc = Linear(input_dim=d_value * n_head,
# compute q ,k ,v output_dim=d_model,
keys = queries if keys is None else keys bias_attr=False)
values = keys if values is None else values
def _prepare_qkv(self, queries, keys, values, cache=None):
if keys is None: # self-attention
keys, values = queries, queries
static_kv = False
else: # cross-attention
static_kv = True
q = self.q_fc(queries) q = self.q_fc(queries)
k = self.k_fc(keys)
v = self.v_fc(values)
# split head
q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key]) q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
q = layers.transpose(x=q, perm=[0, 2, 1, 3]) q = layers.transpose(x=q, perm=[0, 2, 1, 3])
k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
k = layers.transpose(x=k, perm=[0, 2, 1, 3]) if cache is not None and static_kv and "static_k" in cache:
v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value]) # for encoder-decoder attention in inference and has cached
v = layers.transpose(x=v, perm=[0, 2, 1, 3]) k = cache["static_k"]
v = cache["static_v"]
else:
k = self.k_fc(keys)
v = self.v_fc(values)
k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
k = layers.transpose(x=k, perm=[0, 2, 1, 3])
v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
v = layers.transpose(x=v, perm=[0, 2, 1, 3])
if cache is not None: if cache is not None:
cache_k, cache_v = cache["k"], cache["v"] if static_kv and not "static_k" in cache:
k = layers.concat([cache_k, k], axis=2) # for encoder-decoder attention in inference and has not cached
v = layers.concat([cache_v, v], axis=2) cache["static_k"], cache["static_v"] = k, v
cache["k"], cache["v"] = k, v elif not static_kv:
# for decoder self-attention in inference
cache_k, cache_v = cache["k"], cache["v"]
k = layers.concat([cache_k, k], axis=2)
v = layers.concat([cache_v, v], axis=2)
cache["k"], cache["v"] = k, v
return q, k, v
def forward(self, queries, keys, values, attn_bias, cache=None):
# compute q ,k ,v
q, k, v = self._prepare_qkv(queries, keys, values, cache)
# scale dot product attention # scale dot product attention
product = layers.matmul( product = layers.matmul(x=q,
x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5) y=k,
transpose_y=True,
alpha=self.d_model**-0.5)
if attn_bias is not None: if attn_bias is not None:
product += attn_bias product += attn_bias
weights = layers.softmax(product) weights = layers.softmax(product)
if self.dropout_rate: if self.dropout_rate:
weights = layers.dropout( weights = layers.dropout(weights,
weights, dropout_prob=self.dropout_rate, is_test=False) dropout_prob=self.dropout_rate,
is_test=False)
out = layers.matmul(weights, v) out = layers.matmul(weights, v)
# combine heads # combine heads
out = layers.transpose(out, perm=[0, 2, 1, 3]) out = layers.transpose(out, perm=[0, 2, 1, 3])
...@@ -175,7 +193,6 @@ class FFN(Layer): ...@@ -175,7 +193,6 @@ class FFN(Layer):
""" """
Feed-Forward Network Feed-Forward Network
""" """
def __init__(self, d_inner_hid, d_model, dropout_rate): def __init__(self, d_inner_hid, d_model, dropout_rate):
super(FFN, self).__init__() super(FFN, self).__init__()
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
...@@ -185,8 +202,9 @@ class FFN(Layer): ...@@ -185,8 +202,9 @@ class FFN(Layer):
def forward(self, x): def forward(self, x):
hidden = self.fc1(x) hidden = self.fc1(x)
if self.dropout_rate: if self.dropout_rate:
hidden = layers.dropout( hidden = layers.dropout(hidden,
hidden, dropout_prob=self.dropout_rate, is_test=False) dropout_prob=self.dropout_rate,
is_test=False)
out = self.fc2(hidden) out = self.fc2(hidden)
return out return out
...@@ -195,7 +213,6 @@ class EncoderLayer(Layer): ...@@ -195,7 +213,6 @@ class EncoderLayer(Layer):
""" """
EncoderLayer EncoderLayer
""" """
def __init__(self, def __init__(self,
n_head, n_head,
d_key, d_key,
...@@ -224,8 +241,8 @@ class EncoderLayer(Layer): ...@@ -224,8 +241,8 @@ class EncoderLayer(Layer):
prepostprocess_dropout) prepostprocess_dropout)
def forward(self, enc_input, attn_bias): def forward(self, enc_input, attn_bias):
attn_output = self.self_attn( attn_output = self.self_attn(self.preprocesser1(enc_input), None, None,
self.preprocesser1(enc_input), None, None, attn_bias) attn_bias)
attn_output = self.postprocesser1(attn_output, enc_input) attn_output = self.postprocesser1(attn_output, enc_input)
ffn_output = self.ffn(self.preprocesser2(attn_output)) ffn_output = self.ffn(self.preprocesser2(attn_output))
...@@ -237,7 +254,6 @@ class Encoder(Layer): ...@@ -237,7 +254,6 @@ class Encoder(Layer):
""" """
encoder encoder
""" """
def __init__(self, def __init__(self,
n_layer, n_layer,
n_head, n_head,
...@@ -277,7 +293,6 @@ class Embedder(Layer): ...@@ -277,7 +293,6 @@ class Embedder(Layer):
""" """
Word Embedding + Position Encoding Word Embedding + Position Encoding
""" """
def __init__(self, vocab_size, emb_dim, bos_idx=0): def __init__(self, vocab_size, emb_dim, bos_idx=0):
super(Embedder, self).__init__() super(Embedder, self).__init__()
...@@ -296,7 +311,6 @@ class WrapEncoder(Layer): ...@@ -296,7 +311,6 @@ class WrapEncoder(Layer):
""" """
embedder + encoder embedder + encoder
""" """
def __init__(self, src_vocab_size, max_length, n_layer, n_head, d_key, def __init__(self, src_vocab_size, max_length, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd, attention_dropout, relu_dropout, preprocess_cmd,
...@@ -324,9 +338,9 @@ class WrapEncoder(Layer): ...@@ -324,9 +338,9 @@ class WrapEncoder(Layer):
pos_enc = self.pos_encoder(src_pos) pos_enc = self.pos_encoder(src_pos)
pos_enc.stop_gradient = True pos_enc.stop_gradient = True
emb = word_emb + pos_enc emb = word_emb + pos_enc
enc_input = layers.dropout( enc_input = layers.dropout(emb,
emb, dropout_prob=self.emb_dropout, dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb is_test=False) if self.emb_dropout else emb
enc_output = self.encoder(enc_input, src_slf_attn_bias) enc_output = self.encoder(enc_input, src_slf_attn_bias)
return enc_output return enc_output
...@@ -336,7 +350,6 @@ class DecoderLayer(Layer): ...@@ -336,7 +350,6 @@ class DecoderLayer(Layer):
""" """
decoder decoder
""" """
def __init__(self, def __init__(self,
n_head, n_head,
d_key, d_key,
...@@ -376,13 +389,13 @@ class DecoderLayer(Layer): ...@@ -376,13 +389,13 @@ class DecoderLayer(Layer):
self_attn_bias, self_attn_bias,
cross_attn_bias, cross_attn_bias,
cache=None): cache=None):
self_attn_output = self.self_attn( self_attn_output = self.self_attn(self.preprocesser1(dec_input), None,
self.preprocesser1(dec_input), None, None, self_attn_bias, cache) None, self_attn_bias, cache)
self_attn_output = self.postprocesser1(self_attn_output, dec_input) self_attn_output = self.postprocesser1(self_attn_output, dec_input)
cross_attn_output = self.cross_attn( cross_attn_output = self.cross_attn(
self.preprocesser2(self_attn_output), enc_output, enc_output, self.preprocesser2(self_attn_output), enc_output, enc_output,
cross_attn_bias) cross_attn_bias, cache)
cross_attn_output = self.postprocesser2(cross_attn_output, cross_attn_output = self.postprocesser2(cross_attn_output,
self_attn_output) self_attn_output)
...@@ -396,7 +409,6 @@ class Decoder(Layer): ...@@ -396,7 +409,6 @@ class Decoder(Layer):
""" """
decoder decoder
""" """
def __init__(self, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, def __init__(self, n_layer, n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout, relu_dropout, prepostprocess_dropout, attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd): preprocess_cmd, postprocess_cmd):
...@@ -422,8 +434,8 @@ class Decoder(Layer): ...@@ -422,8 +434,8 @@ class Decoder(Layer):
caches=None): caches=None):
for i, decoder_layer in enumerate(self.decoder_layers): for i, decoder_layer in enumerate(self.decoder_layers):
dec_output = decoder_layer(dec_input, enc_output, self_attn_bias, dec_output = decoder_layer(dec_input, enc_output, self_attn_bias,
cross_attn_bias, None cross_attn_bias,
if caches is None else caches[i]) None if caches is None else caches[i])
dec_input = dec_output dec_input = dec_output
return self.processer(dec_output) return self.processer(dec_output)
...@@ -433,7 +445,6 @@ class WrapDecoder(Layer): ...@@ -433,7 +445,6 @@ class WrapDecoder(Layer):
""" """
embedder + decoder embedder + decoder
""" """
def __init__(self, trg_vocab_size, max_length, n_layer, n_head, d_key, def __init__(self, trg_vocab_size, max_length, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd, attention_dropout, relu_dropout, preprocess_cmd,
...@@ -461,8 +472,9 @@ class WrapDecoder(Layer): ...@@ -461,8 +472,9 @@ class WrapDecoder(Layer):
word_embedder.weight, word_embedder.weight,
transpose_y=True) transpose_y=True)
else: else:
self.linear = Linear( self.linear = Linear(input_dim=d_model,
input_dim=d_model, output_dim=trg_vocab_size, bias_attr=False) output_dim=trg_vocab_size,
bias_attr=False)
def forward(self, def forward(self,
trg_word, trg_word,
...@@ -476,14 +488,15 @@ class WrapDecoder(Layer): ...@@ -476,14 +488,15 @@ class WrapDecoder(Layer):
pos_enc = self.pos_encoder(trg_pos) pos_enc = self.pos_encoder(trg_pos)
pos_enc.stop_gradient = True pos_enc.stop_gradient = True
emb = word_emb + pos_enc emb = word_emb + pos_enc
dec_input = layers.dropout( dec_input = layers.dropout(emb,
emb, dropout_prob=self.emb_dropout, dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb is_test=False) if self.emb_dropout else emb
dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias, dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias,
trg_src_attn_bias, caches) trg_src_attn_bias, caches)
dec_output = layers.reshape( dec_output = layers.reshape(
dec_output, dec_output,
shape=[-1, dec_output.shape[-1]], ) shape=[-1, dec_output.shape[-1]],
)
logits = self.linear(dec_output) logits = self.linear(dec_output)
return logits return logits
...@@ -494,10 +507,9 @@ class CrossEntropyCriterion(object): ...@@ -494,10 +507,9 @@ class CrossEntropyCriterion(object):
def __call__(self, predict, label, weights): def __call__(self, predict, label, weights):
if self.label_smooth_eps: if self.label_smooth_eps:
label_out = layers.label_smooth( label_out = layers.label_smooth(label=layers.one_hot(
label=layers.one_hot( input=label, depth=predict.shape[-1]),
input=label, depth=predict.shape[-1]), epsilon=self.label_smooth_eps)
epsilon=self.label_smooth_eps)
cost = layers.softmax_with_cross_entropy( cost = layers.softmax_with_cross_entropy(
logits=predict, logits=predict,
...@@ -515,7 +527,6 @@ class Transformer(Layer): ...@@ -515,7 +527,6 @@ class Transformer(Layer):
""" """
model model
""" """
def __init__(self, def __init__(self,
src_vocab_size, src_vocab_size,
trg_vocab_size, trg_vocab_size,
...@@ -535,25 +546,29 @@ class Transformer(Layer): ...@@ -535,25 +546,29 @@ class Transformer(Layer):
bos_id=0, bos_id=0,
eos_id=1): eos_id=1):
super(Transformer, self).__init__() super(Transformer, self).__init__()
src_word_embedder = Embedder( src_word_embedder = Embedder(vocab_size=src_vocab_size,
vocab_size=src_vocab_size, emb_dim=d_model, bos_idx=bos_id) emb_dim=d_model,
self.encoder = WrapEncoder( bos_idx=bos_id)
src_vocab_size, max_length, n_layer, n_head, d_key, d_value, self.encoder = WrapEncoder(src_vocab_size, max_length, n_layer, n_head,
d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, d_key, d_value, d_model, d_inner_hid,
relu_dropout, preprocess_cmd, postprocess_cmd, src_word_embedder) prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd, src_word_embedder)
if weight_sharing: if weight_sharing:
assert src_vocab_size == trg_vocab_size, ( assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing." "Vocabularies in source and target should be same for weight sharing."
) )
trg_word_embedder = src_word_embedder trg_word_embedder = src_word_embedder
else: else:
trg_word_embedder = Embedder( trg_word_embedder = Embedder(vocab_size=trg_vocab_size,
vocab_size=trg_vocab_size, emb_dim=d_model, bos_idx=bos_id) emb_dim=d_model,
self.decoder = WrapDecoder( bos_idx=bos_id)
trg_vocab_size, max_length, n_layer, n_head, d_key, d_value, self.decoder = WrapDecoder(trg_vocab_size, max_length, n_layer, n_head,
d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, d_key, d_value, d_model, d_inner_hid,
relu_dropout, preprocess_cmd, postprocess_cmd, weight_sharing, prepostprocess_dropout, attention_dropout,
trg_word_embedder) relu_dropout, preprocess_cmd,
postprocess_cmd, weight_sharing,
trg_word_embedder)
self.trg_vocab_size = trg_vocab_size self.trg_vocab_size = trg_vocab_size
self.n_layer = n_layer self.n_layer = n_layer
...@@ -583,18 +598,14 @@ class Transformer(Layer): ...@@ -583,18 +598,14 @@ class Transformer(Layer):
Beam search with the alive and finished two queues, both have a beam size Beam search with the alive and finished two queues, both have a beam size
capicity separately. It includes `grow_topk` `grow_alive` `grow_finish` as capicity separately. It includes `grow_topk` `grow_alive` `grow_finish` as
steps. steps.
1. `grow_topk` selects the top `2*beam_size` candidates to avoid all getting 1. `grow_topk` selects the top `2*beam_size` candidates to avoid all getting
EOS. EOS.
2. `grow_alive` selects the top `beam_size` non-EOS candidates as the inputs 2. `grow_alive` selects the top `beam_size` non-EOS candidates as the inputs
of next decoding step. of next decoding step.
3. `grow_finish` compares the already finished candidates in the finished queue 3. `grow_finish` compares the already finished candidates in the finished queue
and newly added finished candidates from `grow_topk`, and selects the top and newly added finished candidates from `grow_topk`, and selects the top
`beam_size` finished candidates. `beam_size` finished candidates.
""" """
def expand_to_beam_size(tensor, beam_size): def expand_to_beam_size(tensor, beam_size):
tensor = layers.reshape(tensor, tensor = layers.reshape(tensor,
[tensor.shape[0], 1] + tensor.shape[1:]) [tensor.shape[0], 1] + tensor.shape[1:])
...@@ -616,23 +627,19 @@ class Transformer(Layer): ...@@ -616,23 +627,19 @@ class Transformer(Layer):
### initialize states of beam search ### ### initialize states of beam search ###
## init for the alive ## ## init for the alive ##
initial_log_probs = to_variable( initial_log_probs = to_variable(
np.array( np.array([[0.] + [-inf] * (beam_size - 1)], dtype="float32"))
[[0.] + [-inf] * (beam_size - 1)], dtype="float32"))
alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1]) alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1])
alive_seq = to_variable( alive_seq = to_variable(
np.tile( np.tile(np.array([[[bos_id]]], dtype="int64"),
np.array( (batch_size, beam_size, 1)))
[[[bos_id]]], dtype="int64"), (batch_size, beam_size, 1)))
## init for the finished ## ## init for the finished ##
finished_scores = to_variable( finished_scores = to_variable(
np.array( np.array([[-inf] * beam_size], dtype="float32"))
[[-inf] * beam_size], dtype="float32"))
finished_scores = layers.expand(finished_scores, [batch_size, 1]) finished_scores = layers.expand(finished_scores, [batch_size, 1])
finished_seq = to_variable( finished_seq = to_variable(
np.tile( np.tile(np.array([[[bos_id]]], dtype="int64"),
np.array( (batch_size, beam_size, 1)))
[[[bos_id]]], dtype="int64"), (batch_size, beam_size, 1)))
finished_flags = layers.zeros_like(finished_scores) finished_flags = layers.zeros_like(finished_scores)
### initialize inputs and states of transformer decoder ### ### initialize inputs and states of transformer decoder ###
...@@ -644,11 +651,13 @@ class Transformer(Layer): ...@@ -644,11 +651,13 @@ class Transformer(Layer):
enc_output = merge_beam_dim(expand_to_beam_size(enc_output, beam_size)) enc_output = merge_beam_dim(expand_to_beam_size(enc_output, beam_size))
## init states (caches) for transformer, need to be updated according to selected beam ## init states (caches) for transformer, need to be updated according to selected beam
caches = [{ caches = [{
"k": layers.fill_constant( "k":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_key], shape=[batch_size * beam_size, self.n_head, 0, self.d_key],
dtype=enc_output.dtype, dtype=enc_output.dtype,
value=0), value=0),
"v": layers.fill_constant( "v":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_value], shape=[batch_size * beam_size, self.n_head, 0, self.d_value],
dtype=enc_output.dtype, dtype=enc_output.dtype,
value=0), value=0),
...@@ -667,11 +676,11 @@ class Transformer(Layer): ...@@ -667,11 +676,11 @@ class Transformer(Layer):
beam_size, beam_size,
batch_size, batch_size,
need_flat=True): need_flat=True):
batch_idx = layers.range( batch_idx = layers.range(0, batch_size, 1,
0, batch_size, 1, dtype="int64") * beam_size dtype="int64") * beam_size
flat_tensor = merge_beam_dim(tensor_nd) if need_flat else tensor_nd flat_tensor = merge_beam_dim(tensor_nd) if need_flat else tensor_nd
idx = layers.reshape( idx = layers.reshape(layers.elementwise_add(beam_idx, batch_idx, 0),
layers.elementwise_add(beam_idx, batch_idx, 0), [-1]) [-1])
new_flat_tensor = layers.gather(flat_tensor, idx) new_flat_tensor = layers.gather(flat_tensor, idx)
new_tensor_nd = layers.reshape( new_tensor_nd = layers.reshape(
new_flat_tensor, new_flat_tensor,
...@@ -714,8 +723,8 @@ class Transformer(Layer): ...@@ -714,8 +723,8 @@ class Transformer(Layer):
curr_scores = log_probs / length_penalty curr_scores = log_probs / length_penalty
flat_curr_scores = layers.reshape(curr_scores, [batch_size, -1]) flat_curr_scores = layers.reshape(curr_scores, [batch_size, -1])
topk_scores, topk_ids = layers.topk( topk_scores, topk_ids = layers.topk(flat_curr_scores,
flat_curr_scores, k=beam_size * 2) k=beam_size * 2)
topk_log_probs = topk_scores * length_penalty topk_log_probs = topk_scores * length_penalty
...@@ -726,11 +735,13 @@ class Transformer(Layer): ...@@ -726,11 +735,13 @@ class Transformer(Layer):
topk_seq = gather_2d_by_gather(alive_seq, topk_beam_index, topk_seq = gather_2d_by_gather(alive_seq, topk_beam_index,
beam_size, batch_size) beam_size, batch_size)
topk_seq = layers.concat( topk_seq = layers.concat(
[topk_seq, layers.reshape(topk_ids, topk_ids.shape + [1])], [topk_seq,
layers.reshape(topk_ids, topk_ids.shape + [1])],
axis=2) axis=2)
states = update_states(states, topk_beam_index, beam_size) states = update_states(states, topk_beam_index, beam_size)
eos = layers.fill_constant( eos = layers.fill_constant(shape=topk_ids.shape,
shape=topk_ids.shape, dtype="int64", value=eos_id) dtype="int64",
value=eos_id)
topk_finished = layers.cast(layers.equal(topk_ids, eos), "float32") topk_finished = layers.cast(layers.equal(topk_ids, eos), "float32")
#topk_seq: [batch_size, 2*beam_size, i+1] #topk_seq: [batch_size, 2*beam_size, i+1]
...@@ -752,35 +763,37 @@ class Transformer(Layer): ...@@ -752,35 +763,37 @@ class Transformer(Layer):
def grow_finished(finished_seq, finished_scores, finished_flags, def grow_finished(finished_seq, finished_scores, finished_flags,
curr_seq, curr_scores, curr_finished): curr_seq, curr_scores, curr_finished):
# finished scores # finished scores
finished_seq = layers.concat( finished_seq = layers.concat([
[ finished_seq,
finished_seq, layers.fill_constant( layers.fill_constant(shape=[batch_size, beam_size, 1],
shape=[batch_size, beam_size, 1], dtype="int64",
dtype="int64", value=eos_id)
value=eos_id) ],
], axis=2)
axis=2)
# Set the scores of the unfinished seq in curr_seq to large negative # Set the scores of the unfinished seq in curr_seq to large negative
# values # values
curr_scores += (1. - curr_finished) * -inf curr_scores += (1. - curr_finished) * -inf
# concatenating the sequences and scores along beam axis # concatenating the sequences and scores along beam axis
curr_finished_seq = layers.concat([finished_seq, curr_seq], axis=1) curr_finished_seq = layers.concat([finished_seq, curr_seq], axis=1)
curr_finished_scores = layers.concat( curr_finished_scores = layers.concat([finished_scores, curr_scores],
[finished_scores, curr_scores], axis=1) axis=1)
curr_finished_flags = layers.concat( curr_finished_flags = layers.concat([finished_flags, curr_finished],
[finished_flags, curr_finished], axis=1) axis=1)
_, topk_indexes = layers.topk(curr_finished_scores, k=beam_size) _, topk_indexes = layers.topk(curr_finished_scores, k=beam_size)
finished_seq = gather_2d_by_gather(curr_finished_seq, topk_indexes, finished_seq = gather_2d_by_gather(curr_finished_seq, topk_indexes,
beam_size * 3, batch_size) beam_size * 3, batch_size)
finished_scores = gather_2d_by_gather( finished_scores = gather_2d_by_gather(curr_finished_scores,
curr_finished_scores, topk_indexes, beam_size * 3, batch_size) topk_indexes, beam_size * 3,
finished_flags = gather_2d_by_gather( batch_size)
curr_finished_flags, topk_indexes, beam_size * 3, batch_size) finished_flags = gather_2d_by_gather(curr_finished_flags,
topk_indexes, beam_size * 3,
batch_size)
return finished_seq, finished_scores, finished_flags return finished_seq, finished_scores, finished_flags
for i in range(max_len): for i in range(max_len):
trg_pos = layers.fill_constant( trg_pos = layers.fill_constant(shape=trg_word.shape,
shape=trg_word.shape, dtype="int64", value=i) dtype="int64",
value=i)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, caches) enc_output, caches)
topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk( topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk(
...@@ -809,6 +822,36 @@ class Transformer(Layer): ...@@ -809,6 +822,36 @@ class Transformer(Layer):
eos_id=1, eos_id=1,
beam_size=4, beam_size=4,
max_len=256): max_len=256):
if beam_size == 1:
return self._greedy_search(src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=bos_id,
eos_id=eos_id,
max_len=max_len)
else:
return self._beam_search(src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=bos_id,
eos_id=eos_id,
beam_size=beam_size,
max_len=max_len)
def _beam_search(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
beam_size=4,
max_len=256):
def expand_to_beam_size(tensor, beam_size): def expand_to_beam_size(tensor, beam_size):
tensor = layers.reshape(tensor, tensor = layers.reshape(tensor,
[tensor.shape[0], 1] + tensor.shape[1:]) [tensor.shape[0], 1] + tensor.shape[1:])
...@@ -817,29 +860,34 @@ class Transformer(Layer): ...@@ -817,29 +860,34 @@ class Transformer(Layer):
return layers.expand(tensor, tile_dims) return layers.expand(tensor, tile_dims)
def merge_batch_beams(tensor): def merge_batch_beams(tensor):
return layers.reshape( return layers.reshape(tensor, [tensor.shape[0] * tensor.shape[1]] +
tensor, [tensor.shape[0] * tensor.shape[1]] + tensor.shape[2:]) tensor.shape[2:])
def split_batch_beams(tensor): def split_batch_beams(tensor):
return fluid.layers.reshape( return layers.reshape(tensor,
tensor, shape=[-1, beam_size] + list(tensor.shape[1:])) shape=[-1, beam_size] +
list(tensor.shape[1:]))
def mask_probs(probs, finished, noend_mask_tensor): def mask_probs(probs, finished, noend_mask_tensor):
# TODO: use where_op # TODO: use where_op
finished = layers.cast(finished, dtype=probs.dtype) finished = layers.cast(finished, dtype=probs.dtype)
probs = layers.elementwise_mul( probs = layers.elementwise_mul(layers.expand(
layers.expand( layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]),
layers.unsqueeze(finished, [2]), noend_mask_tensor,
[1, 1, self.trg_vocab_size]), axis=-1) - layers.elementwise_mul(
noend_mask_tensor, probs, (finished - 1), axis=0)
axis=-1) - layers.elementwise_mul(
probs, (finished - 1), axis=0)
return probs return probs
def gather(x, indices, batch_pos): def gather(x, indices, batch_pos):
topk_coordinates = fluid.layers.stack([batch_pos, indices], axis=2) topk_coordinates = layers.stack([batch_pos, indices], axis=2)
return layers.gather_nd(x, topk_coordinates) return layers.gather_nd(x, topk_coordinates)
def update_states(func, caches):
for cache in caches: # no need to update static_kv
cache["k"] = func(cache["k"])
cache["v"] = func(cache["v"])
return caches
# run encoder # run encoder
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
...@@ -847,33 +895,32 @@ class Transformer(Layer): ...@@ -847,33 +895,32 @@ class Transformer(Layer):
inf = float(1. * 1e7) inf = float(1. * 1e7)
batch_size = enc_output.shape[0] batch_size = enc_output.shape[0]
max_len = (enc_output.shape[1] + 20) if max_len is None else max_len max_len = (enc_output.shape[1] + 20) if max_len is None else max_len
vocab_size_tensor = layers.fill_constant( vocab_size_tensor = layers.fill_constant(shape=[1],
shape=[1], dtype="int64", value=self.trg_vocab_size) dtype="int64",
value=self.trg_vocab_size)
end_token_tensor = to_variable( end_token_tensor = to_variable(
np.full( np.full([batch_size, beam_size], eos_id, dtype="int64"))
[batch_size, beam_size], eos_id, dtype="int64"))
noend_array = [-inf] * self.trg_vocab_size noend_array = [-inf] * self.trg_vocab_size
noend_array[eos_id] = 0 noend_array[eos_id] = 0
noend_mask_tensor = to_variable(np.array(noend_array, dtype="float32")) noend_mask_tensor = to_variable(np.array(noend_array, dtype="float32"))
batch_pos = layers.expand( batch_pos = layers.expand(
layers.unsqueeze( layers.unsqueeze(
to_variable(np.arange( to_variable(np.arange(0, batch_size, 1, dtype="int64")), [1]),
0, batch_size, 1, dtype="int64")), [1]), [1, beam_size]) [1, beam_size])
predict_ids = [] predict_ids = []
parent_ids = [] parent_ids = []
### initialize states of beam search ### ### initialize states of beam search ###
log_probs = to_variable( log_probs = to_variable(
np.array( np.array([[0.] + [-inf] * (beam_size - 1)] * batch_size,
[[0.] + [-inf] * (beam_size - 1)] * batch_size, dtype="float32"))
dtype="float32")) finished = to_variable(np.full([batch_size, beam_size], 0,
finished = to_variable( dtype="bool"))
np.full(
[batch_size, beam_size], 0, dtype="bool"))
### initialize inputs and states of transformer decoder ### ### initialize inputs and states of transformer decoder ###
## init inputs for decoder, shaped `[batch_size*beam_size, ...]` ## init inputs for decoder, shaped `[batch_size*beam_size, ...]`
trg_word = layers.fill_constant( trg_word = layers.fill_constant(shape=[batch_size * beam_size, 1],
shape=[batch_size * beam_size, 1], dtype="int64", value=bos_id) dtype="int64",
value=bos_id)
trg_pos = layers.zeros_like(trg_word) trg_pos = layers.zeros_like(trg_word)
trg_src_attn_bias = merge_batch_beams( trg_src_attn_bias = merge_batch_beams(
expand_to_beam_size(trg_src_attn_bias, beam_size)) expand_to_beam_size(trg_src_attn_bias, beam_size))
...@@ -881,42 +928,45 @@ class Transformer(Layer): ...@@ -881,42 +928,45 @@ class Transformer(Layer):
expand_to_beam_size(enc_output, beam_size)) expand_to_beam_size(enc_output, beam_size))
## init states (caches) for transformer, need to be updated according to selected beam ## init states (caches) for transformer, need to be updated according to selected beam
caches = [{ caches = [{
"k": layers.fill_constant( "k":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_key], shape=[batch_size * beam_size, self.n_head, 0, self.d_key],
dtype=enc_output.dtype, dtype=enc_output.dtype,
value=0), value=0),
"v": layers.fill_constant( "v":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_value], shape=[batch_size * beam_size, self.n_head, 0, self.d_value],
dtype=enc_output.dtype, dtype=enc_output.dtype,
value=0), value=0),
} for i in range(self.n_layer)] } for i in range(self.n_layer)]
for i in range(max_len): for i in range(max_len):
trg_pos = layers.fill_constant( trg_pos = layers.fill_constant(shape=trg_word.shape,
shape=trg_word.shape, dtype="int64", value=i) dtype="int64",
caches = map_structure( # can not be reshaped since the 0 size value=i)
caches = update_states( # can not be reshaped since the 0 size
lambda x: x if i == 0 else merge_batch_beams(x), caches) lambda x: x if i == 0 else merge_batch_beams(x), caches)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, caches) enc_output, caches)
caches = map_structure(split_batch_beams, caches) caches = update_states(split_batch_beams, caches)
step_log_probs = split_batch_beams( step_log_probs = split_batch_beams(
fluid.layers.log(fluid.layers.softmax(logits))) layers.log(layers.softmax(logits)))
step_log_probs = mask_probs(step_log_probs, finished, step_log_probs = mask_probs(step_log_probs, finished,
noend_mask_tensor) noend_mask_tensor)
log_probs = layers.elementwise_add( log_probs = layers.elementwise_add(x=step_log_probs,
x=step_log_probs, y=log_probs, axis=0) y=log_probs,
axis=0)
log_probs = layers.reshape(log_probs, log_probs = layers.reshape(log_probs,
[-1, beam_size * self.trg_vocab_size]) [-1, beam_size * self.trg_vocab_size])
scores = log_probs scores = log_probs
topk_scores, topk_indices = fluid.layers.topk( topk_scores, topk_indices = layers.topk(input=scores, k=beam_size)
input=scores, k=beam_size) beam_indices = layers.elementwise_floordiv(topk_indices,
beam_indices = fluid.layers.elementwise_floordiv(topk_indices, vocab_size_tensor)
vocab_size_tensor) token_indices = layers.elementwise_mod(topk_indices,
token_indices = fluid.layers.elementwise_mod(topk_indices, vocab_size_tensor)
vocab_size_tensor)
# update states # update states
caches = map_structure(lambda x: gather(x, beam_indices, batch_pos), caches = update_states(lambda x: gather(x, beam_indices, batch_pos),
caches) caches)
log_probs = gather(log_probs, topk_indices, batch_pos) log_probs = gather(log_probs, topk_indices, batch_pos)
finished = gather(finished, beam_indices, batch_pos) finished = gather(finished, beam_indices, batch_pos)
...@@ -937,3 +987,75 @@ class Transformer(Layer): ...@@ -937,3 +987,75 @@ class Transformer(Layer):
finished_scores = topk_scores finished_scores = topk_scores
return finished_seq, finished_scores return finished_seq, finished_scores
def _greedy_search(self,
src_word,
src_pos,
src_slf_attn_bias,
trg_word,
trg_src_attn_bias,
bos_id=0,
eos_id=1,
max_len=256):
# run encoder
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
# constant number
batch_size = enc_output.shape[0]
max_len = (enc_output.shape[1] + 20) if max_len is None else max_len
end_token_tensor = layers.fill_constant(shape=[batch_size, 1],
dtype="int64",
value=eos_id)
predict_ids = []
log_probs = layers.fill_constant(shape=[batch_size, 1],
dtype="float32",
value=0)
trg_word = layers.fill_constant(shape=[batch_size, 1],
dtype="int64",
value=bos_id)
finished = layers.fill_constant(shape=[batch_size, 1],
dtype="bool",
value=0)
## init states (caches) for transformer
caches = [{
"k":
layers.fill_constant(shape=[batch_size, self.n_head, 0, self.d_key],
dtype=enc_output.dtype,
value=0),
"v":
layers.fill_constant(
shape=[batch_size, self.n_head, 0, self.d_value],
dtype=enc_output.dtype,
value=0),
} for i in range(self.n_layer)]
for i in range(max_len):
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias,
enc_output, caches)
step_log_probs = layers.log(layers.softmax(logits))
log_probs = layers.elementwise_add(x=step_log_probs,
y=log_probs,
axis=0)
scores = log_probs
topk_scores, topk_indices = layers.topk(input=scores, k=1)
finished = layers.logical_or(
finished, layers.equal(topk_indices, end_token_tensor))
trg_word = topk_indices
log_probs = topk_scores
predict_ids.append(topk_indices)
if layers.reduce_all(finished).numpy():
break
predict_ids = layers.stack(predict_ids, axis=0)
finished_seq = layers.transpose(predict_ids, [1, 2, 0])
finished_scores = topk_scores
return finished_seq, finished_scores
...@@ -24,7 +24,6 @@ import paddle.fluid as fluid ...@@ -24,7 +24,6 @@ import paddle.fluid as fluid
from utils.configure import PDConfig from utils.configure import PDConfig
from utils.check import check_gpu, check_version from utils.check import check_gpu, check_version
from utils.load import load_dygraph
# include task-specific libs # include task-specific libs
import reader import reader
...@@ -58,6 +57,25 @@ def do_train(args): ...@@ -58,6 +57,25 @@ def do_train(args):
max_length=args.max_length, max_length=args.max_length,
n_head=args.n_head) n_head=args.n_head)
batch_generator = processor.data_generator(phase="train") batch_generator = processor.data_generator(phase="train")
if args.validation_file:
val_processor = reader.DataProcessor(
fpattern=args.validation_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
device_count=trainer_count,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
val_batch_generator = val_processor.data_generator(phase="train")
if trainer_count > 1: # for multi-process gpu training if trainer_count > 1: # for multi-process gpu training
batch_generator = fluid.contrib.reader.distributed_batch_reader( batch_generator = fluid.contrib.reader.distributed_batch_reader(
batch_generator) batch_generator)
...@@ -74,6 +92,9 @@ def do_train(args): ...@@ -74,6 +92,9 @@ def do_train(args):
# define data loader # define data loader
train_loader = fluid.io.DataLoader.from_generator(capacity=10) train_loader = fluid.io.DataLoader.from_generator(capacity=10)
train_loader.set_batch_generator(batch_generator, places=place) train_loader.set_batch_generator(batch_generator, places=place)
if args.validation_file:
val_loader = fluid.io.DataLoader.from_generator(capacity=10)
val_loader.set_batch_generator(val_batch_generator, places=place)
# define model # define model
transformer = Transformer( transformer = Transformer(
...@@ -98,13 +119,13 @@ def do_train(args): ...@@ -98,13 +119,13 @@ def do_train(args):
## init from some checkpoint, to resume the previous training ## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint: if args.init_from_checkpoint:
model_dict, opt_dict = load_dygraph( model_dict, opt_dict = fluid.load_dygraph(
os.path.join(args.init_from_checkpoint, "transformer")) os.path.join(args.init_from_checkpoint, "transformer"))
transformer.load_dict(model_dict) transformer.load_dict(model_dict)
optimizer.set_dict(opt_dict) optimizer.set_dict(opt_dict)
## init from some pretrain models, to better solve the current task ## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model: if args.init_from_pretrain_model:
model_dict, _ = load_dygraph( model_dict, _ = fluid.load_dygraph(
os.path.join(args.init_from_pretrain_model, "transformer")) os.path.join(args.init_from_pretrain_model, "transformer"))
transformer.load_dict(model_dict) transformer.load_dict(model_dict)
...@@ -174,13 +195,38 @@ def do_train(args): ...@@ -174,13 +195,38 @@ def do_train(args):
total_avg_cost - loss_normalizer, total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]), np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time))) args.print_step / (time.time() - avg_batch_time)))
ce_ppl.append(np.exp([min(total_avg_cost, 100)]))
avg_batch_time = time.time() avg_batch_time = time.time()
if step_idx % args.save_step == 0 and step_idx != 0 and (
trainer_count == 1 if step_idx % args.save_step == 0 and step_idx != 0:
or fluid.dygraph.parallel.Env().dev_id == 0): # validation
if args.save_model: if args.validation_file:
transformer.eval()
total_sum_cost = 0
total_token_num = 0
for input_data in val_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word,
trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
lbl_word, lbl_weight) = input_data
logits = transformer(src_word, src_pos,
src_slf_attn_bias, trg_word,
trg_pos, trg_slf_attn_bias,
trg_src_attn_bias)
sum_cost, avg_cost, token_num = criterion(
logits, lbl_word, lbl_weight)
total_sum_cost += sum_cost.numpy()
total_token_num += token_num.numpy()
total_avg_cost = total_sum_cost / total_token_num
logging.info("validation, step_idx: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
transformer.train()
if args.save_model and (
trainer_count == 1
or fluid.dygraph.parallel.Env().dev_id == 0):
model_dir = os.path.join(args.save_model, model_dir = os.path.join(args.save_model,
"step_" + str(step_idx)) "step_" + str(step_idx))
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
......
...@@ -19,6 +19,8 @@ inference_model_dir: "infer_model" ...@@ -19,6 +19,8 @@ inference_model_dir: "infer_model"
random_seed: None random_seed: None
# The pattern to match training data files. # The pattern to match training data files.
training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de" training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de"
# The pattern to match validation data files.
validation_file: "wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de"
# The pattern to match test data files. # The pattern to match test data files.
predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de" predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de"
# The file to output the translation results of predict_file to. # The file to output the translation results of predict_file to.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册