提交 51ef7e15 编写于 作者: G guosheng

Update transformer to adapt to latest code

上级 a28a6e3b
...@@ -30,6 +30,7 @@ from utils.check import check_gpu, check_version ...@@ -30,6 +30,7 @@ from utils.check import check_gpu, check_version
# include task-specific libs # include task-specific libs
import reader import reader
from transformer import InferTransformer, position_encoding_init from transformer import InferTransformer, position_encoding_init
from model import Input
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
...@@ -50,8 +51,6 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, ...@@ -50,8 +51,6 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False,
def do_predict(args): def do_predict(args):
device_ids = list(range(args.num_devices))
@contextlib.contextmanager @contextlib.contextmanager
def null_guard(): def null_guard():
yield yield
...@@ -59,22 +58,23 @@ def do_predict(args): ...@@ -59,22 +58,23 @@ def do_predict(args):
guard = fluid.dygraph.guard() if args.eager_run else null_guard() guard = fluid.dygraph.guard() if args.eager_run else null_guard()
# define the data generator # define the data generator
processor = reader.DataProcessor(fpattern=args.predict_file, processor = reader.DataProcessor(
src_vocab_fpath=args.src_vocab_fpath, fpattern=args.predict_file,
trg_vocab_fpath=args.trg_vocab_fpath, src_vocab_fpath=args.src_vocab_fpath,
token_delimiter=args.token_delimiter, trg_vocab_fpath=args.trg_vocab_fpath,
use_token_batch=False, token_delimiter=args.token_delimiter,
batch_size=args.batch_size, use_token_batch=False,
device_count=1, batch_size=args.batch_size,
pool_size=args.pool_size, device_count=1,
sort_type=reader.SortType.NONE, pool_size=args.pool_size,
shuffle=False, sort_type=reader.SortType.NONE,
shuffle_batch=False, shuffle=False,
start_mark=args.special_token[0], shuffle_batch=False,
end_mark=args.special_token[1], start_mark=args.special_token[0],
unk_mark=args.special_token[2], end_mark=args.special_token[1],
max_length=args.max_length, unk_mark=args.special_token[2],
n_head=args.n_head) max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="predict") batch_generator = processor.data_generator(phase="predict")
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary() args.unk_idx = processor.get_vocab_summary()
...@@ -86,25 +86,38 @@ def do_predict(args): ...@@ -86,25 +86,38 @@ def do_predict(args):
test_loader = batch_generator test_loader = batch_generator
# define model # define model
transformer = InferTransformer(args.src_vocab_size, inputs = [
args.trg_vocab_size, Input(
args.max_length + 1, [None, None], "int64", name="src_word"), Input(
args.n_layer, [None, None], "int64", name="src_pos"), Input(
args.n_head, [None, args.n_head, None, None],
args.d_key, "float32",
args.d_value, name="src_slf_attn_bias"), Input(
args.d_model, [None, args.n_head, None, None],
args.d_inner_hid, "float32",
args.prepostprocess_dropout, name="trg_src_attn_bias")
args.attention_dropout, ]
args.relu_dropout, transformer = InferTransformer(
args.preprocess_cmd, args.src_vocab_size,
args.postprocess_cmd, args.trg_vocab_size,
args.weight_sharing, args.max_length + 1,
args.bos_idx, args.n_layer,
args.eos_idx, args.n_head,
beam_size=args.beam_size, args.d_key,
max_out_len=args.max_out_len) args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.bos_idx,
args.eos_idx,
beam_size=args.beam_size,
max_out_len=args.max_out_len)
transformer.prepare(inputs=inputs)
# load the trained model # load the trained model
assert args.init_from_params, ( assert args.init_from_params, (
...@@ -115,11 +128,8 @@ def do_predict(args): ...@@ -115,11 +128,8 @@ def do_predict(args):
for input_data in test_loader(): for input_data in test_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word, (src_word, src_pos, src_slf_attn_bias, trg_word,
trg_src_attn_bias) = input_data trg_src_attn_bias) = input_data
finished_seq = transformer.test(inputs=(src_word, src_pos, finished_seq = transformer.test(inputs=(
src_slf_attn_bias, src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias))[0]
trg_src_attn_bias),
device='gpu',
device_ids=device_ids)[0]
finished_seq = np.transpose(finished_seq, [0, 2, 1]) finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq: for ins in finished_seq:
for beam_idx, beam in enumerate(ins): for beam_idx, beam in enumerate(ins):
......
...@@ -54,8 +54,7 @@ class RNNUnit(Layer): ...@@ -54,8 +54,7 @@ class RNNUnit(Layer):
if sys.version_info < (3, ): if sys.version_info < (3, ):
integer_types = ( integer_types = (
int, int,
long, long, )
)
else: else:
integer_types = (int, ) integer_types = (int, )
"""For shape, list/tuple of integer is the finest-grained objection""" """For shape, list/tuple of integer is the finest-grained objection"""
...@@ -67,8 +66,8 @@ class RNNUnit(Layer): ...@@ -67,8 +66,8 @@ class RNNUnit(Layer):
# TODO: Add check for the illegal # TODO: Add check for the illegal
if isinstance(seq, dict): if isinstance(seq, dict):
return True return True
return (isinstance(seq, collections.Sequence) return (isinstance(seq, collections.Sequence) and
and not isinstance(seq, six.string_types)) not isinstance(seq, six.string_types))
class Shape(object): class Shape(object):
def __init__(self, shape): def __init__(self, shape):
...@@ -174,6 +173,7 @@ class BasicLSTMUnit(RNNUnit): ...@@ -174,6 +173,7 @@ class BasicLSTMUnit(RNNUnit):
forget_bias(float|1.0): forget bias used when computing forget gate forget_bias(float|1.0): forget bias used when computing forget gate
dtype(string): data type used in this unit dtype(string): data type used in this unit
""" """
def __init__(self, def __init__(self,
hidden_size, hidden_size,
input_size, input_size,
...@@ -190,9 +190,8 @@ class BasicLSTMUnit(RNNUnit): ...@@ -190,9 +190,8 @@ class BasicLSTMUnit(RNNUnit):
self._bias_attr = bias_attr self._bias_attr = bias_attr
self._gate_activation = gate_activation or layers.sigmoid self._gate_activation = gate_activation or layers.sigmoid
self._activation = activation or layers.tanh self._activation = activation or layers.tanh
self._forget_bias = layers.fill_constant([1], self._forget_bias = layers.fill_constant(
dtype=dtype, [1], dtype=dtype, value=forget_bias)
value=forget_bias)
self._forget_bias.stop_gradient = False self._forget_bias.stop_gradient = False
self._dtype = dtype self._dtype = dtype
self._input_size = input_size self._input_size = input_size
...@@ -204,10 +203,11 @@ class BasicLSTMUnit(RNNUnit): ...@@ -204,10 +203,11 @@ class BasicLSTMUnit(RNNUnit):
], ],
dtype=self._dtype) dtype=self._dtype)
self._bias = self.create_parameter(attr=self._bias_attr, self._bias = self.create_parameter(
shape=[4 * self._hidden_size], attr=self._bias_attr,
dtype=self._dtype, shape=[4 * self._hidden_size],
is_bias=True) dtype=self._dtype,
is_bias=True)
def forward(self, input, state): def forward(self, input, state):
pre_hidden, pre_cell = state pre_hidden, pre_cell = state
...@@ -260,9 +260,8 @@ class RNN(fluid.dygraph.Layer): ...@@ -260,9 +260,8 @@ class RNN(fluid.dygraph.Layer):
# TODO: use where_op # TODO: use where_op
new_state = fluid.layers.elementwise_mul( new_state = fluid.layers.elementwise_mul(
new_state, step_mask, new_state, step_mask,
axis=0) - fluid.layers.elementwise_mul(state, axis=0) - fluid.layers.elementwise_mul(
(step_mask - 1), state, (step_mask - 1), axis=0)
axis=0)
return new_state return new_state
flat_inputs = flatten(inputs) flat_inputs = flatten(inputs)
...@@ -300,7 +299,9 @@ class RNN(fluid.dygraph.Layer): ...@@ -300,7 +299,9 @@ class RNN(fluid.dygraph.Layer):
**kwargs) **kwargs)
if sequence_length: if sequence_length:
new_states = map_structure( new_states = map_structure(
partial(_maybe_copy, step_mask=mask[i]), states, partial(
_maybe_copy, step_mask=mask[i]),
states,
new_states) new_states)
states = new_states states = new_states
outputs = map_structure( outputs = map_structure(
...@@ -347,10 +348,9 @@ class EncoderCell(RNNUnit): ...@@ -347,10 +348,9 @@ class EncoderCell(RNNUnit):
self.lstm_cells = list() self.lstm_cells = list()
for i in range(self.num_layers): for i in range(self.num_layers):
self.lstm_cells.append( self.lstm_cells.append(
self.add_sublayer( self.add_sublayer("layer_%d" % i,
"layer_%d" % i, BasicLSTMUnit(input_size if i == 0 else
BasicLSTMUnit(input_size if i == 0 else hidden_size, hidden_size, hidden_size)))
hidden_size)))
def forward(self, step_input, states): def forward(self, step_input, states):
new_states = [] new_states = []
...@@ -384,18 +384,14 @@ class MultiHeadAttention(Layer): ...@@ -384,18 +384,14 @@ class MultiHeadAttention(Layer):
self.d_value = d_value self.d_value = d_value
self.d_model = d_model self.d_model = d_model
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.q_fc = Linear(input_dim=d_model, self.q_fc = Linear(
output_dim=d_key * n_head, input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
bias_attr=False) self.k_fc = Linear(
self.k_fc = Linear(input_dim=d_model, input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
output_dim=d_key * n_head, self.v_fc = Linear(
bias_attr=False) input_dim=d_model, output_dim=d_value * n_head, bias_attr=False)
self.v_fc = Linear(input_dim=d_model, self.proj_fc = Linear(
output_dim=d_value * n_head, input_dim=d_value * n_head, output_dim=d_model, bias_attr=False)
bias_attr=False)
self.proj_fc = Linear(input_dim=d_value * n_head,
output_dim=d_model,
bias_attr=False)
def forward(self, queries, keys, values, attn_bias, cache=None): def forward(self, queries, keys, values, attn_bias, cache=None):
# compute q ,k ,v # compute q ,k ,v
...@@ -421,17 +417,14 @@ class MultiHeadAttention(Layer): ...@@ -421,17 +417,14 @@ class MultiHeadAttention(Layer):
cache["k"], cache["v"] = k, v cache["k"], cache["v"] = k, v
# scale dot product attention # scale dot product attention
product = layers.matmul(x=q, product = layers.matmul(
y=k, x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5)
transpose_y=True,
alpha=self.d_model**-0.5)
if attn_bias: if attn_bias:
product += attn_bias product += attn_bias
weights = layers.softmax(product) weights = layers.softmax(product)
if self.dropout_rate: if self.dropout_rate:
weights = layers.dropout(weights, weights = layers.dropout(
dropout_prob=self.dropout_rate, weights, dropout_prob=self.dropout_rate, is_test=False)
is_test=False)
out = layers.matmul(weights, v) out = layers.matmul(weights, v)
...@@ -497,14 +490,13 @@ class DynamicDecode(Layer): ...@@ -497,14 +490,13 @@ class DynamicDecode(Layer):
inputs, states, finished = (initial_inputs, initial_states, inputs, states, finished = (initial_inputs, initial_states,
initial_finished) initial_finished)
cond = layers.logical_not((layers.reduce_all(initial_finished))) cond = layers.logical_not((layers.reduce_all(initial_finished)))
sequence_lengths = layers.cast(layers.zeros_like(initial_finished), sequence_lengths = layers.cast(
"int64") layers.zeros_like(initial_finished), "int64")
outputs = None outputs = None
step_idx = 0 step_idx = 0
step_idx_tensor = layers.fill_constant(shape=[1], step_idx_tensor = layers.fill_constant(
dtype="int64", shape=[1], dtype="int64", value=step_idx)
value=step_idx)
while cond.numpy(): while cond.numpy():
(step_outputs, next_states, next_inputs, (step_outputs, next_states, next_inputs,
next_finished) = self.decoder.step(step_idx_tensor, inputs, next_finished) = self.decoder.step(step_idx_tensor, inputs,
...@@ -512,8 +504,8 @@ class DynamicDecode(Layer): ...@@ -512,8 +504,8 @@ class DynamicDecode(Layer):
next_finished = layers.logical_or(next_finished, finished) next_finished = layers.logical_or(next_finished, finished)
next_sequence_lengths = layers.elementwise_add( next_sequence_lengths = layers.elementwise_add(
sequence_lengths, sequence_lengths,
layers.cast(layers.logical_not(finished), layers.cast(
sequence_lengths.dtype)) layers.logical_not(finished), sequence_lengths.dtype))
if self.impute_finished: # rectify the states for the finished. if self.impute_finished: # rectify the states for the finished.
next_states = map_structure( next_states = map_structure(
...@@ -570,6 +562,7 @@ class TransfomerCell(object): ...@@ -570,6 +562,7 @@ class TransfomerCell(object):
Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
used as RNNCell used as RNNCell
""" """
def __init__(self, decoder): def __init__(self, decoder):
self.decoder = decoder self.decoder = decoder
...@@ -593,20 +586,16 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): ...@@ -593,20 +586,16 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
self.var_dim_in_state = var_dim_in_state self.var_dim_in_state = var_dim_in_state
def _merge_batch_beams_with_var_dim(self, x): def _merge_batch_beams_with_var_dim(self, x):
if not hasattr(self, "batch_size"):
self.batch_size = layers.shape(x)[0]
if not hasattr(self, "batch_beam_size"):
self.batch_beam_size = self.batch_size * self.beam_size
# init length of cache is 0, and it increases with decoding carrying on, # init length of cache is 0, and it increases with decoding carrying on,
# thus need to reshape elaborately # thus need to reshape elaborately
var_dim_in_state = self.var_dim_in_state + 1 # count in beam dim var_dim_in_state = self.var_dim_in_state + 1 # count in beam dim
x = layers.transpose( x = layers.transpose(x,
x, list(range(var_dim_in_state, len(x.shape))) +
list(range(var_dim_in_state, len(x.shape))) + list(range(0, var_dim_in_state)))
list(range(0, var_dim_in_state))) x = layers.reshape(
x = layers.reshape(x, [0] * (len(x.shape) - var_dim_in_state) + x, [0] * (len(x.shape) - var_dim_in_state
[self.batch_beam_size] + ) + [self.batch_size * self.beam_size] +
list(x.shape[-var_dim_in_state + 2:])) [int(size) for size in x.shape[-var_dim_in_state + 2:]])
x = layers.transpose( x = layers.transpose(
x, x,
list(range((len(x.shape) + 1 - var_dim_in_state), len(x.shape))) + list(range((len(x.shape) + 1 - var_dim_in_state), len(x.shape))) +
...@@ -616,8 +605,10 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): ...@@ -616,8 +605,10 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
def _split_batch_beams_with_var_dim(self, x): def _split_batch_beams_with_var_dim(self, x):
var_dim_size = layers.shape(x)[self.var_dim_in_state] var_dim_size = layers.shape(x)[self.var_dim_in_state]
x = layers.reshape( x = layers.reshape(
x, [-1, self.beam_size] + list(x.shape[1:self.var_dim_in_state]) + x, [-1, self.beam_size] +
[var_dim_size] + list(x.shape[self.var_dim_in_state + 1:])) [int(size)
for size in x.shape[1:self.var_dim_in_state]] + [var_dim_size] +
[int(size) for size in x.shape[self.var_dim_in_state + 1:]])
return x return x
def step(self, time, inputs, states, **kwargs): def step(self, time, inputs, states, **kwargs):
...@@ -642,137 +633,3 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): ...@@ -642,137 +633,3 @@ class TransformerBeamSearchDecoder(layers.BeamSearchDecoder):
beam_search_state.finished) beam_search_state.finished)
return (beam_search_output, beam_search_state, next_inputs, finished) return (beam_search_output, beam_search_state, next_inputs, finished)
'''
@contextlib.contextmanager
def eager_guard(is_eager):
if is_eager:
with fluid.dygraph.guard():
yield
else:
yield
# print(flatten(np.random.rand(2,8,8)))
random_seed = 123
np.random.seed(random_seed)
# print np.random.rand(2, 8)
batch_size = 2
seq_len = 8
hidden_size = 8
vocab_size, embed_dim, num_layers, hidden_size = 100, 8, 2, 8
bos_id, eos_id, beam_size, max_step_num = 0, 1, 5, 10
time_major = False
eagar_run = False
import torch
with eager_guard(eagar_run):
fluid.default_main_program().random_seed = random_seed
fluid.default_startup_program().random_seed = random_seed
inputs_data = np.random.rand(batch_size, seq_len,
hidden_size).astype("float32")
states_data = np.random.rand(batch_size, hidden_size).astype("float32")
lstm_cell = BasicLSTMUnit(hidden_size=8, input_size=8)
lstm = RNN(cell=lstm_cell, time_major=time_major)
inputs = to_variable(inputs_data) if eagar_run else fluid.data(
name="x", shape=[None, None, hidden_size], dtype="float32")
states = lstm_cell.get_initial_states(batch_ref=inputs,
batch_dim_idx=1 if time_major else 0)
out, _ = lstm(inputs, states)
# print states
# print layers.BeamSearchDecoder.tile_beam_merge_with_batch(out, 5)
# embedder = Embedding(size=(vocab_size, embed_dim))
# output_layer = Linear(hidden_size, vocab_size)
# decoder = layers.BeamSearchDecoder(lstm_cell,
# bos_id,
# eos_id,
# beam_size,
# embedding_fn=embedder,
# output_fn=output_layer)
# dynamic_decoder = DynamicDecode(decoder, max_step_num)
# out,_ = dynamic_decoder(inits=states)
# caches = [{
# "k":
# layers.fill_constant_batch_size_like(out,
# shape=[-1, 8, 0, 64],
# dtype="float32",
# value=0),
# "v":
# layers.fill_constant_batch_size_like(out,
# shape=[-1, 8, 0, 64],
# dtype="float32",
# value=0)
# } for i in range(6)]
cache = layers.fill_constant_batch_size_like(out,
shape=[-1, 8, 0, 64],
dtype="float32",
value=0)
print cache
# out = layers.BeamSearchDecoder.tile_beam_merge_with_batch(cache, 5)
# out = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(cache, 5)
# batch_beam_size = layers.shape(out)[0] * 5
# print out
cell = TransfomerCell(None)
decoder = TransformerBeamSearchDecoder(cell, 0, 1, 5, 2)
cache = decoder._expand_to_beam_size(cache)
print cache
cache = decoder._merge_batch_beams_with_var_dim(cache)
print cache
cache1 = layers.fill_constant_batch_size_like(cache,
shape=[-1, 8, 1, 64],
dtype="float32",
value=0)
print cache1.shape
cache = layers.concat([cache, cache1], axis=2)
out = decoder._split_batch_beams_with_var_dim(cache)
# out = layers.transpose(out,
# list(range(3, len(out.shape))) + list(range(0, 3)))
# print out
# out = layers.reshape(out, list(out.shape[:2]) + [batch_beam_size, 8])
# print out
# out = layers.transpose(out, [2,3,0,1])
print out.shape
if eagar_run:
print "hehe" #out #.numpy()
else:
executor.run(fluid.default_startup_program())
inputs = fluid.data(name="x",
shape=[None, None, hidden_size],
dtype="float32")
out_np = executor.run(feed={"x": inputs_data},
fetch_list=[out.name])[0]
print np.array(out_np).shape
exit(0)
# dygraph
# inputs = to_variable(inputs_data)
# states = lstm_cell.get_initial_states(batch_ref=inputs,
# batch_dim_idx=1 if time_major else 0)
# print lstm(inputs, states)[0].numpy()
# graph
executor.run(fluid.default_startup_program())
inputs = fluid.data(name="x",
shape=[None, None, hidden_size],
dtype="float32")
states = lstm_cell.get_initial_states(batch_ref=inputs,
batch_dim_idx=1 if time_major else 0)
out, _ = lstm(inputs, states)
out_np = executor.run(feed={"x": inputs_data}, fetch_list=[out.name])[0]
print np.array(out_np)
#print fluid.io.save_inference_model(dirname="test_model", feeded_var_names=["x"], target_vars=[out], executor=executor, model_filename="model.pdmodel", params_filename="params.pdparams")
# test_program, feed_target_names, fetch_targets = fluid.io.load_inference_model(dirname="test_model", executor=executor, model_filename="model.pdmodel", params_filename="params.pdparams")
# out = executor.run(program=test_program, feed={"x": np.random.rand(2, 8, 8).astype("float32")}, fetch_list=fetch_targets)[0]
'''
\ No newline at end of file
...@@ -31,10 +31,11 @@ from utils.check import check_gpu, check_version ...@@ -31,10 +31,11 @@ from utils.check import check_gpu, check_version
# include task-specific libs # include task-specific libs
import reader import reader
from transformer import Transformer, CrossEntropyCriterion, NoamDecay from transformer import Transformer, CrossEntropyCriterion, NoamDecay
from model import Input
def do_train(args): def do_train(args):
device_ids = list(range(args.num_devices)) trainer_count = 1 #get_nranks()
@contextlib.contextmanager @contextlib.contextmanager
def null_guard(): def null_guard():
...@@ -43,23 +44,27 @@ def do_train(args): ...@@ -43,23 +44,27 @@ def do_train(args):
guard = fluid.dygraph.guard() if args.eager_run else null_guard() guard = fluid.dygraph.guard() if args.eager_run else null_guard()
# define the data generator # define the data generator
processor = reader.DataProcessor(fpattern=args.training_file, processor = reader.DataProcessor(
src_vocab_fpath=args.src_vocab_fpath, fpattern=args.training_file,
trg_vocab_fpath=args.trg_vocab_fpath, src_vocab_fpath=args.src_vocab_fpath,
token_delimiter=args.token_delimiter, trg_vocab_fpath=args.trg_vocab_fpath,
use_token_batch=args.use_token_batch, token_delimiter=args.token_delimiter,
batch_size=args.batch_size, use_token_batch=args.use_token_batch,
device_count=args.num_devices, batch_size=args.batch_size,
pool_size=args.pool_size, device_count=trainer_count,
sort_type=args.sort_type, pool_size=args.pool_size,
shuffle=args.shuffle, sort_type=args.sort_type,
shuffle_batch=args.shuffle_batch, shuffle=args.shuffle,
start_mark=args.special_token[0], shuffle_batch=args.shuffle_batch,
end_mark=args.special_token[1], start_mark=args.special_token[0],
unk_mark=args.special_token[2], end_mark=args.special_token[1],
max_length=args.max_length, unk_mark=args.special_token[2],
n_head=args.n_head) max_length=args.max_length,
n_head=args.n_head)
batch_generator = processor.data_generator(phase="train") batch_generator = processor.data_generator(phase="train")
if trainer_count > 1: # for multi-process gpu training
batch_generator = fluid.contrib.reader.distributed_batch_reader(
batch_generator)
if args.validation_file: if args.validation_file:
val_processor = reader.DataProcessor( val_processor = reader.DataProcessor(
fpattern=args.validation_file, fpattern=args.validation_file,
...@@ -68,7 +73,7 @@ def do_train(args): ...@@ -68,7 +73,7 @@ def do_train(args):
token_delimiter=args.token_delimiter, token_delimiter=args.token_delimiter,
use_token_batch=args.use_token_batch, use_token_batch=args.use_token_batch,
batch_size=args.batch_size, batch_size=args.batch_size,
device_count=args.num_devices, device_count=trainer_count,
pool_size=args.pool_size, pool_size=args.pool_size,
sort_type=args.sort_type, sort_type=args.sort_type,
shuffle=False, shuffle=False,
...@@ -82,7 +87,6 @@ def do_train(args): ...@@ -82,7 +87,6 @@ def do_train(args):
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = processor.get_vocab_summary() args.unk_idx = processor.get_vocab_summary()
with guard: with guard:
# set seed for CE # set seed for CE
random_seed = eval(str(args.random_seed)) random_seed = eval(str(args.random_seed))
...@@ -96,6 +100,28 @@ def do_train(args): ...@@ -96,6 +100,28 @@ def do_train(args):
val_loader = val_batch_generator val_loader = val_batch_generator
# define model # define model
inputs = [
Input(
[None, None], "int64", name="src_word"), Input(
[None, None], "int64", name="src_pos"), Input(
[None, args.n_head, None, None],
"float32",
name="src_slf_attn_bias"), Input(
[None, None], "int64", name="trg_word"), Input(
[None, None], "int64", name="trg_pos"), Input(
[None, args.n_head, None, None],
"float32",
name="trg_slf_attn_bias"), Input(
[None, args.n_head, None, None],
"float32",
name="trg_src_attn_bias")
]
labels = [
Input(
[None, 1], "int64", name="label"), Input(
[None, 1], "float32", name="weight")
]
transformer = Transformer( transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1, args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model, args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
...@@ -112,7 +138,9 @@ def do_train(args): ...@@ -112,7 +138,9 @@ def do_train(args):
beta2=args.beta2, beta2=args.beta2,
epsilon=float(args.eps), epsilon=float(args.eps),
parameter_list=transformer.parameters()), parameter_list=transformer.parameters()),
CrossEntropyCriterion(args.label_smooth_eps)) CrossEntropyCriterion(args.label_smooth_eps),
inputs=inputs,
labels=labels)
## init from some checkpoint, to resume the previous training ## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint: if args.init_from_checkpoint:
...@@ -126,9 +154,8 @@ def do_train(args): ...@@ -126,9 +154,8 @@ def do_train(args):
# the best cross-entropy value with label smoothing # the best cross-entropy value with label smoothing
loss_normalizer = -( loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log( (1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + (1. - args.label_smooth_eps)) + args.label_smooth_eps *
args.label_smooth_eps * np.log(args.label_smooth_eps / np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
(args.trg_vocab_size - 1) + 1e-20))
step_idx = 0 step_idx = 0
# train loop # train loop
...@@ -136,10 +163,7 @@ def do_train(args): ...@@ -136,10 +163,7 @@ def do_train(args):
pass_start_time = time.time() pass_start_time = time.time()
batch_id = 0 batch_id = 0
for input_data in train_loader(): for input_data in train_loader():
outputs, losses = transformer.train(input_data[:-2], losses = transformer.train(input_data[:-2], input_data[-2:])
input_data[-2:],
device='gpu',
device_ids=device_ids)
if step_idx % args.print_step == 0: if step_idx % args.print_step == 0:
total_avg_cost = np.sum(losses) total_avg_cost = np.sum(losses)
...@@ -149,30 +173,27 @@ def do_train(args): ...@@ -149,30 +173,27 @@ def do_train(args):
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" % "normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost, (step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer, total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]))) np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time() avg_batch_time = time.time()
else: else:
logging.info( logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" % "normalized loss: %f, ppl: %f, speed: %.2f step/s"
%
(step_idx, pass_id, batch_id, total_avg_cost, (step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer, total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]), np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time))) args.print_step / (time.time() - avg_batch_time)))
avg_batch_time = time.time() avg_batch_time = time.time()
if step_idx % args.save_step == 0 and step_idx != 0: if step_idx % args.save_step == 0 and step_idx != 0:
# validation: how to accumulate with Model loss # validation: how to accumulate with Model loss
if args.validation_file: if args.validation_file:
total_avg_cost = 0 total_avg_cost = 0
for idx, input_data in enumerate(val_loader()): for idx, input_data in enumerate(val_loader()):
outputs, losses = transformer.eval( losses = transformer.eval(input_data[:-2],
input_data[:-2], input_data[-2:])
input_data[-2:],
device='gpu',
device_ids=device_ids)
total_avg_cost += np.sum(losses) total_avg_cost += np.sum(losses)
total_avg_cost /= idx + 1 total_avg_cost /= idx + 1
logging.info("validation, step_idx: %d, avg loss: %f, " logging.info("validation, step_idx: %d, avg loss: %f, "
...@@ -181,10 +202,9 @@ def do_train(args): ...@@ -181,10 +202,9 @@ def do_train(args):
total_avg_cost - loss_normalizer, total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]))) np.exp([min(total_avg_cost, 100)])))
transformer.save( transformer.save(
os.path.join(args.save_model, os.path.join(args.save_model, "step_" + str(step_idx),
"step_" + str(step_idx), "transformer"))
"transformer"))
batch_id += 1 batch_id += 1
step_idx += 1 step_idx += 1
......
...@@ -20,7 +20,7 @@ import paddle.fluid as fluid ...@@ -20,7 +20,7 @@ import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable
from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay
from model import Model, shape_hints, CrossEntropy, Loss from model import Model, CrossEntropy, Loss
def position_encoding_init(n_position, d_pos_vec): def position_encoding_init(n_position, d_pos_vec):
...@@ -32,10 +32,10 @@ def position_encoding_init(n_position, d_pos_vec): ...@@ -32,10 +32,10 @@ def position_encoding_init(n_position, d_pos_vec):
num_timescales = channels // 2 num_timescales = channels // 2
log_timescale_increment = (np.log(float(1e4) / float(1)) / log_timescale_increment = (np.log(float(1e4) / float(1)) /
(num_timescales - 1)) (num_timescales - 1))
inv_timescales = np.exp( inv_timescales = np.exp(np.arange(
np.arange(num_timescales)) * -log_timescale_increment num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims( scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
inv_timescales, 0) 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant') signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal position_enc = signal
...@@ -46,6 +46,7 @@ class NoamDecay(LearningRateDecay): ...@@ -46,6 +46,7 @@ class NoamDecay(LearningRateDecay):
""" """
learning rate scheduler learning rate scheduler
""" """
def __init__(self, def __init__(self,
d_model, d_model,
warmup_steps, warmup_steps,
...@@ -70,6 +71,7 @@ class PrePostProcessLayer(Layer): ...@@ -70,6 +71,7 @@ class PrePostProcessLayer(Layer):
""" """
PrePostProcessLayer PrePostProcessLayer
""" """
def __init__(self, process_cmd, d_model, dropout_rate): def __init__(self, process_cmd, d_model, dropout_rate):
super(PrePostProcessLayer, self).__init__() super(PrePostProcessLayer, self).__init__()
self.process_cmd = process_cmd self.process_cmd = process_cmd
...@@ -80,8 +82,8 @@ class PrePostProcessLayer(Layer): ...@@ -80,8 +82,8 @@ class PrePostProcessLayer(Layer):
elif cmd == "n": # add layer normalization elif cmd == "n": # add layer normalization
self.functors.append( self.functors.append(
self.add_sublayer( self.add_sublayer(
"layer_norm_%d" % "layer_norm_%d" % len(
len(self.sublayers(include_sublayers=False)), self.sublayers(include_sublayers=False)),
LayerNorm( LayerNorm(
normalized_shape=d_model, normalized_shape=d_model,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
...@@ -106,6 +108,7 @@ class MultiHeadAttention(Layer): ...@@ -106,6 +108,7 @@ class MultiHeadAttention(Layer):
""" """
Multi-Head Attention Multi-Head Attention
""" """
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.): def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
self.n_head = n_head self.n_head = n_head
...@@ -113,18 +116,14 @@ class MultiHeadAttention(Layer): ...@@ -113,18 +116,14 @@ class MultiHeadAttention(Layer):
self.d_value = d_value self.d_value = d_value
self.d_model = d_model self.d_model = d_model
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.q_fc = Linear(input_dim=d_model, self.q_fc = Linear(
output_dim=d_key * n_head, input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
bias_attr=False) self.k_fc = Linear(
self.k_fc = Linear(input_dim=d_model, input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
output_dim=d_key * n_head, self.v_fc = Linear(
bias_attr=False) input_dim=d_model, output_dim=d_value * n_head, bias_attr=False)
self.v_fc = Linear(input_dim=d_model, self.proj_fc = Linear(
output_dim=d_value * n_head, input_dim=d_value * n_head, output_dim=d_model, bias_attr=False)
bias_attr=False)
self.proj_fc = Linear(input_dim=d_value * n_head,
output_dim=d_model,
bias_attr=False)
def _prepare_qkv(self, queries, keys, values, cache=None): def _prepare_qkv(self, queries, keys, values, cache=None):
if keys is None: # self-attention if keys is None: # self-attention
...@@ -167,17 +166,14 @@ class MultiHeadAttention(Layer): ...@@ -167,17 +166,14 @@ class MultiHeadAttention(Layer):
q, k, v = self._prepare_qkv(queries, keys, values, cache) q, k, v = self._prepare_qkv(queries, keys, values, cache)
# scale dot product attention # scale dot product attention
product = layers.matmul(x=q, product = layers.matmul(
y=k, x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5)
transpose_y=True,
alpha=self.d_model**-0.5)
if attn_bias: if attn_bias:
product += attn_bias product += attn_bias
weights = layers.softmax(product) weights = layers.softmax(product)
if self.dropout_rate: if self.dropout_rate:
weights = layers.dropout(weights, weights = layers.dropout(
dropout_prob=self.dropout_rate, weights, dropout_prob=self.dropout_rate, is_test=False)
is_test=False)
out = layers.matmul(weights, v) out = layers.matmul(weights, v)
...@@ -203,18 +199,19 @@ class FFN(Layer): ...@@ -203,18 +199,19 @@ class FFN(Layer):
""" """
Feed-Forward Network Feed-Forward Network
""" """
def __init__(self, d_inner_hid, d_model, dropout_rate): def __init__(self, d_inner_hid, d_model, dropout_rate):
super(FFN, self).__init__() super(FFN, self).__init__()
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.fc1 = Linear(input_dim=d_model, output_dim=d_inner_hid, act="relu") self.fc1 = Linear(
input_dim=d_model, output_dim=d_inner_hid, act="relu")
self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model) self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model)
def forward(self, x): def forward(self, x):
hidden = self.fc1(x) hidden = self.fc1(x)
if self.dropout_rate: if self.dropout_rate:
hidden = layers.dropout(hidden, hidden = layers.dropout(
dropout_prob=self.dropout_rate, hidden, dropout_prob=self.dropout_rate, is_test=False)
is_test=False)
out = self.fc2(hidden) out = self.fc2(hidden)
return out return out
...@@ -223,6 +220,7 @@ class EncoderLayer(Layer): ...@@ -223,6 +220,7 @@ class EncoderLayer(Layer):
""" """
EncoderLayer EncoderLayer
""" """
def __init__(self, def __init__(self,
n_head, n_head,
d_key, d_key,
...@@ -251,8 +249,8 @@ class EncoderLayer(Layer): ...@@ -251,8 +249,8 @@ class EncoderLayer(Layer):
prepostprocess_dropout) prepostprocess_dropout)
def forward(self, enc_input, attn_bias): def forward(self, enc_input, attn_bias):
attn_output = self.self_attn(self.preprocesser1(enc_input), None, None, attn_output = self.self_attn(
attn_bias) self.preprocesser1(enc_input), None, None, attn_bias)
attn_output = self.postprocesser1(attn_output, enc_input) attn_output = self.postprocesser1(attn_output, enc_input)
ffn_output = self.ffn(self.preprocesser2(attn_output)) ffn_output = self.ffn(self.preprocesser2(attn_output))
...@@ -264,6 +262,7 @@ class Encoder(Layer): ...@@ -264,6 +262,7 @@ class Encoder(Layer):
""" """
encoder encoder
""" """
def __init__(self, def __init__(self,
n_layer, n_layer,
n_head, n_head,
...@@ -303,6 +302,7 @@ class Embedder(Layer): ...@@ -303,6 +302,7 @@ class Embedder(Layer):
""" """
Word Embedding + Position Encoding Word Embedding + Position Encoding
""" """
def __init__(self, vocab_size, emb_dim, bos_idx=0): def __init__(self, vocab_size, emb_dim, bos_idx=0):
super(Embedder, self).__init__() super(Embedder, self).__init__()
...@@ -321,6 +321,7 @@ class WrapEncoder(Layer): ...@@ -321,6 +321,7 @@ class WrapEncoder(Layer):
""" """
embedder + encoder embedder + encoder
""" """
def __init__(self, src_vocab_size, max_length, n_layer, n_head, d_key, def __init__(self, src_vocab_size, max_length, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd, attention_dropout, relu_dropout, preprocess_cmd,
...@@ -348,9 +349,9 @@ class WrapEncoder(Layer): ...@@ -348,9 +349,9 @@ class WrapEncoder(Layer):
pos_enc = self.pos_encoder(src_pos) pos_enc = self.pos_encoder(src_pos)
pos_enc.stop_gradient = True pos_enc.stop_gradient = True
emb = word_emb + pos_enc emb = word_emb + pos_enc
enc_input = layers.dropout(emb, enc_input = layers.dropout(
dropout_prob=self.emb_dropout, emb, dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb is_test=False) if self.emb_dropout else emb
enc_output = self.encoder(enc_input, src_slf_attn_bias) enc_output = self.encoder(enc_input, src_slf_attn_bias)
return enc_output return enc_output
...@@ -360,6 +361,7 @@ class DecoderLayer(Layer): ...@@ -360,6 +361,7 @@ class DecoderLayer(Layer):
""" """
decoder decoder
""" """
def __init__(self, def __init__(self,
n_head, n_head,
d_key, d_key,
...@@ -399,8 +401,8 @@ class DecoderLayer(Layer): ...@@ -399,8 +401,8 @@ class DecoderLayer(Layer):
self_attn_bias, self_attn_bias,
cross_attn_bias, cross_attn_bias,
cache=None): cache=None):
self_attn_output = self.self_attn(self.preprocesser1(dec_input), None, self_attn_output = self.self_attn(
None, self_attn_bias, cache) self.preprocesser1(dec_input), None, None, self_attn_bias, cache)
self_attn_output = self.postprocesser1(self_attn_output, dec_input) self_attn_output = self.postprocesser1(self_attn_output, dec_input)
cross_attn_output = self.cross_attn( cross_attn_output = self.cross_attn(
...@@ -419,6 +421,7 @@ class Decoder(Layer): ...@@ -419,6 +421,7 @@ class Decoder(Layer):
""" """
decoder decoder
""" """
def __init__(self, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, def __init__(self, n_layer, n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout, relu_dropout, prepostprocess_dropout, attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd): preprocess_cmd, postprocess_cmd):
...@@ -444,8 +447,8 @@ class Decoder(Layer): ...@@ -444,8 +447,8 @@ class Decoder(Layer):
caches=None): caches=None):
for i, decoder_layer in enumerate(self.decoder_layers): for i, decoder_layer in enumerate(self.decoder_layers):
dec_output = decoder_layer(dec_input, enc_output, self_attn_bias, dec_output = decoder_layer(dec_input, enc_output, self_attn_bias,
cross_attn_bias, cross_attn_bias, None
None if caches is None else caches[i]) if caches is None else caches[i])
dec_input = dec_output dec_input = dec_output
return self.processer(dec_output) return self.processer(dec_output)
...@@ -463,6 +466,7 @@ class WrapDecoder(Layer): ...@@ -463,6 +466,7 @@ class WrapDecoder(Layer):
""" """
embedder + decoder embedder + decoder
""" """
def __init__(self, trg_vocab_size, max_length, n_layer, n_head, d_key, def __init__(self, trg_vocab_size, max_length, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd, attention_dropout, relu_dropout, preprocess_cmd,
...@@ -490,9 +494,8 @@ class WrapDecoder(Layer): ...@@ -490,9 +494,8 @@ class WrapDecoder(Layer):
word_embedder.weight, word_embedder.weight,
transpose_y=True) transpose_y=True)
else: else:
self.linear = Linear(input_dim=d_model, self.linear = Linear(
output_dim=trg_vocab_size, input_dim=d_model, output_dim=trg_vocab_size, bias_attr=False)
bias_attr=False)
def forward(self, def forward(self,
trg_word, trg_word,
...@@ -506,53 +509,30 @@ class WrapDecoder(Layer): ...@@ -506,53 +509,30 @@ class WrapDecoder(Layer):
pos_enc = self.pos_encoder(trg_pos) pos_enc = self.pos_encoder(trg_pos)
pos_enc.stop_gradient = True pos_enc.stop_gradient = True
emb = word_emb + pos_enc emb = word_emb + pos_enc
dec_input = layers.dropout(emb, dec_input = layers.dropout(
dropout_prob=self.emb_dropout, emb, dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb is_test=False) if self.emb_dropout else emb
dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias, dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias,
trg_src_attn_bias, caches) trg_src_attn_bias, caches)
dec_output = layers.reshape( dec_output = layers.reshape(
dec_output, dec_output,
shape=[-1, dec_output.shape[-1]], shape=[-1, dec_output.shape[-1]], )
)
logits = self.linear(dec_output) logits = self.linear(dec_output)
return logits return logits
# class CrossEntropyCriterion(object):
# def __init__(self, label_smooth_eps):
# self.label_smooth_eps = label_smooth_eps
# def __call__(self, predict, label, weights):
# if self.label_smooth_eps:
# label_out = layers.label_smooth(label=layers.one_hot(
# input=label, depth=predict.shape[-1]),
# epsilon=self.label_smooth_eps)
# cost = layers.softmax_with_cross_entropy(
# logits=predict,
# label=label_out,
# soft_label=True if self.label_smooth_eps else False)
# weighted_cost = cost * weights
# sum_cost = layers.reduce_sum(weighted_cost)
# token_num = layers.reduce_sum(weights)
# token_num.stop_gradient = True
# avg_cost = sum_cost / token_num
# return sum_cost, avg_cost, token_num
class CrossEntropyCriterion(Loss): class CrossEntropyCriterion(Loss):
def __init__(self, label_smooth_eps): def __init__(self, label_smooth_eps):
super(CrossEntropyCriterion, self).__init__() super(CrossEntropyCriterion, self).__init__()
self.label_smooth_eps = label_smooth_eps self.label_smooth_eps = label_smooth_eps
def forward(self, outputs, labels): def forward(self, outputs, labels):
predict = outputs[0] predict, (label, weights) = outputs[0], labels
label, weights = labels
if self.label_smooth_eps: if self.label_smooth_eps:
label = layers.label_smooth(label=layers.one_hot( label = layers.label_smooth(
input=label, depth=predict.shape[-1]), label=layers.one_hot(
epsilon=self.label_smooth_eps) input=label, depth=predict.shape[-1]),
epsilon=self.label_smooth_eps)
cost = layers.softmax_with_cross_entropy( cost = layers.softmax_with_cross_entropy(
logits=predict, logits=predict,
...@@ -565,17 +545,12 @@ class CrossEntropyCriterion(Loss): ...@@ -565,17 +545,12 @@ class CrossEntropyCriterion(Loss):
avg_cost = sum_cost / token_num avg_cost = sum_cost / token_num
return avg_cost return avg_cost
def infer_shape(self, _):
return [[None, 1], [None, 1]]
def infer_dtype(self, _):
return ["int64", "float32"]
class Transformer(Model): class Transformer(Model):
""" """
model model
""" """
def __init__(self, def __init__(self,
src_vocab_size, src_vocab_size,
trg_vocab_size, trg_vocab_size,
...@@ -595,29 +570,25 @@ class Transformer(Model): ...@@ -595,29 +570,25 @@ class Transformer(Model):
bos_id=0, bos_id=0,
eos_id=1): eos_id=1):
super(Transformer, self).__init__() super(Transformer, self).__init__()
src_word_embedder = Embedder(vocab_size=src_vocab_size, src_word_embedder = Embedder(
emb_dim=d_model, vocab_size=src_vocab_size, emb_dim=d_model, bos_idx=bos_id)
bos_idx=bos_id) self.encoder = WrapEncoder(
self.encoder = WrapEncoder(src_vocab_size, max_length, n_layer, n_head, src_vocab_size, max_length, n_layer, n_head, d_key, d_value,
d_key, d_value, d_model, d_inner_hid, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout,
prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd, src_word_embedder)
relu_dropout, preprocess_cmd,
postprocess_cmd, src_word_embedder)
if weight_sharing: if weight_sharing:
assert src_vocab_size == trg_vocab_size, ( assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing." "Vocabularies in source and target should be same for weight sharing."
) )
trg_word_embedder = src_word_embedder trg_word_embedder = src_word_embedder
else: else:
trg_word_embedder = Embedder(vocab_size=trg_vocab_size, trg_word_embedder = Embedder(
emb_dim=d_model, vocab_size=trg_vocab_size, emb_dim=d_model, bos_idx=bos_id)
bos_idx=bos_id) self.decoder = WrapDecoder(
self.decoder = WrapDecoder(trg_vocab_size, max_length, n_layer, n_head, trg_vocab_size, max_length, n_layer, n_head, d_key, d_value,
d_key, d_value, d_model, d_inner_hid, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout,
prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd, weight_sharing,
relu_dropout, preprocess_cmd, trg_word_embedder)
postprocess_cmd, weight_sharing,
trg_word_embedder)
self.trg_vocab_size = trg_vocab_size self.trg_vocab_size = trg_vocab_size
self.n_layer = n_layer self.n_layer = n_layer
...@@ -625,13 +596,6 @@ class Transformer(Model): ...@@ -625,13 +596,6 @@ class Transformer(Model):
self.d_key = d_key self.d_key = d_key
self.d_value = d_value self.d_value = d_value
@shape_hints(src_word=[None, None],
src_pos=[None, None],
src_slf_attn_bias=[None, 8, None, None],
trg_word=[None, None],
trg_pos=[None, None],
trg_slf_attn_bias=[None, 8, None, None],
trg_src_attn_bias=[None, 8, None, None])
def forward(self, src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, def forward(self, src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias): trg_slf_attn_bias, trg_src_attn_bias):
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
...@@ -648,6 +612,7 @@ class TransfomerCell(object): ...@@ -648,6 +612,7 @@ class TransfomerCell(object):
Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
used as RNNCell used as RNNCell
""" """
def __init__(self, decoder): def __init__(self, decoder):
self.decoder = decoder self.decoder = decoder
...@@ -666,6 +631,7 @@ class InferTransformer(Transformer): ...@@ -666,6 +631,7 @@ class InferTransformer(Transformer):
""" """
model for prediction model for prediction
""" """
def __init__(self, def __init__(self,
src_vocab_size, src_vocab_size,
trg_vocab_size, trg_vocab_size,
...@@ -693,29 +659,21 @@ class InferTransformer(Transformer): ...@@ -693,29 +659,21 @@ class InferTransformer(Transformer):
super(InferTransformer, self).__init__(**args) super(InferTransformer, self).__init__(**args)
cell = TransfomerCell(self.decoder) cell = TransfomerCell(self.decoder)
self.beam_search_decoder = DynamicDecode( self.beam_search_decoder = DynamicDecode(
TransformerBeamSearchDecoder(cell, TransformerBeamSearchDecoder(
bos_id, cell, bos_id, eos_id, beam_size, var_dim_in_state=2),
eos_id, max_out_len,
beam_size, is_test=True)
var_dim_in_state=2), max_out_len)
@shape_hints(src_word=[None, None],
src_pos=[None, None],
src_slf_attn_bias=[None, 8, None, None],
trg_src_attn_bias=[None, 8, None, None])
def forward(self, src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias): def forward(self, src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias):
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
## init states (caches) for transformer, need to be updated according to selected beam ## init states (caches) for transformer, need to be updated according to selected beam
caches = [{ caches = [{
"k": "k": layers.fill_constant_batch_size_like(
layers.fill_constant_batch_size_like(
input=enc_output, input=enc_output,
shape=[-1, self.n_head, 0, self.d_key], shape=[-1, self.n_head, 0, self.d_key],
dtype=enc_output.dtype, dtype=enc_output.dtype,
value=0), value=0),
"v": "v": layers.fill_constant_batch_size_like(
layers.fill_constant_batch_size_like(
input=enc_output, input=enc_output,
shape=[-1, self.n_head, 0, self.d_value], shape=[-1, self.n_head, 0, self.d_value],
dtype=enc_output.dtype, dtype=enc_output.dtype,
...@@ -725,10 +683,10 @@ class InferTransformer(Transformer): ...@@ -725,10 +683,10 @@ class InferTransformer(Transformer):
enc_output, self.beam_size) enc_output, self.beam_size)
trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch( trg_src_attn_bias = TransformerBeamSearchDecoder.tile_beam_merge_with_batch(
trg_src_attn_bias, self.beam_size) trg_src_attn_bias, self.beam_size)
static_caches = self.decoder.decoder.prepare_static_cache( static_caches = self.decoder.decoder.prepare_static_cache(enc_output)
enc_output) rs, _ = self.beam_search_decoder(
rs, _ = self.beam_search_decoder(inits=caches, inits=caches,
enc_output=enc_output, enc_output=enc_output,
trg_src_attn_bias=trg_src_attn_bias, trg_src_attn_bias=trg_src_attn_bias,
static_caches=static_caches) static_caches=static_caches)
return rs return rs
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
enable_ce: False enable_ce: False
eager_run: False eager_run: False
num_devices: 1
# The frequency to save trained models when training. # The frequency to save trained models when training.
save_step: 10000 save_step: 10000
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册