提交 bb45f859 编写于 作者: G gongweibao

add

上级 a9b22c5d
class TrainTaskConfig(object):
use_gpu = True
# the epoch number to train.
pass_num = 50
# the number of sequences contained in a mini-batch.
batch_size = 56
# the hyper parameters for Adam optimizer.
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.98
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 15000
# the flag indicating to use average loss or sum loss when training.
use_avg_cost = False
# the directory for saving trained models.
model_dir = "trained_models"
class InferTaskConfig(object):
use_gpu = True
# the number of examples in one run for sequence generation.
batch_size = 1
# the parameters for beam search.
beam_size = 5
max_length = 30
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False
# the directory for loading the trained model.
model_path = "trained_models/pass_20.infer.model"
class ModelHyperParams(object):
# This model directly uses paddle.dataset.wmt16 in which <bos>, <eos> and
# <unk> token has alreay been added. As for the <pad> token, any token
# included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients.
# size of source word dictionary.
src_vocab_size = 30001
# size of target word dictionay
trg_vocab_size = 30001
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences.
# The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
max_length = 150
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rate used by all dropout layers.
dropout = 0.1
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# Names of all data layers in encoder listed in order.
encoder_input_data_names = (
"src_word",
"src_pos",
"src_slf_attn_bias",
"src_data_shape",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )
# Names of all data layers in decoder listed in order.
decoder_input_data_names = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
"trg_src_attn_post_softmax_shape",
"enc_output", )
# Names of label related data layers listed in order.
label_data_names = (
"lbl_word",
"lbl_weight", )
import numpy as np
import paddle
import paddle.fluid as fluid
import model
from model import wrap_encoder as encoder
from model import wrap_decoder as decoder
from config import InferTaskConfig, ModelHyperParams, \
encoder_input_data_names, decoder_input_data_names
from train import pad_batch_data
import nist_data_provider
def translate_batch(exe,
src_words,
encoder,
enc_in_names,
enc_out_names,
decoder,
dec_in_names,
dec_out_names,
beam_size,
max_length,
n_best,
batch_size,
n_head,
d_model,
src_pad_idx,
trg_pad_idx,
bos_idx,
eos_idx,
unk_idx,
output_unk=True):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally.
"""
# Prepare data for encoder and run the encoder.
enc_in_data = pad_batch_data(
src_words,
src_pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=False)
# Append the data shape input to reshape the output of embedding layer.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1], d_model], dtype="int32")
]
# Append the shape inputs to reshape before and after softmax in encoder
# self attention.
enc_in_data = enc_in_data + [
np.array(
[-1, enc_in_data[2].shape[-1]], dtype="int32"), np.array(
enc_in_data[2].shape, dtype="int32")
]
enc_output = exe.run(encoder,
feed=dict(zip(enc_in_names, enc_in_data)),
fetch_list=enc_out_names)[0]
# Beam Search.
# To store the beam info.
scores = np.zeros((batch_size, beam_size), dtype="float32")
prev_branchs = [[] for i in range(batch_size)]
next_ids = [[] for i in range(batch_size)]
# Use beam_inst_map to map beam idx to the instance idx in batch, since the
# size of feeded batch is changing.
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(range(batch_size))
}
# Use active_beams to recode the alive.
active_beams = range(batch_size)
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
"""
Decode and select n_best sequences for one instance by backtrace.
"""
seqs = []
for i in range(n_best):
k = i
seq = []
for j in range(len(prev_branchs) - 1, -1, -1):
seq.append(next_ids[j][k])
k = prev_branchs[j][k]
seq = seq[::-1]
# Add the <bos>, since next_ids don't include the <bos>.
seq = [bos_idx] + seq
seqs.append(seq)
return seqs
def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
"""
Initialize the input data for decoder.
"""
trg_words = np.array(
[[bos_idx]] * batch_size * beam_size, dtype="int64")
trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64")
src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[2].shape[
-1], enc_in_data[2], 1
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len,
trg_max_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_max_len, trg_max_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(
src_slf_attn_bias[:, :, ::src_max_length, :][:, np.newaxis],
[1, beam_size, 1, trg_max_len, 1]).reshape([
-1, src_slf_attn_bias.shape[1], trg_max_len,
src_slf_attn_bias.shape[-1]
])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[batch_size * beam_size, trg_max_len, d_model], dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = np.tile(
enc_output[:, np.newaxis], [1, beam_size, 1, 1]).reshape(
[-1, enc_output.shape[-2], enc_output.shape[-1]])
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
"""
Update the input data of decoder mainly by slicing from the previous
input data and dropping the finished instance beams.
"""
trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output = dec_in_data
trg_cur_len = trg_slf_attn_bias.shape[-1] + 1
trg_words = np.array(
[
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
for beam_idx in active_beams
],
dtype="int64")
trg_words = trg_words.reshape([-1, 1])
trg_pos = np.array(
[range(1, trg_cur_len + 1)] * len(active_beams) * beam_size,
dtype="int64").reshape([-1, 1])
active_beams = [beam_inst_map[beam_idx] for beam_idx in active_beams]
active_beams_indice = (
(np.array(active_beams) * beam_size)[:, np.newaxis] +
np.array(range(beam_size))[np.newaxis, :]).flatten()
# This is used to remove attention on subsequent words.
trg_slf_attn_bias = np.ones((len(active_beams) * beam_size, trg_cur_len,
trg_cur_len))
trg_slf_attn_bias = np.triu(trg_slf_attn_bias, 1).reshape(
[-1, 1, trg_cur_len, trg_cur_len])
trg_slf_attn_bias = (np.tile(trg_slf_attn_bias, [1, n_head, 1, 1]) *
[-1e9]).astype("float32")
# This is used to remove attention on the paddings of source sequences.
trg_src_attn_bias = np.tile(trg_src_attn_bias[
active_beams_indice, :, ::trg_src_attn_bias.shape[2], :],
[1, 1, trg_cur_len, 1])
# Append the shape input to reshape the output of embedding layer.
trg_data_shape = np.array(
[len(active_beams) * beam_size, trg_cur_len, d_model],
dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# decoder self attention.
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
# Append the shape inputs to reshape before and after softmax in
# encoder-decoder attention.
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
enc_output = enc_output[active_beams_indice, :, :]
return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \
trg_data_shape, trg_slf_attn_pre_softmax_shape, \
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape, \
trg_src_attn_post_softmax_shape, enc_output
dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data,
enc_output)
for i in range(max_length):
predict_all = exe.run(decoder,
feed=dict(zip(dec_in_names, dec_in_data)),
fetch_list=dec_out_names)[0]
predict_all = np.log(
predict_all.reshape([len(beam_inst_map) * beam_size, i + 1, -1])
[:, -1, :])
predict_all = (predict_all + scores[active_beams].reshape(
[len(beam_inst_map) * beam_size, -1])).reshape(
[len(beam_inst_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = []
for beam_idx in range(batch_size):
if not beam_inst_map.has_key(beam_idx):
continue
inst_idx = beam_inst_map[beam_idx]
predict = (predict_all[inst_idx, :, :]
if i != 0 else predict_all[inst_idx, 0, :]).flatten()
top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:]
top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[::
-1]]
top_scores = predict[top_scores_ids]
scores[beam_idx] = top_scores
prev_branchs[beam_idx].append(top_scores_ids /
predict_all.shape[-1])
next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1])
if next_ids[beam_idx][-1][0] != eos_idx:
active_beams.append(beam_idx)
if len(active_beams) == 0:
break
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams,
beam_inst_map)
beam_inst_map = {
beam_idx: inst_idx
for inst_idx, beam_idx in enumerate(active_beams)
}
# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)
for beam_idx in range(batch_size)
]
return seqs, scores[:, :n_best].tolist()
def main():
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
encoder_program = fluid.Program()
with fluid.program_guard(main_program=encoder_program):
enc_output = encoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)
decoder_program = fluid.Program()
with fluid.program_guard(main_program=decoder_program):
predict = decoder(
ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout)
# Load model parameters of encoder and decoder separately from the saved
# transformer model.
encoder_var_names = []
for op in encoder_program.block(0).ops:
encoder_var_names += op.input_arg_names
encoder_param_names = filter(
lambda var_name: isinstance(encoder_program.block(0).var(var_name),
fluid.framework.Parameter),
encoder_var_names)
encoder_params = map(encoder_program.block(0).var, encoder_param_names)
decoder_var_names = []
for op in decoder_program.block(0).ops:
decoder_var_names += op.input_arg_names
decoder_param_names = filter(
lambda var_name: isinstance(decoder_program.block(0).var(var_name),
fluid.framework.Parameter),
decoder_var_names)
decoder_params = map(decoder_program.block(0).var, decoder_param_names)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=encoder_params)
fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params)
# This is used here to set dropout to the test mode.
encoder_program = fluid.io.get_inference_program(
target_vars=[enc_output], main_program=encoder_program)
decoder_program = fluid.io.get_inference_program(
target_vars=[predict], main_program=decoder_program)
'''test_data = paddle.batch(
paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=InferTaskConfig.batch_size)
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)'''
test_data = paddle.batch(
nist_data_provider.test("nist06n.test", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=InferTaskConfig.batch_size)
trg_idx2word = nist_data_provider.get_dict(
"data",
dict_size=ModelHyperParams.trg_vocab_size,
lang="en",
reverse=True)
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)
for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe,
[item[0] for item in data],
encoder_program,
encoder_input_data_names,
[enc_output.name],
decoder_program,
decoder_input_data_names,
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.d_model,
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.eos_idx, # Use eos_idx to pad.
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
output_unk=InferTaskConfig.output_unk)
for i in range(len(batch_seqs)):
# Post-process the beam-search decoded sequences.
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))
if __name__ == "__main__":
main()
import os
from functools import partial
from collections import defaultdict
__all__ = [
"train",
"test",
"get_dict",
]
DATA_HOME = "/root/data/nist06n/"
START_MARK = "_GO"
END_MARK = "_EOS"
UNK_MARK = "_UNK"
def __build_dict(data_file, dict_size, save_path, lang="cn"):
word_dict = defaultdict(int)
data_files = [os.path.join(data_file, f) for f in os.listdir(data_file)
] if os.path.isdir(data_file) else [data_file]
for file_path in data_files:
with open(file_path, mode="r") as f:
for line in f.readlines():
line_split = line.strip().split("\t")
if len(line_split) != 2: continue
sen = line_split[0] if lang == "cn" else line_split[1]
for w in sen.split():
word_dict[w] += 1
with open(save_path, "w") as fout:
fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK,
UNK_MARK))
for idx, word in enumerate(
sorted(word_dict.iteritems(), key=lambda x: x[1],
reverse=True)):
if idx + 3 == dict_size: break
fout.write("%s\n" % (word[0]))
def __load_dict(data_file, dict_size, lang, dict_file=None, reverse=False):
dict_file = "%s_%d.dict" % (lang,
dict_size) if dict_file is None else dict_file
dict_path = os.path.join(DATA_HOME, dict_file)
data_path = os.path.join(DATA_HOME, data_file)
if not os.path.exists(dict_path) or (len(open(dict_path, "r").readlines())
!= dict_size):
__build_dict(data_path, dict_size, dict_path, lang)
word_dict = {}
with open(dict_path, "r") as fdict:
for idx, line in enumerate(fdict):
if reverse:
word_dict[idx] = line.strip()
else:
word_dict[line.strip()] = idx
return word_dict
def reader_creator(data_file,
src_lang,
src_dict_size,
trg_dict_size,
src_dict_file=None,
trg_dict_file=None,
len_filter=200):
def reader():
src_dict = __load_dict(data_file, src_dict_size, "cn", src_dict_file)
trg_dict = __load_dict(data_file, trg_dict_size, "en", trg_dict_file)
# the indice for start mark, end mark, and unk are the same in source
# language and target language. Here uses the source language
# dictionary to determine their indices.
start_id = src_dict[START_MARK]
end_id = src_dict[END_MARK]
unk_id = src_dict[UNK_MARK]
src_col = 0 if src_lang == "cn" else 1
trg_col = 1 - src_col
data_path = os.path.join(DATA_HOME, data_file)
data_files = [
os.path.join(data_path, f) for f in os.listdir(data_path)
] if os.path.isdir(data_path) else [data_path]
for file_path in data_files:
with open(file_path, mode="r") as f:
for line in f.readlines():
line_split = line.strip().split("\t")
if len(line_split) != 2:
continue
src_words = line_split[src_col].split()
src_ids = [start_id
] + [src_dict.get(w, unk_id)
for w in src_words] + [end_id]
trg_words = line_split[trg_col].split()
trg_ids = [trg_dict.get(w, unk_id) for w in trg_words]
trg_ids_next = trg_ids + [end_id]
trg_ids = [start_id] + trg_ids
if len(src_words) + len(trg_words) < len_filter:
yield src_ids, trg_ids, trg_ids_next
return reader
def train(data_file,
src_dict_size,
trg_dict_size,
src_lang="cn",
src_dict_file=None,
trg_dict_file=None,
len_filter=200):
return reader_creator(data_file, src_lang, src_dict_size, trg_dict_size,
src_dict_file, trg_dict_file, len_filter)
test = partial(train, len_filter=100000)
def get_dict(data_file, dict_size, lang, dict_file=None, reverse=False):
dict_file = "%s_%d.dict" % (lang,
dict_size) if dict_file is None else dict_file
dict_path = os.path.join(DATA_HOME, dict_file)
assert os.path.exists(dict_path), "Word dictionary does not exist. "
return __load_dict(data_file, dict_size, lang, dict_file, reverse)
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
class LearningRateScheduler(object):
"""
Wrapper for learning rate scheduling as described in the Transformer paper.
LearningRateScheduler adapts the learning rate externally and the adapted
learning rate will be feeded into the main_program as input data.
"""
def __init__(self,
d_model,
warmup_steps,
place,
learning_rate=0.001,
current_steps=0,
name="learning_rate"):
self.current_steps = current_steps
self.warmup_steps = warmup_steps
self.d_model = d_model
self.learning_rate = layers.create_global_var(
name=name,
shape=[1],
value=float(learning_rate),
dtype="float32",
persistable=True)
self.place = place
def update_learning_rate(self, data_input):
self.current_steps += 1
lr_value = np.power(self.d_model, -0.5) * np.min([
np.power(self.current_steps, -0.5),
np.power(self.warmup_steps, -1.5) * self.current_steps
])
lr_tensor = fluid.LoDTensor()
lr_tensor.set(np.array([lr_value], dtype="float32"), self.place)
data_input[self.learning_rate.name] = lr_tensor
import os
import time
import numpy as np
import paddle
import paddle.fluid as fluid
from model import transformer, position_encoding_init
from optim import LearningRateScheduler
from config import TrainTaskConfig, ModelHyperParams, pos_enc_param_names, \
encoder_input_data_names, decoder_input_data_names, label_data_names
import nist_data_provider
def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=True):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if is_label: # label weight
inst_weight = np.array(
[[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
range(1, len(inst) + 1) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data, 1).reshape(
[-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
[1, n_head, 1, 1]) * [-1e9]
else:
# This is used to avoid attention on paddings.
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
(max_len - len(inst))
for inst in insts])
slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
[1, n_head, max_len, 1])
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len:
return_list += [max_len]
return return_list if len(return_list) > 1 else return_list[0]
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
n_head, d_model):
"""
Put all padded data needed by training into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
# These shape tensors are used in reshape_op.
src_data_shape = np.array([len(insts), src_max_len, d_model], dtype="int32")
trg_data_shape = np.array([len(insts), trg_max_len, d_model], dtype="int32")
src_slf_attn_pre_softmax_shape = np.array(
[-1, src_slf_attn_bias.shape[-1]], dtype="int32")
src_slf_attn_post_softmax_shape = np.array(
src_slf_attn_bias.shape, dtype="int32")
trg_slf_attn_pre_softmax_shape = np.array(
[-1, trg_slf_attn_bias.shape[-1]], dtype="int32")
trg_slf_attn_post_softmax_shape = np.array(
trg_slf_attn_bias.shape, dtype="int32")
trg_src_attn_pre_softmax_shape = np.array(
[-1, trg_src_attn_bias.shape[-1]], dtype="int32")
trg_src_attn_post_softmax_shape = np.array(
trg_src_attn_bias.shape, dtype="int32")
lbl_word, lbl_weight = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False)
input_dict = dict(
zip(input_data_names, [
src_word, src_pos, src_slf_attn_bias, src_data_shape,
src_slf_attn_pre_softmax_shape, src_slf_attn_post_softmax_shape,
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias,
trg_data_shape, trg_slf_attn_pre_softmax_shape,
trg_slf_attn_post_softmax_shape, trg_src_attn_pre_softmax_shape,
trg_src_attn_post_softmax_shape, lbl_word, lbl_weight
]))
return input_dict
def main():
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps, place,
TrainTaskConfig.learning_rate)
optimizer = fluid.optimizer.Adam(
learning_rate=lr_scheduler.learning_rate,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
optimizer.minimize(avg_cost if TrainTaskConfig.use_avg_cost else sum_cost)
train_data = paddle.batch(
paddle.reader.shuffle(
nist_data_provider.train("data", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
buf_size=100000),
batch_size=TrainTaskConfig.batch_size)
# Program to do validation.
'''test_program = fluid.default_main_program().clone()
with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program([avg_cost])
val_data = paddle.batch(
nist_data_provider.train("data", ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)'''
def test(exe):
test_total_cost = 0
test_total_token = 0
for batch_id, data in enumerate(val_data()):
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
test_sum_cost, test_token_num = exe.run(
test_program,
feed=data_input,
fetch_list=[sum_cost, token_num],
use_program_cache=True)
test_total_cost += test_sum_cost
test_total_token += test_token_num
test_avg_cost = test_total_cost / test_total_token
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
# Initialize the parameters.
exe.run(fluid.framework.default_startup_program())
for pos_enc_param_name in pos_enc_param_names:
pos_enc_param = fluid.global_scope().find_var(
pos_enc_param_name).get_tensor()
pos_enc_param.set(
position_encoding_init(ModelHyperParams.max_length + 1,
ModelHyperParams.d_model), place)
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
fetch_list=[sum_cost, avg_cost],
use_program_cache=True)
sum_cost_val, avg_cost_val = np.array(outs[0]), np.array(outs[1])
print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
(pass_id, batch_id, sum_cost_val, avg_cost_val,
np.exp([min(avg_cost_val[0], 100)])))
# Validate and save the model for inference.
#val_avg_cost, val_ppl = test(exe)
pass_end_time = time.time()
time_consumed = pass_end_time - pass_start_time
print("pass_id = " + str(pass_id) + " time_consumed = " +
str(time_consumed))
#print("epoch: %d, val avg loss: %f, val ppl: %f, "
# "consumed %fs" % (pass_id, val_avg_cost, val_ppl, time_consumed))
fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"),
encoder_input_data_names + decoder_input_data_names[:-1],
[predict], exe)
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册