提交 64a7e1a0 编写于 作者: G guosheng

Update transformer

上级 0baa8f68
...@@ -88,14 +88,17 @@ def do_predict(args): ...@@ -88,14 +88,17 @@ def do_predict(args):
# define model # define model
inputs = [ inputs = [
Input( Input(
[None, None], "int64", name="src_word"), Input( [None, None], "int64", name="src_word"),
[None, None], "int64", name="src_pos"), Input( Input(
[None, args.n_head, None, None], [None, None], "int64", name="src_pos"),
"float32", Input(
name="src_slf_attn_bias"), Input( [None, args.n_head, None, None],
[None, args.n_head, None, None], "float32",
"float32", name="src_slf_attn_bias"),
name="trg_src_attn_bias") Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias"),
] ]
transformer = InferTransformer( transformer = InferTransformer(
args.src_vocab_size, args.src_vocab_size,
......
...@@ -19,6 +19,15 @@ import tarfile ...@@ -19,6 +19,15 @@ import tarfile
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.io import BatchSampler, DataLoader
class TokenBatchSampler(BatchSampler):
def __init__(self):
pass
def __iter(self):
pass
def pad_batch_data(insts, def pad_batch_data(insts,
...@@ -54,7 +63,8 @@ def pad_batch_data(insts, ...@@ -54,7 +63,8 @@ def pad_batch_data(insts,
if is_target: if is_target:
# This is used to avoid attention on paddings and subsequent # This is used to avoid attention on paddings and subsequent
# words. # words.
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len)) slf_attn_bias_data = np.ones(
(inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data, slf_attn_bias_data = np.triu(slf_attn_bias_data,
1).reshape([-1, 1, max_len, max_len]) 1).reshape([-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data, slf_attn_bias_data = np.tile(slf_attn_bias_data,
...@@ -306,6 +316,7 @@ class DataProcessor(object): ...@@ -306,6 +316,7 @@ class DataProcessor(object):
:param seed: The seed for random. :param seed: The seed for random.
:type seed: int :type seed: int
""" """
def __init__(self, def __init__(self,
src_vocab_fpath, src_vocab_fpath,
trg_vocab_fpath, trg_vocab_fpath,
...@@ -360,21 +371,23 @@ class DataProcessor(object): ...@@ -360,21 +371,23 @@ class DataProcessor(object):
def load_src_trg_ids(self, fpattern, tar_fname): def load_src_trg_ids(self, fpattern, tar_fname):
converters = [ converters = [
Converter(vocab=self._src_vocab, Converter(
beg=self._bos_idx, vocab=self._src_vocab,
end=self._eos_idx, beg=self._bos_idx,
unk=self._unk_idx, end=self._eos_idx,
delimiter=self._token_delimiter, unk=self._unk_idx,
add_beg=False) delimiter=self._token_delimiter,
add_beg=False)
] ]
if not self._only_src: if not self._only_src:
converters.append( converters.append(
Converter(vocab=self._trg_vocab, Converter(
beg=self._bos_idx, vocab=self._trg_vocab,
end=self._eos_idx, beg=self._bos_idx,
unk=self._unk_idx, end=self._eos_idx,
delimiter=self._token_delimiter, unk=self._unk_idx,
add_beg=True)) delimiter=self._token_delimiter,
add_beg=True))
converters = ComposedConverter(converters) converters = ComposedConverter(converters)
...@@ -402,9 +415,8 @@ class DataProcessor(object): ...@@ -402,9 +415,8 @@ class DataProcessor(object):
f = tarfile.open(fpaths[0], "rb") f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname): for line in f.extractfile(tar_fname):
fields = line.strip(b"\n").split(self._field_delimiter) fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src if (not self._only_src and len(fields) == 2) or (
and len(fields) == 2) or (self._only_src self._only_src and len(fields) == 1):
and len(fields) == 1):
yield fields yield fields
else: else:
for fpath in fpaths: for fpath in fpaths:
...@@ -414,9 +426,8 @@ class DataProcessor(object): ...@@ -414,9 +426,8 @@ class DataProcessor(object):
with open(fpath, "rb") as f: with open(fpath, "rb") as f:
for line in f: for line in f:
fields = line.strip(b"\n").split(self._field_delimiter) fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src if (not self._only_src and len(fields) == 2) or (
and len(fields) == 2) or (self._only_src self._only_src and len(fields) == 1):
and len(fields) == 1):
yield fields yield fields
@staticmethod @staticmethod
...@@ -477,7 +488,8 @@ class DataProcessor(object): ...@@ -477,7 +488,8 @@ class DataProcessor(object):
if self._only_src: if self._only_src:
yield [[self._src_seq_ids[idx]] for idx in batch_ids] yield [[self._src_seq_ids[idx]] for idx in batch_ids]
else: else:
yield [(self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1], yield [(self._src_seq_ids[idx],
self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]) for idx in batch_ids] self._trg_seq_ids[idx][1:]) for idx in batch_ids]
return __impl__ return __impl__
...@@ -512,8 +524,8 @@ class DataProcessor(object): ...@@ -512,8 +524,8 @@ class DataProcessor(object):
for item in data_reader(): for item in data_reader():
inst_num_per_part = len(item) // count inst_num_per_part = len(item) // count
for i in range(count): for i in range(count):
yield item[inst_num_per_part * i:inst_num_per_part * yield item[inst_num_per_part * i:inst_num_per_part * (i + 1
(i + 1)] )]
return __impl__ return __impl__
...@@ -535,7 +547,7 @@ class DataProcessor(object): ...@@ -535,7 +547,7 @@ class DataProcessor(object):
for data in data_reader(): for data in data_reader():
data_inputs = prepare_train_input(data, src_pad_idx, data_inputs = prepare_train_input(data, src_pad_idx,
trg_pad_idx, n_head) trg_pad_idx, n_head)
yield data_inputs yield data_inputs[:-2], data_inputs[-2:]
def __for_predict__(): def __for_predict__():
for data in data_reader(): for data in data_reader():
......
...@@ -32,9 +32,35 @@ from utils.check import check_gpu, check_version ...@@ -32,9 +32,35 @@ from utils.check import check_gpu, check_version
import reader import reader
from transformer import Transformer, CrossEntropyCriterion, NoamDecay from transformer import Transformer, CrossEntropyCriterion, NoamDecay
from model import Input from model import Input
from callbacks import ProgBarLogger
class LoggerCallback(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.):
super(LoggerCallback, self).__init__(log_freq, verbose)
self.loss_normalizer = loss_normalizer
def on_train_begin(self, logs=None):
super(LoggerCallback, self).on_train_begin(logs)
self.train_metrics += ["normalized loss", "ppl"]
def on_train_batch_end(self, step, logs=None):
logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
logs["ppl"] = np.exp(min(logs["loss"][0], 100))
super(LoggerCallback, self).on_train_batch_end(step, logs)
def on_eval_begin(self, logs=None):
super(LoggerCallback, self).on_eval_begin(logs)
self.eval_metrics += ["normalized loss", "ppl"]
def on_eval_batch_end(self, step, logs=None):
logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
logs["ppl"] = np.exp(min(logs["loss"][0], 100))
super(LoggerCallback, self).on_eval_batch_end(step, logs)
def do_train(args): def do_train(args):
init_context('dynamic' if FLAGS.dynamic else 'static')
trainer_count = 1 #get_nranks() trainer_count = 1 #get_nranks()
@contextlib.contextmanager @contextlib.contextmanager
...@@ -102,24 +128,31 @@ def do_train(args): ...@@ -102,24 +128,31 @@ def do_train(args):
# define model # define model
inputs = [ inputs = [
Input( Input(
[None, None], "int64", name="src_word"), Input( [None, None], "int64", name="src_word"),
[None, None], "int64", name="src_pos"), Input( Input(
[None, args.n_head, None, None], [None, None], "int64", name="src_pos"),
"float32", Input(
name="src_slf_attn_bias"), Input( [None, args.n_head, None, None],
[None, None], "int64", name="trg_word"), Input( "float32",
[None, None], "int64", name="trg_pos"), Input( name="src_slf_attn_bias"),
[None, args.n_head, None, None], Input(
"float32", [None, None], "int64", name="trg_word"),
name="trg_slf_attn_bias"), Input( Input(
[None, args.n_head, None, None], [None, None], "int64", name="trg_pos"),
"float32", Input(
name="trg_src_attn_bias") [None, args.n_head, None, None],
"float32",
name="trg_slf_attn_bias"),
Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias"),
] ]
labels = [ labels = [
Input( Input(
[None, 1], "int64", name="label"), Input( [None, 1], "int64", name="label"),
[None, 1], "float32", name="weight") Input(
[None, 1], "float32", name="weight"),
] ]
transformer = Transformer( transformer = Transformer(
...@@ -149,7 +182,8 @@ def do_train(args): ...@@ -149,7 +182,8 @@ def do_train(args):
## init from some pretrain models, to better solve the current task ## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model: if args.init_from_pretrain_model:
transformer.load( transformer.load(
os.path.join(args.init_from_pretrain_model, "transformer")) os.path.join(args.init_from_pretrain_model, "transformer"),
reset_optimizer=True)
# the best cross-entropy value with label smoothing # the best cross-entropy value with label smoothing
loss_normalizer = -( loss_normalizer = -(
...@@ -157,63 +191,17 @@ def do_train(args): ...@@ -157,63 +191,17 @@ def do_train(args):
(1. - args.label_smooth_eps)) + args.label_smooth_eps * (1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20)) np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
step_idx = 0 transformer.fit(train_loader=train_loader,
# train loop eval_loader=val_loader,
for pass_id in range(args.epoch): epochs=1,
pass_start_time = time.time() eval_freq=1,
batch_id = 0 save_freq=1,
for input_data in train_loader(): verbose=2,
losses = transformer.train(input_data[:-2], input_data[-2:]) callbacks=[
LoggerCallback(
if step_idx % args.print_step == 0: log_freq=args.print_step,
total_avg_cost = np.sum(losses) loss_normalizer=loss_normalizer)
])
if step_idx == 0:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s"
%
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
avg_batch_time = time.time()
if step_idx % args.save_step == 0 and step_idx != 0:
# validation: how to accumulate with Model loss
if args.validation_file:
total_avg_cost = 0
for idx, input_data in enumerate(val_loader()):
losses = transformer.eval(input_data[:-2],
input_data[-2:])
total_avg_cost += np.sum(losses)
total_avg_cost /= idx + 1
logging.info("validation, step_idx: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
transformer.save(
os.path.join(args.save_model, "step_" + str(step_idx),
"transformer"))
batch_id += 1
step_idx += 1
time_consumed = time.time() - pass_start_time
if args.save_model:
transformer.save(
os.path.join(args.save_model, "step_final", "transformer"))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册