提交 51ef7e15 编写于 作者: G guosheng

Update transformer to adapt to latest code

上级 a28a6e3b
......@@ -30,6 +30,7 @@ from utils.check import check_gpu, check_version
# include task-specific libs
import reader
from transformer import InferTransformer, position_encoding_init
from model import Input
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
......@@ -50,8 +51,6 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
def do_predict(args):
device_ids = list(range(args.num_devices))
@contextlib.contextmanager
def null_guard():
yield
......@@ -59,22 +58,23 @@ def do_predict(args):
guard = fluid.dygraph.guard() if args.eager_run else null_guard()
# define the data generator
processor = reader.DataProcessor(fpattern=args.predict_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=False,
batch_size=args.batch_size,
device_count=1,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
processor = reader.DataProcessor(
fpattern=args.predict_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=False,
batch_size=args.batch_size,
device_count=1,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="predict")
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary()
......@@ -86,25 +86,38 @@ def do_predict(args):
test_loader = batch_generator
# define model
transformer = InferTransformer(args.src_vocab_size,
args.trg_vocab_size,
args.max_length + 1,
args.n_layer,
args.n_head,
args.d_key,
args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.bos_idx,
args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
inputs = [
Input(
[None, None], "int64", name="src_word"), Input(
[None, None], "int64", name="src_pos"), Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"), Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias")
]
transformer = InferTransformer(
args.src_vocab_size,
args.trg_vocab_size,
args.max_length + 1,
args.n_layer,
args.n_head,
args.d_key,
args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.bos_idx,
args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
transformer.prepare(inputs=inputs)
# load the trained model
assert args.init_from_params, (
......@@ -115,11 +128,8 @@ def do_predict(args):
for input_data in test_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word,
trg_src_attn_bias) = input_data
finished_seq = transformer.test(inputs=(src_word, src_pos,
src_slf_attn_bias,
trg_src_attn_bias),
device='gpu',
device_ids=device_ids)[0]
finished_seq = transformer.test(inputs=(
src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias))[0]
finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
......
......@@ -54,8 +54,7 @@ class RNNUnit(Layer):
if sys.version_info < (3, ):
integer_types = (
int,
long,
)
long, )
else:
integer_types = (int, )
"""For shape, list/tuple of integer is the finest-grained objection"""
......@@ -67,8 +66,8 @@ class RNNUnit(Layer):
# TODO: Add check for the illegal
if isinstance(seq, dict):
return True
return (isinstance(seq, collections.Sequence)
and not isinstance(seq, six.string_types))
return (isinstance(seq, collections.Sequence) and
not isinstance(seq, six.string_types))
class Shape(object):
def __init__(self, shape):
......@@ -174,6 +173,7 @@ class BasicLSTMUnit(RNNUnit):
forget_bias(float|1.0): forget bias used when computing forget gate
dtype(string): data type used in this unit
"""
def __init__(self,
hidden_size,
input_size,
......@@ -190,9 +190,8 @@ class BasicLSTMUnit(RNNUnit):
self._bias_attr = bias_attr
self._gate_activation = gate_activation or layers.sigmoid
self._activation = activation or layers.tanh
self._forget_bias = layers.fill_constant([1],
dtype=dtype,
value=forget_bias)
self._forget_bias = layers.fill_constant(
[1], dtype=dtype, value=forget_bias)
self._forget_bias.stop_gradient = False
self._dtype = dtype
self._input_size = input_size
......@@ -204,10 +203,11 @@ class BasicLSTMUnit(RNNUnit):
],
dtype=self._dtype)
self._bias = self.create_parameter(attr=self._bias_attr,
shape=[4 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
self._bias = self.create_parameter(
attr=self._bias_attr,
shape=[4 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, state):
pre_hidden, pre_cell = state
......@@ -260,9 +260,8 @@ class RNN(fluid.dygraph.Layer):
# TODO: use where_op
new_state = fluid.layers.elementwise_mul(
new_state, step_mask,
axis=0) - fluid.layers.elementwise_mul(state,
(step_mask - 1),
axis=0)
axis=0) - fluid.layers.elementwise_mul(
state, (step_mask - 1), axis=0)
return new_state
flat_inputs = flatten(inputs)
......@@ -300,7 +299,9 @@ class RNN(fluid.dygraph.Layer):
**kwargs)
if sequence_length:
new_states = map_structure(
partial(_maybe_copy, step_mask=mask[i]), states,
partial(
_maybe_copy, step_mask=mask[i]),
states,
new_states)
states = new_states
outputs = map_structure(
......@@ -347,10 +348,9 @@ class EncoderCell(RNNUnit):
self.lstm_cells = list()
for i in range(self.num_layers):
self.lstm_cells.append(
self.add_sublayer(
"layer_%d" % i,
BasicLSTMUnit(input_size if i == 0 else hidden_size,
hidden_size)))
self.add_sublayer("layer_%d" % i,
BasicLSTMUnit(input_size if i == 0 else
hidden_size, hidden_size)))
def forward(self, step_input, states):
new_states = []
......@@ -384,18 +384,14 @@ class MultiHeadAttention(Layer):
self.d_value = d_value
self.d_model = d_model
self.dropout_rate = dropout_rate
self.q_fc = Linear(input_dim=d_model,
output_dim=d_key * n_head,
bias_attr=False)
self.k_fc = Linear(input_dim=d_model,
output_dim=d_key * n_head,
bias_attr=False)
self.v_fc = Linear(input_dim=d_model,
output_dim=d_value * n_head,
bias_attr=False)
self.proj_fc = Linear(input_dim=d_value * n_head,
output_dim=d_model,
bias_attr=False)
self.q_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
self.k_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
self.v_fc = Linear(
input_dim=d_model, output_dim=d_value * n_head, bias_attr=False)
self.proj_fc = Linear(
input_dim=d_value * n_head, output_dim=d_model, bias_attr=False)
def forward(self, queries, keys, values, attn_bias, cache=None):
# compute q ,k ,v
......@@ -421,17 +417,14 @@ class MultiHeadAttention(Layer):
cache["k"], cache["v"] = k, v
# scale dot product attention
product = layers.matmul(x=q,
y=k,
transpose_y=True,
alpha=self.d_model**-0.5)
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if self.dropout_rate:
weights = layers.dropout(weights,
dropout_prob=self.dropout_rate,
is_test=False)
weights = layers.dropout(
weights, dropout_prob=self.dropout_rate, is_test=False)
out = layers.matmul(weights, v)
......@@ -497,14 +490,13 @@ class DynamicDecode(Layer):
inputs, states, finished = (initial_inputs, initial_states,
initial_finished)
cond = layers.logical_not((layers.reduce_all(initial_finished)))
sequence_lengths = layers.cast(layers.zeros_like(initial_finished),
"int64")
sequence_lengths = layers.cast(
layers.zeros_like(initial_finished), "int64")
outputs = None
step_idx = 0
step_idx_tensor = layers.fill_constant(shape=[1],
dtype="int64",
value=step_idx)
step_idx_tensor = layers.fill_constant(
shape=[1], dtype="int64", value=step_idx)
while cond.numpy():
(step_outputs, next_states, next_inputs,
next_finished) = self.decoder.step(step_idx_tensor, inputs,
......@@ -512,8 +504,8 @@ class DynamicDecode(Layer):
next_finished = layers.logical_or(next_finished, finished)
next_sequence_lengths = layers.elementwise_add(
sequence_lengths,
layers.cast(layers.logical_not(finished),
sequence_lengths.dtype))
layers.cast(
layers.logical_not(finished), sequence_lengths.dtype))
if self.impute_finished: # rectify the states for the finished.
next_states = map_structure(
......@@ -570,6 +562,7 @@ class TransfomerCell(object):
Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
used as RNNCell
"""
def __init__(self, decoder):
self.decoder = decoder
......@@ -593,20 +586,16 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
self.var_dim_in_state = var_dim_in_state
def _merge_batch_beams_with_var_dim(self, x):
if not hasattr(self, "batch_size"):
self.batch_size = layers.shape(x)[0]
if not hasattr(self, "batch_beam_size"):
self.batch_beam_size = self.batch_size * self.beam_size
# init length of cache is 0, and it increases with decoding carrying on,
# thus need to reshape elaborately
var_dim_in_state = self.var_dim_in_state + 1 # count in beam dim
x = layers.transpose(
x,
list(range(var_dim_in_state, len(x.shape))) +
list(range(0, var_dim_in_state)))
x = layers.reshape(x, [0] * (len(x.shape) - var_dim_in_state) +
[self.batch_beam_size] +
list(x.shape[-var_dim_in_state + 2:]))
x = layers.transpose(x,
list(range(var_dim_in_state, len(x.shape))) +
list(range(0, var_dim_in_state)))
x = layers.reshape(
x, [0] * (len(x.shape) - var_dim_in_state
) + [self.batch_size * self.beam_size] +
[int(size) for size in x.shape[-var_dim_in_state + 2:]])
x = layers.transpose(
x,
list(range((len(x.shape) + 1 - var_dim_in_state), len(x.shape))) +
......@@ -616,8 +605,10 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
def _split_batch_beams_with_var_dim(self, x):
var_dim_size = layers.shape(x)[self.var_dim_in_state]
x = layers.reshape(
x, [-1, self.beam_size] + list(x.shape[1:self.var_dim_in_state]) +
[var_dim_size] + list(x.shape[self.var_dim_in_state + 1:]))
x, [-1, self.beam_size] +
[int(size)
for size in x.shape[1:self.var_dim_in_state]] + [var_dim_size] +
[int(size) for size in x.shape[self.var_dim_in_state + 1:]])
return x
def step(self, time, inputs, states, **kwargs):
......@@ -642,137 +633,3 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
beam_search_state.finished)
return (beam_search_output, beam_search_state, next_inputs, finished)
'''
@contextlib.contextmanager
def eager_guard(is_eager):
if is_eager:
with fluid.dygraph.guard():
yield
else:
yield
# print(flatten(np.random.rand(2,8,8)))
random_seed = 123
np.random.seed(random_seed)
# print np.random.rand(2, 8)
batch_size = 2
seq_len = 8
hidden_size = 8
vocab_size, embed_dim, num_layers, hidden_size = 100, 8, 2, 8
bos_id, eos_id, beam_size, max_step_num = 0, 1, 5, 10
time_major = False
eagar_run = False
import torch
with eager_guard(eagar_run):
fluid.default_main_program().random_seed = random_seed
fluid.default_startup_program().random_seed = random_seed
inputs_data = np.random.rand(batch_size, seq_len,
hidden_size).astype("float32")
states_data = np.random.rand(batch_size, hidden_size).astype("float32")
lstm_cell = BasicLSTMUnit(hidden_size=8, input_size=8)
lstm = RNN(cell=lstm_cell, time_major=time_major)
inputs = to_variable(inputs_data) if eagar_run else fluid.data(
name="x", shape=[None, None, hidden_size], dtype="float32")
states = lstm_cell.get_initial_states(batch_ref=inputs,
batch_dim_idx=1 if time_major else 0)
out, _ = lstm(inputs, states)
# print states
# print layers.BeamSearchDecoder.tile_beam_merge_with_batch(out, 5)
# embedder = Embedding(size=(vocab_size, embed_dim))
# output_layer = Linear(hidden_size, vocab_size)
# decoder = layers.BeamSearchDecoder(lstm_cell,
# bos_id,
# eos_id,
# beam_size,
# embedding_fn=embedder,
# output_fn=output_layer)
# dynamic_decoder = DynamicDecode(decoder, max_step_num)
# out,_ = dynamic_decoder(inits=states)
# caches = [{
# "k":
# layers.fill_constant_batch_size_like(out,
# shape=[-1, 8, 0, 64],
# dtype="float32",
# value=0),
# "v":
# layers.fill_constant_batch_size_like(out,
# shape=[-1, 8, 0, 64],
# dtype="float32",
# value=0)
# } for i in range(6)]
cache = layers.fill_constant_batch_size_like(out,
shape=[-1, 8, 0, 64],
dtype="float32",
value=0)
print cache
# out = layers.BeamSearchDecoder.tile_beam_merge_with_batch(cache, 5)
# out = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(cache, 5)
# batch_beam_size = layers.shape(out)[0] * 5
# print out
cell = TransfomerCell(None)
decoder = TransformerBeamSearchDecoder(cell, 0, 1, 5, 2)
cache = decoder._expand_to_beam_size(cache)
print cache
cache = decoder._merge_batch_beams_with_var_dim(cache)
print cache
cache1 = layers.fill_constant_batch_size_like(cache,
shape=[-1, 8, 1, 64],
dtype="float32",
value=0)
print cache1.shape
cache = layers.concat([cache, cache1], axis=2)
out = decoder._split_batch_beams_with_var_dim(cache)
# out = layers.transpose(out,
# list(range(3, len(out.shape))) + list(range(0, 3)))
# print out
# out = layers.reshape(out, list(out.shape[:2]) + [batch_beam_size, 8])
# print out
# out = layers.transpose(out, [2,3,0,1])
print out.shape
if eagar_run:
print "hehe" #out #.numpy()
else:
executor.run(fluid.default_startup_program())
inputs = fluid.data(name="x",
shape=[None, None, hidden_size],
dtype="float32")
out_np = executor.run(feed={"x": inputs_data},
fetch_list=[out.name])[0]
print np.array(out_np).shape
exit(0)
# dygraph
# inputs = to_variable(inputs_data)
# states = lstm_cell.get_initial_states(batch_ref=inputs,
# batch_dim_idx=1 if time_major else 0)
# print lstm(inputs, states)[0].numpy()
# graph
executor.run(fluid.default_startup_program())
inputs = fluid.data(name="x",
shape=[None, None, hidden_size],
dtype="float32")
states = lstm_cell.get_initial_states(batch_ref=inputs,
batch_dim_idx=1 if time_major else 0)
out, _ = lstm(inputs, states)
out_np = executor.run(feed={"x": inputs_data}, fetch_list=[out.name])[0]
print np.array(out_np)
#print fluid.io.save_inference_model(dirname="test_model", feeded_var_names=["x"], target_vars=[out], executor=executor, model_filename="model.pdmodel", params_filename="params.pdparams")
# test_program, feed_target_names, fetch_targets = fluid.io.load_inference_model(dirname="test_model", executor=executor, model_filename="model.pdmodel", params_filename="params.pdparams")
# out = executor.run(program=test_program, feed={"x": np.random.rand(2, 8, 8).astype("float32")}, fetch_list=fetch_targets)[0]
'''
\ No newline at end of file
......@@ -31,10 +31,11 @@ from utils.check import check_gpu, check_version
# include task-specific libs
import reader
from transformer import Transformer, CrossEntropyCriterion, NoamDecay
from model import Input
def do_train(args):
device_ids = list(range(args.num_devices))
trainer_count = 1 #get_nranks()
@contextlib.contextmanager
def null_guard():
......@@ -43,23 +44,27 @@ def do_train(args):
guard = fluid.dygraph.guard() if args.eager_run else null_guard()
# define the data generator
processor = reader.DataProcessor(fpattern=args.training_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
device_count=args.num_devices,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
processor = reader.DataProcessor(
fpattern=args.training_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
device_count=trainer_count,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="train")
if trainer_count > 1: # for multi-process gpu training
batch_generator = fluid.contrib.reader.distributed_batch_reader(
batch_generator)
if args.validation_file:
val_processor = reader.DataProcessor(
fpattern=args.validation_file,
......@@ -68,7 +73,7 @@ def do_train(args):
token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
device_count=args.num_devices,
device_count=trainer_count,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=False,
......@@ -82,7 +87,6 @@ def do_train(args):
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary()
with guard:
# set seed for CE
random_seed = eval(str(args.random_seed))
......@@ -96,6 +100,28 @@ def do_train(args):
val_loader = val_batch_generator
# define model
inputs = [
Input(
[None, None], "int64", name="src_word"), Input(
[None, None], "int64", name="src_pos"), Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"), Input(
[None, None], "int64", name="trg_word"), Input(
[None, None], "int64", name="trg_pos"), Input(
[None, args.n_head, None, None],
"float32",
name="trg_slf_attn_bias"), Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias")
]
labels = [
Input(
[None, 1], "int64", name="label"), Input(
[None, 1], "float32", name="weight")
]
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
......@@ -112,7 +138,9 @@ def do_train(args):
beta2=args.beta2,
epsilon=float(args.eps),
parameter_list=transformer.parameters()),
CrossEntropyCriterion(args.label_smooth_eps))
CrossEntropyCriterion(args.label_smooth_eps),
inputs=inputs,
labels=labels)
## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint:
......@@ -126,9 +154,8 @@ def do_train(args):
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) +
args.label_smooth_eps * np.log(args.label_smooth_eps /
(args.trg_vocab_size - 1) + 1e-20))
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
step_idx = 0
# train loop
......@@ -136,10 +163,7 @@ def do_train(args):
pass_start_time = time.time()
batch_id = 0
for input_data in train_loader():
outputs, losses = transformer.train(input_data[:-2],
input_data[-2:],
device='gpu',
device_ids=device_ids)
losses = transformer.train(input_data[:-2], input_data[-2:])
if step_idx % args.print_step == 0:
total_avg_cost = np.sum(losses)
......@@ -149,30 +173,27 @@ def do_train(args):
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" %
"normalized loss: %f, ppl: %f, speed: %.2f step/s"
%
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
avg_batch_time = time.time()
if step_idx % args.save_step == 0 and step_idx != 0:
# validation: how to accumulate with Model loss
if args.validation_file:
total_avg_cost = 0
for idx, input_data in enumerate(val_loader()):
outputs, losses = transformer.eval(
input_data[:-2],
input_data[-2:],
device='gpu',
device_ids=device_ids)
losses = transformer.eval(input_data[:-2],
input_data[-2:])
total_avg_cost += np.sum(losses)
total_avg_cost /= idx + 1
logging.info("validation, step_idx: %d, avg loss: %f, "
......@@ -181,10 +202,9 @@ def do_train(args):
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
transformer.save(
os.path.join(args.save_model,
"step_" + str(step_idx),
"transformer"))
transformer.save(
os.path.join(args.save_model, "step_" + str(step_idx),
"transformer"))
batch_id += 1
step_idx += 1
......
......@@ -20,7 +20,7 @@ import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable
from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay
from model import Model, shape_hints, CrossEntropy, Loss
from model import Model, CrossEntropy, Loss
def position_encoding_init(n_position, d_pos_vec):
......@@ -32,10 +32,10 @@ def position_encoding_init(n_position, d_pos_vec):
num_timescales = channels // 2
log_timescale_increment = (np.log(float(1e4) / float(1)) /
(num_timescales - 1))
inv_timescales = np.exp(
np.arange(num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(
inv_timescales, 0)
inv_timescales = np.exp(np.arange(
num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal
......@@ -46,6 +46,7 @@ class NoamDecay(LearningRateDecay):
"""
learning rate scheduler
"""
def __init__(self,
d_model,
warmup_steps,
......@@ -70,6 +71,7 @@ class PrePostProcessLayer(Layer):
"""
PrePostProcessLayer
"""
def __init__(self, process_cmd, d_model, dropout_rate):
super(PrePostProcessLayer, self).__init__()
self.process_cmd = process_cmd
......@@ -80,8 +82,8 @@ class PrePostProcessLayer(Layer):
elif cmd == "n": # add layer normalization
self.functors.append(
self.add_sublayer(
"layer_norm_%d" %
len(self.sublayers(include_sublayers=False)),
"layer_norm_%d" % len(
self.sublayers(include_sublayers=False)),
LayerNorm(
normalized_shape=d_model,
param_attr=fluid.ParamAttr(
......@@ -106,6 +108,7 @@ class MultiHeadAttention(Layer):
"""
Multi-Head Attention
"""
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
......@@ -113,18 +116,14 @@ class MultiHeadAttention(Layer):
self.d_value = d_value
self.d_model = d_model
self.dropout_rate = dropout_rate
self.q_fc = Linear(input_dim=d_model,
output_dim=d_key * n_head,
bias_attr=False)
self.k_fc = Linear(input_dim=d_model,
output_dim=d_key * n_head,
bias_attr=False)
self.v_fc = Linear(input_dim=d_model,
output_dim=d_value * n_head,
bias_attr=False)
self.proj_fc = Linear(input_dim=d_value * n_head,
output_dim=d_model,
bias_attr=False)
self.q_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
self.k_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
self.v_fc = Linear(
input_dim=d_model, output_dim=d_value * n_head, bias_attr=False)
self.proj_fc = Linear(
input_dim=d_value * n_head, output_dim=d_model, bias_attr=False)
def _prepare_qkv(self, queries, keys, values, cache=None):
if keys is None: # self-attention
......@@ -167,17 +166,14 @@ class MultiHeadAttention(Layer):
q, k, v = self._prepare_qkv(queries, keys, values, cache)
# scale dot product attention
product = layers.matmul(x=q,
y=k,
transpose_y=True,
alpha=self.d_model**-0.5)
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if self.dropout_rate:
weights = layers.dropout(weights,
dropout_prob=self.dropout_rate,
is_test=False)
weights = layers.dropout(
weights, dropout_prob=self.dropout_rate, is_test=False)
out = layers.matmul(weights, v)
......@@ -203,18 +199,19 @@ class FFN(Layer):
"""
Feed-Forward Network
"""
def __init__(self, d_inner_hid, d_model, dropout_rate):
super(FFN, self).__init__()
self.dropout_rate = dropout_rate
self.fc1 = Linear(input_dim=d_model, output_dim=d_inner_hid, act="relu")
self.fc1 = Linear(
input_dim=d_model, output_dim=d_inner_hid, act="relu")
self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model)
def forward(self, x):
hidden = self.fc1(x)
if self.dropout_rate:
hidden = layers.dropout(hidden,
dropout_prob=self.dropout_rate,
is_test=False)
hidden = layers.dropout(
hidden, dropout_prob=self.dropout_rate, is_test=False)
out = self.fc2(hidden)
return out
......@@ -223,6 +220,7 @@ class EncoderLayer(Layer):
"""
EncoderLayer
"""
def __init__(self,
n_head,
d_key,
......@@ -251,8 +249,8 @@ class EncoderLayer(Layer):
prepostprocess_dropout)
def forward(self, enc_input, attn_bias):
attn_output = self.self_attn(self.preprocesser1(enc_input), None, None,
attn_bias)
attn_output = self.self_attn(
self.preprocesser1(enc_input), None, None, attn_bias)
attn_output = self.postprocesser1(attn_output, enc_input)
ffn_output = self.ffn(self.preprocesser2(attn_output))
......@@ -264,6 +262,7 @@ class Encoder(Layer):
"""
encoder
"""
def __init__(self,
n_layer,
n_head,
......@@ -303,6 +302,7 @@ class Embedder(Layer):
"""
Word Embedding + Position Encoding
"""
def __init__(self, vocab_size, emb_dim, bos_idx=0):
super(Embedder, self).__init__()
......@@ -321,6 +321,7 @@ class WrapEncoder(Layer):
"""
embedder + encoder
"""
def __init__(self, src_vocab_size, max_length, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
......@@ -348,9 +349,9 @@ class WrapEncoder(Layer):
pos_enc = self.pos_encoder(src_pos)
pos_enc.stop_gradient = True
emb = word_emb + pos_enc
enc_input = layers.dropout(emb,
dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb
enc_input = layers.dropout(
emb, dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb
enc_output = self.encoder(enc_input, src_slf_attn_bias)
return enc_output
......@@ -360,6 +361,7 @@ class DecoderLayer(Layer):
"""
decoder
"""
def __init__(self,
n_head,
d_key,
......@@ -399,8 +401,8 @@ class DecoderLayer(Layer):
self_attn_bias,
cross_attn_bias,
cache=None):
self_attn_output = self.self_attn(self.preprocesser1(dec_input), None,
None, self_attn_bias, cache)
self_attn_output = self.self_attn(
self.preprocesser1(dec_input), None, None, self_attn_bias, cache)
self_attn_output = self.postprocesser1(self_attn_output, dec_input)
cross_attn_output = self.cross_attn(
......@@ -419,6 +421,7 @@ class Decoder(Layer):
"""
decoder
"""
def __init__(self, n_layer, n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd):
......@@ -444,8 +447,8 @@ class Decoder(Layer):
caches=None):
for i, decoder_layer in enumerate(self.decoder_layers):
dec_output = decoder_layer(dec_input, enc_output, self_attn_bias,
cross_attn_bias,
None if caches is None else caches[i])
cross_attn_bias, None
if caches is None else caches[i])
dec_input = dec_output
return self.processer(dec_output)
......@@ -463,6 +466,7 @@ class WrapDecoder(Layer):
"""
embedder + decoder
"""
def __init__(self, trg_vocab_size, max_length, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
......@@ -490,9 +494,8 @@ class WrapDecoder(Layer):
word_embedder.weight,
transpose_y=True)
else:
self.linear = Linear(input_dim=d_model,
output_dim=trg_vocab_size,
bias_attr=False)
self.linear = Linear(
input_dim=d_model, output_dim=trg_vocab_size, bias_attr=False)
def forward(self,
trg_word,
......@@ -506,53 +509,30 @@ class WrapDecoder(Layer):
pos_enc = self.pos_encoder(trg_pos)
pos_enc.stop_gradient = True
emb = word_emb + pos_enc
dec_input = layers.dropout(emb,
dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb
dec_input = layers.dropout(
emb, dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb
dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias,
trg_src_attn_bias, caches)
dec_output = layers.reshape(
dec_output,
shape=[-1, dec_output.shape[-1]],
)
shape=[-1, dec_output.shape[-1]], )
logits = self.linear(dec_output)
return logits
# class CrossEntropyCriterion(object):
# def __init__(self, label_smooth_eps):
# self.label_smooth_eps = label_smooth_eps
# def __call__(self, predict, label, weights):
# if self.label_smooth_eps:
# label_out = layers.label_smooth(label=layers.one_hot(
# input=label, depth=predict.shape[-1]),
# epsilon=self.label_smooth_eps)
# cost = layers.softmax_with_cross_entropy(
# logits=predict,
# label=label_out,
# soft_label=True if self.label_smooth_eps else False)
# weighted_cost = cost * weights
# sum_cost = layers.reduce_sum(weighted_cost)
# token_num = layers.reduce_sum(weights)
# token_num.stop_gradient = True
# avg_cost = sum_cost / token_num
# return sum_cost, avg_cost, token_num
class CrossEntropyCriterion(Loss):
def __init__(self, label_smooth_eps):
super(CrossEntropyCriterion, self).__init__()
self.label_smooth_eps = label_smooth_eps
def forward(self, outputs, labels):
predict = outputs[0]
label, weights = labels
predict, (label, weights) = outputs[0], labels
if self.label_smooth_eps:
label = layers.label_smooth(label=layers.one_hot(
input=label, depth=predict.shape[-1]),
epsilon=self.label_smooth_eps)
label = layers.label_smooth(
label=layers.one_hot(
input=label, depth=predict.shape[-1]),
epsilon=self.label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=predict,
......@@ -565,17 +545,12 @@ class CrossEntropyCriterion(Loss):
avg_cost = sum_cost / token_num
return avg_cost
def infer_shape(self, _):
return [[None, 1], [None, 1]]
def infer_dtype(self, _):
return ["int64", "float32"]
class Transformer(Model):
"""
model
"""
def __init__(self,
src_vocab_size,
trg_vocab_size,
......@@ -595,29 +570,25 @@ class Transformer(Model):
bos_id=0,
eos_id=1):
super(Transformer, self).__init__()
src_word_embedder = Embedder(vocab_size=src_vocab_size,
emb_dim=d_model,
bos_idx=bos_id)
self.encoder = WrapEncoder(src_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd, src_word_embedder)
src_word_embedder = Embedder(
vocab_size=src_vocab_size, emb_dim=d_model, bos_idx=bos_id)
self.encoder = WrapEncoder(
src_vocab_size, max_length, n_layer, n_head, d_key, d_value,
d_model, d_inner_hid, prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd, src_word_embedder)
if weight_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
trg_word_embedder = src_word_embedder
else:
trg_word_embedder = Embedder(vocab_size=trg_vocab_size,
emb_dim=d_model,
bos_idx=bos_id)
self.decoder = WrapDecoder(trg_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd, weight_sharing,
trg_word_embedder)
trg_word_embedder = Embedder(
vocab_size=trg_vocab_size, emb_dim=d_model, bos_idx=bos_id)
self.decoder = WrapDecoder(
trg_vocab_size, max_length, n_layer, n_head, d_key, d_value,
d_model, d_inner_hid, prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd, weight_sharing,
trg_word_embedder)
self.trg_vocab_size = trg_vocab_size
self.n_layer = n_layer
......@@ -625,13 +596,6 @@ class Transformer(Model):
self.d_key = d_key
self.d_value = d_value
@shape_hints(src_word=[None, None],
src_pos=[None, None],
src_slf_attn_bias=[None, 8, None, None],
trg_word=[None, None],
trg_pos=[None, None],
trg_slf_attn_bias=[None, 8, None, None],
trg_src_attn_bias=[None, 8, None, None])
def forward(self, src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias):
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
......@@ -648,6 +612,7 @@ class TransfomerCell(object):
Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
used as RNNCell
"""
def __init__(self, decoder):
self.decoder = decoder
......@@ -666,6 +631,7 @@ class InferTransformer(Transformer):
"""
model for prediction
"""
def __init__(self,
src_vocab_size,
trg_vocab_size,
......@@ -693,29 +659,21 @@ class InferTransformer(Transformer):
super(InferTransformer, self).__init__(**args)
cell = TransfomerCell(self.decoder)
self.beam_search_decoder = DynamicDecode(
TransformerBeamSearchDecoder(cell,
bos_id,
eos_id,
beam_size,
var_dim_in_state=2), max_out_len)
TransformerBeamSearchDecoder(
cell, bos_id, eos_id, beam_size, var_dim_in_state=2),
max_out_len,
is_test=True)
@shape_hints(src_word=[None, None],
src_pos=[None, None],
src_slf_attn_bias=[None, 8, None, None],
trg_src_attn_bias=[None, 8, None, None])
def forward(self, src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias):
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
## init states (caches) for transformer, need to be updated according to selected beam
caches = [{
"k":
layers.fill_constant_batch_size_like(
"k": layers.fill_constant_batch_size_like(
input=enc_output,
shape=[-1, self.n_head, 0, self.d_key],
dtype=enc_output.dtype,
value=0),
"v":
layers.fill_constant_batch_size_like(
"v": layers.fill_constant_batch_size_like(
input=enc_output,
shape=[-1, self.n_head, 0, self.d_value],
dtype=enc_output.dtype,
......@@ -725,10 +683,10 @@ class InferTransformer(Transformer):
enc_output, self.beam_size)
trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
trg_src_attn_bias, self.beam_size)
static_caches = self.decoder.decoder.prepare_static_cache(
enc_output)
rs, _ = self.beam_search_decoder(inits=caches,
enc_output=enc_output,
trg_src_attn_bias=trg_src_attn_bias,
static_caches=static_caches)
static_caches = self.decoder.decoder.prepare_static_cache(enc_output)
rs, _ = self.beam_search_decoder(
inits=caches,
enc_output=enc_output,
trg_src_attn_bias=trg_src_attn_bias,
static_caches=static_caches)
return rs
......@@ -2,7 +2,6 @@
enable_ce: False
eager_run: False
num_devices: 1
# The frequency to save trained models when training.
save_step: 10000
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册