提交 bb45f859 编写于 作者: G gongweibao

add

上级 a9b22c5d
class TrainTaskConfig(object):
use_gpu = True
# the epoch number to train.
pass_num = 50
# the number of sequences contained in a mini-batch.
batch_size = 56
# the hyper parameters for Adam optimizer.
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.98
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 15000
# the flag indicating to use average loss or sum loss when training.
use_avg_cost = False
# the directory for saving trained models.
model_dir = "trained_models"
class InferTaskConfig(object):
use_gpu = True
# the number of examples in one run for sequence generation.
batch_size = 1
# the parameters for beam search.
beam_size = 5
max_length = 30
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False
# the directory for loading the trained model.
model_path = "trained_models/pass_20.infer.model"
class ModelHyperParams(object):
# This model directly uses paddle.dataset.wmt16 in which <bos>, <eos> and
# <unk> token has alreay been added. As for the <pad> token, any token
# included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients.
# size of source word dictionary.
src_vocab_size = 30001
# size of target word dictionay
trg_vocab_size = 30001
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences.
# The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
max_length = 150
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rate used by all dropout layers.
dropout = 0.1
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# Names of all data layers in encoder listed in order.
encoder_input_data_names = (
"src_word",
"src_pos",
"src_slf_attn_bias",
"src_data_shape",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )
# Names of all data layers in decoder listed in order.
decoder_input_data_names = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
"trg_src_attn_post_softmax_shape",
"enc_output", )
# Names of label related data layers listed in order.
label_data_names = (
"lbl_word",
"lbl_weight", )
import numpy as np
import paddle
import paddle.fluid as fluid
import model
from model import wrap_encoder as encoder
from model import wrap_decoder as decoder
from config import InferTaskConfig, ModelHyperParams, \
encoder_input_data_names, decoder_input_data_names
from train import pad_batch_data
import nist_data_provider
def translate_batch(exe,
src_words,
encoder,
enc_in_names,
enc_out_names,
decoder,
dec_in_names,
dec_out_names,
beam_size,
max_length,
n_best,
batch_size,
n_head,
d_model,
src_pad_idx,
trg_pad_idx,
bos_idx,
eos_idx,
unk_idx,
output_unk=True):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally.
"""
# Prepare data for encoder and run the encoder.
enc_in_data = pad_batch_data(
src_words,
src_pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=False)
# Append the data shape input to reshape the output of embedding layer.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1], d_model], dtype="int32")
]
# Append the shape inputs to reshape before and after softmax in encoder
# self attention.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1]], dtype="int32"), np.array(
enc_in_data[2].shape, dtype="int32")
]
enc_output = exe.run(encoder,
feed=dict(zip(enc_in_names, enc_in_data)),
fetch_list=enc_out_names)[0]
# Beam Search.
# To store the beam info.
scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[] for i in range(batch_size)]
next_ids = [[] for i in range(batch_size)]
# Use beam_inst_map to map beam idx to the instance idx in batch, since the
# size of feeded batch is changing.
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(range(batch_size))
}
# Use active_beams to recode the alive.
active_beams = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
"""
Decode and select n_best sequences for one instance by backtrace.
"""
seqs = []
for i in range(n_best):
k = i
seq = []
for j in range(len(prev_branchs) - 1, -1, -1):
seq.append(next_ids[j][k])
k = prev_branchs[j][k]
seq = seq[::-1]
# Add the <bos>, since next_ids don't include the <bos>.
seq = [bos_idx] + seq
seqs.append(seq)
return seqs
def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
"""
Initialize the input data for decoder.
"""
trg_words = np.array(
[[bos_idx]] * batch_size * beam_size, dtype="int64")
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[2].shape[
-1], enc_in_data[2], 1
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len,
trg_max_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_max_len, trg_max_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis],
[1, beam_size, 1, trg_max_len, 1]).reshape([
-1, src_slf_attn_bias.shape[1], trg_max_len,
src_slf_attn_bias.shape[-1]
])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[batch_size * beam_size, trg_max_len, d_model], dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = np.tile(
enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape(
[-1, enc_output.shape[-2], enc_output.shape[-1]])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
"""
Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams.
"""
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output = dec_in_data
trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
trg_words = np.array(
[
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
for beam_idx in active_beams
],
dtype="int64")
trg_words = trg_words.reshape([-1, 1])
trg_pos = np.array(
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1])
active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams]
active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten()
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((len(active_beams) * beam_size, trg_cur_len,
trg_cur_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_cur_len, trg_cur_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[len(active_beams) * beam_size, trg_cur_len, d_model],
dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output)
for i in range(max_length):
predict_all = exe.run(decoder,
feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0]
predict_all = np.log(
predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
[:, -1, :])
predict_all = (predict_all + scores[active_beams].reshape(
[len(beam_inst_map) * beam_size, -1])).reshape(
[len(beam_inst_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = []
for beam_idx in range(batch_size):
if not beam_inst_map.has_key(beam_idx):
continue
inst_idx = beam_inst_map[beam_idx]
predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[::
-1]]
top_scores = predict[top_scores_ids]
scores[beam_idx] = top_scores
prev_branchs[beam_idx].append(top_scores_ids /
predict_all.shape[-1])
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx)
if len(active_beams) == 0:
break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams,
beam_inst_map)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(active_beams)
}
# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)
for beam_idx in range(batch_size)
]
return seqs, scores[:, :n_best].tolist()
def main():
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
encoder_program = fluid.Program()
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(
ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)
# Load model parameters of encoder and decoder separately from the saved
# transformer model.
encoder_var_names = []
for op in encoder_program.block(0).ops:
encoder_var_names += op.input_arg_names
encoder_param_names = filter(
lambda var_name: isinstance(encoder_program.block(0).var(var_name),
fluid.framework.Parameter),
encoder_var_names)
encoder_params = map(encoder_program.block(0).var, encoder_param_names)
decoder_var_names = []
for op in decoder_program.block(0).ops:
decoder_var_names += op.input_arg_names
decoder_param_names = filter(
lambda var_name: isinstance(decoder_program.block(0).var(var_name),
fluid.framework.Parameter),
decoder_var_names)
decoder_params = map(decoder_program.block(0).var, decoder_param_names)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=encoder_params)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params)
# This is used here to set dropout to the test mode.
encoder_program = fluid.io.get_inference_program(
target_vars=[enc_output], main_program=encoder_program)
decoder_program = fluid.io.get_inference_program(
target_vars=[predict], main_program=decoder_program)
'''test_data = paddle.batch(
paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=InferTaskConfig.batch_size)
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)'''
test_data = paddle.batch(
nist_data_provider.test("nist06n.test", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=InferTaskConfig.batch_size)
trg_idx2word = nist_data_provider.get_dict(
"data",
dict_size=ModelHyperParams.trg_vocab_size,
lang="en",
reverse=True)
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe,
[item[0] for item in data],
encoder_program,
encoder_input_data_names,
[enc_output.name],
decoder_program,
decoder_input_data_names,
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
output_unk=InferTaskConfig.output_unk)
for i in range(len(batch_seqs)):
# Post-process the beam-search decoded sequences.
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))
if __name__ == "__main__":
main()
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from config import TrainTaskConfig, pos_enc_param_names, \
encoder_input_data_names, decoder_input_data_names, label_data_names
def position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
"""
position_enc = np.array([[
pos / np.power(10000, 2 * (j // 2) / d_pos_vec)
for j in range(d_pos_vec)
] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
return position_enc.astype("float32")
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
size=d_key * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_key,
fan_out=n_head * d_key),
bias_attr=False,
num_flatten_dims=2)
k = layers.fc(input=keys,
size=d_key * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_key,
fan_out=n_head * d_key),
bias_attr=False,
num_flatten_dims=2)
v = layers.fc(input=values,
size=d_value * n_head,
param_attr=fluid.initializer.Xavier(
uniform=False,
fan_in=d_model * d_value,
fan_out=n_head * d_value),
bias_attr=False,
num_flatten_dims=2)
return q, k, v
def __split_heads(x, n_head):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
if n_head == 1:
return x
hidden_size = x.shape[-1]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, -1, n_head, hidden_size // n_head])
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) == 3: return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]]))
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x=q, scale=d_model**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
weights = layers.reshape(
x=layers.elementwise_add(
x=product, y=attn_bias) if attn_bias else product,
shape=[-1, product.shape[-1]],
actual_shape=pre_softmax_shape,
act="softmax")
weights = layers.reshape(
x=weights, shape=product.shape, actual_shape=post_softmax_shape)
if dropout_rate:
weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False)
out = layers.matmul(weights, v)
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
dropout_rate)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
param_attr=fluid.initializer.Xavier(uniform=False),
bias_attr=False,
num_flatten_dims=2)
return proj_out
def positionwise_feed_forward(x, d_inner_hid, d_hid):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
param_attr=fluid.initializer.Uniform(
low=-(d_hid**-0.5), high=(d_hid**-0.5)),
act="relu")
out = layers.fc(input=hidden,
size=d_hid,
num_flatten_dims=2,
param_attr=fluid.initializer.Uniform(
low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5)))
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = layers.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out, dropout_prob=dropout_rate, is_test=False)
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def prepare_encoder(src_word,
src_pos,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate=0.,
src_data_shape=None,
pos_enc_param_name=None):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
src_word_emb = layers.embedding(
src_word,
size=[src_vocab_size, src_emb_dim],
param_attr=fluid.initializer.Normal(0., 1.))
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc
enc_input = layers.reshape(
x=enc_input,
shape=[-1, src_max_len, src_emb_dim],
actual_shape=src_data_shape)
return layers.dropout(
enc_input, dropout_prob=dropout_rate,
is_test=False) if dropout_rate else enc_input
prepare_encoder = partial(
prepare_encoder, pos_enc_param_name=pos_enc_param_names[0])
prepare_decoder = partial(
prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])
def encoder_layer(enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model,
n_head, dropout_rate, pre_softmax_shape, post_softmax_shape)
attn_output = post_process_layer(enc_input, attn_output, "dan",
dropout_rate)
ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)
def encoder(enc_input,
attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
pre_softmax_shape=None,
post_softmax_shape=None):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
pre_softmax_shape,
post_softmax_shape, )
enc_input = enc_output
return enc_output
def decoder_layer(dec_input,
enc_output,
slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
slf_attn_output = multi_head_attention(
dec_input,
dec_input,
dec_input,
slf_attn_bias,
d_key,
d_value,
d_model,
n_head,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, )
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
"dan", # residual connection + dropout + layer normalization
dropout_rate, )
enc_attn_output = multi_head_attention(
slf_attn_output,
enc_output,
enc_output,
dec_enc_attn_bias,
d_key,
d_value,
d_model,
n_head,
dropout_rate,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
"dan", # residual connection + dropout + layer normalization
dropout_rate, )
ffd_output = positionwise_feed_forward(
enc_attn_output,
d_inner_hid,
d_model, )
dec_output = post_process_layer(
enc_attn_output,
ffd_output,
"dan", # residual connection + dropout + layer normalization
dropout_rate, )
return dec_output
def decoder(dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.,
slf_attn_pre_softmax_shape=None,
slf_attn_post_softmax_shape=None,
src_attn_pre_softmax_shape=None,
src_attn_post_softmax_shape=None):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for i in range(n_layer):
dec_output = decoder_layer(
dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
dec_input = dec_output
return dec_output
def make_inputs(input_data_names,
n_head,
d_model,
max_length,
is_pos,
slf_attn_bias_flag,
src_attn_bias_flag,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=True):
"""
Define the input data layers for the transformer model.
"""
input_layers = []
batch_size = 1 # Only for the infer-shape in compile time.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
# The actual data shape of word is:
# [batch_size * max_len_in_batch, 1]
word = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
input_layers += [word]
# This is used for position data or label weight.
# The actual data shape of pos is:
# [batch_size * max_len_in_batch, 1]
pos = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size * max_length, 1],
dtype="int64" if is_pos else "float32",
append_batch_size=False)
input_layers += [pos]
if slf_attn_bias_flag:
# This input is used to remove attention weights on paddings for the
# encoder and to remove attention weights on subsequent words for the
# decoder.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, max_len_in_batch, max_len_in_batch]
slf_attn_bias = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
input_layers += [slf_attn_bias]
if src_attn_bias_flag:
# This input is used to remove attention weights on paddings. It's used
# in encoder-decoder attention.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, n_head, trg_max_len_in_batch, src_max_len_in_batch]
src_attn_bias = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
input_layers += [src_attn_bias]
if data_shape_flag:
# This input is used to reshape the output of embedding layer.
data_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[3],
dtype="int32",
append_batch_size=False)
input_layers += [data_shape]
if slf_attn_shape_flag:
# This shape input is used to reshape before softmax in self attention.
slf_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[2],
dtype="int32",
append_batch_size=False)
input_layers += [slf_attn_pre_softmax_shape]
# This shape input is used to reshape after softmax in self attention.
slf_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[4],
dtype="int32",
append_batch_size=False)
input_layers += [slf_attn_post_softmax_shape]
if src_attn_shape_flag:
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
src_attn_pre_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[2],
dtype="int32",
append_batch_size=False)
input_layers += [src_attn_pre_softmax_shape]
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
src_attn_post_softmax_shape = layers.data(
name=input_data_names[len(input_layers)],
shape=[4],
dtype="int32",
append_batch_size=False)
input_layers += [src_attn_post_softmax_shape]
if enc_output_flag:
# This input is used in independent decoder program for inference.
# The actual data shape of slf_attn_bias_flag is:
# [batch_size, max_len_in_batch, d_model]
enc_output = layers.data(
name=input_data_names[len(input_layers)],
shape=[batch_size, max_length, d_model],
dtype="float32",
append_batch_size=False)
input_layers += [enc_output]
return input_layers
def transformer(
src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, ):
enc_inputs = make_inputs(
encoder_input_data_names,
n_head,
d_model,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=False,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=False)
enc_output = wrap_encoder(
src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
enc_inputs, )
dec_inputs = make_inputs(
decoder_input_data_names,
n_head,
d_model,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=True,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=True)
predict = wrap_decoder(
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
dec_inputs,
enc_output, )
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
gold, weights = make_inputs(
label_data_names,
n_head,
d_model,
max_length,
is_pos=False,
slf_attn_bias_flag=False,
src_attn_bias_flag=False,
enc_output_flag=False,
data_shape_flag=False,
slf_attn_shape_flag=False,
src_attn_shape_flag=False)
cost = layers.softmax_with_cross_entropy(logits=predict, label=gold)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
avg_cost = sum_cost / token_num
return sum_cost, avg_cost, predict, token_num
def wrap_encoder(src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
enc_inputs=None):
"""
The wrapper assembles together all needed layers for the encoder.
"""
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
make_inputs(
encoder_input_data_names,
n_head,
d_model,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=False,
enc_output_flag=False,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=False)
else:
src_word, src_pos, src_slf_attn_bias, src_data_shape, \
slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape = \
enc_inputs
enc_input = prepare_encoder(
src_word,
src_pos,
src_vocab_size,
d_model,
max_length,
dropout_rate,
src_data_shape, )
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape, )
return enc_output
def wrap_decoder(trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
dec_inputs=None,
enc_output=None):
"""
The wrapper assembles together all needed layers for the decoder.
"""
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape, enc_output = make_inputs(
decoder_input_data_names,
n_head,
d_model,
max_length,
is_pos=True,
slf_attn_bias_flag=True,
src_attn_bias_flag=True,
enc_output_flag=True,
data_shape_flag=True,
slf_attn_shape_flag=True,
src_attn_shape_flag=True)
else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, slf_attn_pre_softmax_shape, \
slf_attn_post_softmax_shape, src_attn_pre_softmax_shape, \
src_attn_post_softmax_shape = dec_inputs
dec_input = prepare_decoder(
trg_word,
trg_pos,
trg_vocab_size,
d_model,
max_length,
dropout_rate,
trg_data_shape, )
dec_output = decoder(
dec_input,
enc_output,
trg_slf_attn_bias,
trg_src_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
slf_attn_pre_softmax_shape,
slf_attn_post_softmax_shape,
src_attn_pre_softmax_shape,
src_attn_post_softmax_shape, )
# Return logits for training and probs for inference.
predict = layers.reshape(
x=layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax" if dec_inputs is None else None)
return predict
import os
from functools import partial
from collections import defaultdict
__all__ = [
"train",
"test",
"get_dict",
]
DATA_HOME = "/root/data/nist06n/"
START_MARK = "_GO"
END_MARK = "_EOS"
UNK_MARK = "_UNK"
def __build_dict(data_file, dict_size, save_path, lang="cn"):
word_dict = defaultdict(int)
data_files = [os.path.join(data_file, f) for f in os.listdir(data_file)
] if os.path.isdir(data_file) else [data_file]
for file_path in data_files:
with open(file_path, mode="r") as f:
for line in f.readlines():
line_split = line.strip().split("\t")
if len(line_split) != 2: continue
sen = line_split[0] if lang == "cn" else line_split[1]
for w in sen.split():
word_dict[w] += 1
with open(save_path, "w") as fout:
fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK,
UNK_MARK))
for idx, word in enumerate(
sorted(word_dict.iteritems(), key=lambda x: x[1],
reverse=True)):
if idx + 3 == dict_size: break
fout.write("%s\n" % (word[0]))
def __load_dict(data_file, dict_size, lang, dict_file=None, reverse=False):
dict_file = "%s_%d.dict" % (lang,
dict_size) if dict_file is None else dict_file
dict_path = os.path.join(DATA_HOME, dict_file)
data_path = os.path.join(DATA_HOME, data_file)
if not os.path.exists(dict_path) or (len(open(dict_path, "r").readlines())
!= dict_size):
__build_dict(data_path, dict_size, dict_path, lang)
word_dict = {}
with open(dict_path, "r") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip()
else:
word_dict[line.strip()] = idx
return word_dict
def reader_creator(data_file,
src_lang,
src_dict_size,
trg_dict_size,
src_dict_file=None,
trg_dict_file=None,
len_filter=200):
def reader():
src_dict = __load_dict(data_file, src_dict_size, "cn", src_dict_file)
trg_dict = __load_dict(data_file, trg_dict_size, "en", trg_dict_file)
# the indice for start mark, end mark, and unk are the same in source
# language and target language. Here uses the source language
# dictionary to determine their indices.
start_id = src_dict[START_MARK]
end_id = src_dict[END_MARK]
unk_id = src_dict[UNK_MARK]
src_col = 0 if src_lang == "cn" else 1
trg_col = 1 - src_col
data_path = os.path.join(DATA_HOME, data_file)
data_files = [
os.path.join(data_path, f) for f in os.listdir(data_path)
] if os.path.isdir(data_path) else [data_path]
for file_path in data_files:
with open(file_path, mode="r") as f:
for line in f.readlines():
line_split = line.strip().split("\t")
if len(line_split) != 2:
continue
src_words = line_split[src_col].split()
src_ids = [start_id
] + [src_dict.get(w, unk_id)
for w in src_words] + [end_id]
trg_words = line_split[trg_col].split()
trg_ids = [trg_dict.get(w, unk_id) for w in trg_words]
trg_ids_next = trg_ids + [end_id]
trg_ids = [start_id] + trg_ids
if len(src_words) + len(trg_words) < len_filter:
yield src_ids, trg_ids, trg_ids_next
return reader
def train(data_file,
src_dict_size,
trg_dict_size,
src_lang="cn",
src_dict_file=None,
trg_dict_file=None,
len_filter=200):
return reader_creator(data_file, src_lang, src_dict_size, trg_dict_size,
src_dict_file, trg_dict_file, len_filter)
test = partial(train, len_filter=100000)
def get_dict(data_file, dict_size, lang, dict_file=None, reverse=False):
dict_file = "%s_%d.dict" % (lang,
dict_size) if dict_file is None else dict_file
dict_path = os.path.join(DATA_HOME, dict_file)
assert os.path.exists(dict_path), "Word dictionary does not exist. "
return __load_dict(data_file, dict_size, lang, dict_file, reverse)
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
class LearningRateScheduler(object):
"""
Wrapper for learning rate scheduling as described in the Transformer paper.
LearningRateScheduler adapts the learning rate externally and the adapted
learning rate will be feeded into the main_program as input data.
"""
def __init__(self,
d_model,
warmup_steps,
place,
learning_rate=0.001,
current_steps=0,
name="learning_rate"):
self.current_steps = current_steps
self.warmup_steps = warmup_steps
self.d_model = d_model
self.learning_rate = layers.create_global_var(
name=name,
shape=[1],
value=float(learning_rate),
dtype="float32",
persistable=True)
self.place = place
def update_learning_rate(self, data_input):
self.current_steps += 1
lr_value = np.power(self.d_model, -0.5) * np.min([
np.power(self.current_steps, -0.5),
np.power(self.warmup_steps, -1.5) * self.current_steps
])
lr_tensor = fluid.LoDTensor()
lr_tensor.set(np.array([lr_value], dtype="float32"), self.place)
data_input[self.learning_rate.name] = lr_tensor
因为 它太大了无法显示 source diff 。你可以改为 查看blob
import os
import time
import numpy as np
import paddle
import paddle.fluid as fluid
from model import transformer, position_encoding_init
from optim import LearningRateScheduler
from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \
encoder_input_data_names, decoder_input_data_names, label_data_names
import nist_data_provider
def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=True):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if is_label: # label weight
inst_weight = np.array(
[[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
range(1, len(inst) + 1) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape(
[-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
[1, n_head, 1, 1]) * [-1e9]
else:
# This is used to avoid attention on paddings.
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
(max_len - len(inst))
for inst in insts])
slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
[1, n_head, max_len, 1])
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len:
return_list += [max_len]
return return_list if len(return_list) > 1 else return_list[0]
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
n_head, d_model):
"""
Put all padded data needed by training into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([len(insts), src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([len(insts), trg_max_len, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
src_slf_attn_bias.shape, dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
lbl_word, lbl_weight = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False)
input_dict = dict(
zip(input_data_names, [
src_word, src_pos, src_slf_attn_bias, src_data_shape,
src_slf_attn_pre_softmax_shape, src_slf_attn_post_softmax_shape,
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
trg_data_shape, trg_slf_attn_pre_softmax_shape,
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape,
trg_src_attn_post_softmax_shape, lbl_word, lbl_weight
]))
return input_dict
def main():
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps, place,
TrainTaskConfig.learning_rate)
optimizer = fluid.optimizer.Adam(
learning_rate=lr_scheduler.learning_rate,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
optimizer.minimize(avg_cost if TrainTaskConfig.use_avg_cost else sum_cost)
train_data = paddle.batch(
paddle.reader.shuffle(
nist_data_provider.train("data", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
buf_size=100000),
batch_size=TrainTaskConfig.batch_size)
# Program to do validation.
'''test_program = fluid.default_main_program().clone()
with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program([avg_cost])
val_data = paddle.batch(
nist_data_provider.train("data", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)'''
def test(exe):
test_total_cost = 0
test_total_token = 0
for batch_id, data in enumerate(val_data()):
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
test_sum_cost, test_token_num = exe.run(
test_program,
feed=data_input,
fetch_list=[sum_cost, token_num],
use_program_cache=True)
test_total_cost += test_sum_cost
test_total_token += test_token_num
test_avg_cost = test_total_cost / test_total_token
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
# Initialize the parameters.
exe.run(fluid.framework.default_startup_program())
for pos_enc_param_name in pos_enc_param_names:
pos_enc_param = fluid.global_scope().find_var(
pos_enc_param_name).get_tensor()
pos_enc_param.set(
position_encoding_init(ModelHyperParams.max_length + 1,
ModelHyperParams.d_model), place)
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
fetch_list=[sum_cost, avg_cost],
use_program_cache=True)
sum_cost_val, avg_cost_val = np.array(outs[0]), np.array(outs[1])
print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
(pass_id, batch_id, sum_cost_val, avg_cost_val,
np.exp([min(avg_cost_val[0], 100)])))
# Validate and save the model for inference.
#val_avg_cost, val_ppl = test(exe)
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
print("pass_id = " + str(pass_id) + " time_consumed = " +
str(time_consumed))
#print("epoch: %d, val avg loss: %f, val ppl: %f, "
# "consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed))
fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"),
encoder_input_data_names + decoder_input_data_names[:-1],
[predict], exe)
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册