提交 64a7e1a0 编写于 作者: G guosheng

Update transformer

上级 0baa8f68
......@@ -88,14 +88,17 @@ def do_predict(args):
# define model
inputs = [
Input(
[None, None], "int64", name="src_word"), Input(
[None, None], "int64", name="src_pos"), Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"), Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias")
[None, None], "int64", name="src_word"),
Input(
[None, None], "int64", name="src_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias"),
]
transformer = InferTransformer(
args.src_vocab_size,
......
......@@ -19,6 +19,15 @@ import tarfile
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.io import BatchSampler, DataLoader
class TokenBatchSampler(BatchSampler):
def __init__(self):
pass
def __iter(self):
pass
def pad_batch_data(insts,
......@@ -54,7 +63,8 @@ def pad_batch_data(insts,
if is_target:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.ones(
(inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data,
1).reshape([-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
......@@ -306,6 +316,7 @@ class DataProcessor(object):
:param seed: The seed for random.
:type seed: int
"""
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
......@@ -360,21 +371,23 @@ class DataProcessor(object):
def load_src_trg_ids(self, fpattern, tar_fname):
converters = [
Converter(vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False)
Converter(
vocab=self._src_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=False)
]
if not self._only_src:
converters.append(
Converter(vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=True))
Converter(
vocab=self._trg_vocab,
beg=self._bos_idx,
end=self._eos_idx,
unk=self._unk_idx,
delimiter=self._token_delimiter,
add_beg=True))
converters = ComposedConverter(converters)
......@@ -402,9 +415,8 @@ class DataProcessor(object):
f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname):
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
else:
for fpath in fpaths:
......@@ -414,9 +426,8 @@ class DataProcessor(object):
with open(fpath, "rb") as f:
for line in f:
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
@staticmethod
......@@ -477,7 +488,8 @@ class DataProcessor(object):
if self._only_src:
yield [[self._src_seq_ids[idx]] for idx in batch_ids]
else:
yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1],
yield [(self._src_seq_ids[idx],
self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]) for idx in batch_ids]
return __impl__
......@@ -512,8 +524,8 @@ class DataProcessor(object):
for item in data_reader():
inst_num_per_part = len(item) // count
for i in range(count):
yield item[inst_num_per_part * i:inst_num_per_part *
(i + 1)]
yield item[inst_num_per_part * i:inst_num_per_part * (i + 1
)]
return __impl__
......@@ -535,7 +547,7 @@ class DataProcessor(object):
for data in data_reader():
data_inputs = prepare_train_input(data, src_pad_idx,
trg_pad_idx, n_head)
yield data_inputs
yield data_inputs[:-2], data_inputs[-2:]
def __for_predict__():
for data in data_reader():
......
......@@ -32,9 +32,35 @@ from utils.check import check_gpu, check_version
import reader
from transformer import Transformer, CrossEntropyCriterion, NoamDecay
from model import Input
from callbacks import ProgBarLogger
class LoggerCallback(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.):
super(LoggerCallback, self).__init__(log_freq, verbose)
self.loss_normalizer = loss_normalizer
def on_train_begin(self, logs=None):
super(LoggerCallback, self).on_train_begin(logs)
self.train_metrics += ["normalized loss", "ppl"]
def on_train_batch_end(self, step, logs=None):
logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
logs["ppl"] = np.exp(min(logs["loss"][0], 100))
super(LoggerCallback, self).on_train_batch_end(step, logs)
def on_eval_begin(self, logs=None):
super(LoggerCallback, self).on_eval_begin(logs)
self.eval_metrics += ["normalized loss", "ppl"]
def on_eval_batch_end(self, step, logs=None):
logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
logs["ppl"] = np.exp(min(logs["loss"][0], 100))
super(LoggerCallback, self).on_eval_batch_end(step, logs)
def do_train(args):
init_context('dynamic' if FLAGS.dynamic else 'static')
trainer_count = 1 #get_nranks()
@contextlib.contextmanager
......@@ -102,24 +128,31 @@ def do_train(args):
# define model
inputs = [
Input(
[None, None], "int64", name="src_word"), Input(
[None, None], "int64", name="src_pos"), Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"), Input(
[None, None], "int64", name="trg_word"), Input(
[None, None], "int64", name="trg_pos"), Input(
[None, args.n_head, None, None],
"float32",
name="trg_slf_attn_bias"), Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias")
[None, None], "int64", name="src_word"),
Input(
[None, None], "int64", name="src_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"),
Input(
[None, None], "int64", name="trg_word"),
Input(
[None, None], "int64", name="trg_pos"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_slf_attn_bias"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias"),
]
labels = [
Input(
[None, 1], "int64", name="label"), Input(
[None, 1], "float32", name="weight")
[None, 1], "int64", name="label"),
Input(
[None, 1], "float32", name="weight"),
]
transformer = Transformer(
......@@ -149,7 +182,8 @@ def do_train(args):
## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model:
transformer.load(
os.path.join(args.init_from_pretrain_model, "transformer"))
os.path.join(args.init_from_pretrain_model, "transformer"),
reset_optimizer=True)
# the best cross-entropy value with label smoothing
loss_normalizer = -(
......@@ -157,63 +191,17 @@ def do_train(args):
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
step_idx = 0
# train loop
for pass_id in range(args.epoch):
pass_start_time = time.time()
batch_id = 0
for input_data in train_loader():
losses = transformer.train(input_data[:-2], input_data[-2:])
if step_idx % args.print_step == 0:
total_avg_cost = np.sum(losses)
if step_idx == 0:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s"
%
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
avg_batch_time = time.time()
if step_idx % args.save_step == 0 and step_idx != 0:
# validation: how to accumulate with Model loss
if args.validation_file:
total_avg_cost = 0
for idx, input_data in enumerate(val_loader()):
losses = transformer.eval(input_data[:-2],
input_data[-2:])
total_avg_cost += np.sum(losses)
total_avg_cost /= idx + 1
logging.info("validation, step_idx: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
transformer.save(
os.path.join(args.save_model, "step_" + str(step_idx),
"transformer"))
batch_id += 1
step_idx += 1
time_consumed = time.time() - pass_start_time
if args.save_model:
transformer.save(
os.path.join(args.save_model, "step_final", "transformer"))
transformer.fit(train_loader=train_loader,
eval_loader=val_loader,
epochs=1,
eval_freq=1,
save_freq=1,
verbose=2,
callbacks=[
LoggerCallback(
log_freq=args.print_step,
loss_normalizer=loss_normalizer)
])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册