提交 02cc92c4 编写于 作者: G Guo Sheng 提交者: Cheerego

Rewrite machine_translation with attention model. (#740)

* Rewrite machine_translation with attention model.

* Replace formula in machine_translation with picture
上级 35d6107f
此差异已折叠。
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.fluid.layers as pd
from paddle.fluid.executor import Executor
import os
dict_size = 30000
source_dict_dim = target_dict_dim = dict_size
hidden_dim = 32
word_dim = 32
batch_size = 2
max_length = 8
topk_size = 50
beam_size = 2
is_sparse = True
decoder_size = hidden_dim
model_save_dir = "machine_translation.inference.model"
def encoder():
src_word_id = pd.data(
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
src_embedding = pd.embedding(
input=src_word_id,
size=[dict_size, word_dim],
dtype='float32',
is_sparse=is_sparse,
param_attr=fluid.ParamAttr(name='vemb'))
fc1 = pd.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
lstm_hidden0, lstm_0 = pd.dynamic_lstm(input=fc1, size=hidden_dim * 4)
encoder_out = pd.sequence_last_step(input=lstm_hidden0)
return encoder_out
def decode(context):
init_state = context
array_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length)
counter = pd.zeros(shape=[1], dtype='int64', force_cpu=True)
# fill the first element with init_state
state_array = pd.create_array('float32')
pd.array_write(init_state, array=state_array, i=counter)
# ids, scores as memory
ids_array = pd.create_array('int64')
scores_array = pd.create_array('float32')
init_ids = pd.data(name="init_ids", shape=[1], dtype="int64", lod_level=2)
init_scores = pd.data(
name="init_scores", shape=[1], dtype="float32", lod_level=2)
pd.array_write(init_ids, array=ids_array, i=counter)
pd.array_write(init_scores, array=scores_array, i=counter)
cond = pd.less_than(x=counter, y=array_len)
while_op = pd.While(cond=cond)
with while_op.block():
pre_ids = pd.array_read(array=ids_array, i=counter)
pre_state = pd.array_read(array=state_array, i=counter)
pre_score = pd.array_read(array=scores_array, i=counter)
# expand the lod of pre_state to be the same with pre_score
pre_state_expanded = pd.sequence_expand(pre_state, pre_score)
pre_ids_emb = pd.embedding(
input=pre_ids,
size=[dict_size, word_dim],
dtype='float32',
is_sparse=is_sparse,
param_attr=fluid.ParamAttr(name='vemb'))
# use rnn unit to update rnn
current_state = pd.fc(
input=[pre_state_expanded, pre_ids_emb],
size=decoder_size,
act='tanh')
current_state_with_lod = pd.lod_reset(x=current_state, y=pre_score)
# use score to do beam search
current_score = pd.fc(
input=current_state_with_lod, size=target_dict_dim, act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)
with pd.Switch() as switch:
with switch.case(pd.is_empty(selected_ids)):
pd.fill_constant(
shape=[1], value=0, dtype='bool', force_cpu=True, out=cond)
with switch.default():
pd.increment(x=counter, value=1, in_place=True)
# update the memories
pd.array_write(current_state, array=state_array, i=counter)
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)
translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)
return translation_ids, translation_scores
def decode_main(use_cuda):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = Executor(place)
exe.run(framework.default_startup_program())
context = encoder()
translation_ids, translation_scores = decode(context)
fluid.io.load_persistables(executor=exe, dirname=model_save_dir)
init_ids_data = np.array([1 for _ in range(batch_size)], dtype='int64')
init_scores_data = np.array(
[1. for _ in range(batch_size)], dtype='float32')
init_ids_data = init_ids_data.reshape((batch_size, 1))
init_scores_data = init_scores_data.reshape((batch_size, 1))
init_lod = [1] * batch_size
init_lod = [init_lod, init_lod]
init_ids = fluid.create_lod_tensor(init_ids_data, init_lod, place)
init_scores = fluid.create_lod_tensor(init_scores_data, init_lod, place)
test_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.test(dict_size), buf_size=1000),
batch_size=batch_size)
feed_order = ['src_word_id']
feed_list = [
framework.default_main_program().global_block().var(var_name)
for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)
for data in test_data():
feed_data = map(lambda x: [x[0]], data)
feed_dict = feeder.feed(feed_data)
feed_dict['init_ids'] = init_ids
feed_dict['init_scores'] = init_scores
results = exe.run(
framework.default_main_program(),
feed=feed_dict,
fetch_list=[translation_ids, translation_scores],
return_numpy=False)
result_ids = np.array(results[0])
result_ids_lod = results[0].lod()
result_scores = np.array(results[1])
print("Original sentence:")
print(" ".join([src_dict[w] for w in feed_data[0][0][1:-1]]))
print("Translated score and sentence:")
for i in xrange(beam_size):
start_pos = result_ids_lod[1][i] + 1
end_pos = result_ids_lod[1][i + 1]
print("%d\t%.4f\t%s\n" % (
i + 1, result_scores[end_pos - 1],
" ".join([trg_dict[w] for w in result_ids[start_pos:end_pos]])))
break
def main(use_cuda):
decode_main(False) # Beam Search does not support CUDA
if __name__ == '__main__':
use_cuda = os.getenv('WITH_GPU', '0') != '0'
main(use_cuda)
......@@ -12,130 +12,315 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import six
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as pd
import os
import sys
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
dict_size = 30000
source_dict_dim = target_dict_dim = dict_size
hidden_dim = 32
word_dim = 32
batch_size = 2
max_length = 8
topk_size = 50
beam_size = 2
source_dict_size = target_dict_size = dict_size
word_dim = 512
hidden_dim = 512
decoder_size = hidden_dim
max_length = 256
beam_size = 4
batch_size = 64
is_sparse = True
decoder_size = hidden_dim
model_save_dir = "machine_translation.inference.model"
def encoder():
src_word_id = pd.data(
src_word_id = fluid.layers.data(
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
src_embedding = pd.embedding(
src_embedding = fluid.layers.embedding(
input=src_word_id,
size=[dict_size, word_dim],
size=[source_dict_size, word_dim],
dtype='float32',
is_sparse=is_sparse,
param_attr=fluid.ParamAttr(name='vemb'))
is_sparse=is_sparse)
fc_forward = fluid.layers.fc(
input=src_embedding, size=hidden_dim * 3, bias_attr=False)
src_forward = fluid.layers.dynamic_gru(input=fc_forward, size=hidden_dim)
fc_backward = fluid.layers.fc(
input=src_embedding, size=hidden_dim * 3, bias_attr=False)
src_backward = fluid.layers.dynamic_gru(
input=fc_backward, size=hidden_dim, is_reverse=True)
encoded_vector = fluid.layers.concat(
input=[src_forward, src_backward], axis=1)
return encoded_vector
fc1 = pd.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
lstm_hidden0, lstm_0 = pd.dynamic_lstm(input=fc1, size=hidden_dim * 4)
encoder_out = pd.sequence_last_step(input=lstm_hidden0)
return encoder_out
def cell(x, hidden, encoder_out, encoder_out_proj):
def simple_attention(encoder_vec, encoder_proj, decoder_state):
decoder_state_proj = fluid.layers.fc(
input=decoder_state, size=decoder_size, bias_attr=False)
decoder_state_expand = fluid.layers.sequence_expand(
x=decoder_state_proj, y=encoder_proj)
mixed_state = fluid.layers.elementwise_add(encoder_proj,
decoder_state_expand)
attention_weights = fluid.layers.fc(
input=mixed_state, size=1, bias_attr=False)
attention_weights = fluid.layers.sequence_softmax(
input=attention_weights)
weigths_reshape = fluid.layers.reshape(x=attention_weights, shape=[-1])
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weigths_reshape, axis=0)
context = fluid.layers.sequence_pool(input=scaled, pool_type='sum')
return context
context = simple_attention(encoder_out, encoder_out_proj, hidden)
out = fluid.layers.fc(
input=[x, context], size=decoder_size * 3, bias_attr=False)
out = fluid.layers.gru_unit(
input=out, hidden=hidden, size=decoder_size * 3)[0]
return out, out
def train_decoder(context):
trg_language_word = pd.data(
def train_decoder(encoder_out):
encoder_last = fluid.layers.sequence_last_step(input=encoder_out)
encoder_last_proj = fluid.layers.fc(
input=encoder_last, size=decoder_size, act='tanh')
# cache the encoder_out's computed result in attention
encoder_out_proj = fluid.layers.fc(
input=encoder_out, size=decoder_size, bias_attr=False)
trg_language_word = fluid.layers.data(
name="target_language_word", shape=[1], dtype='int64', lod_level=1)
trg_embedding = pd.embedding(
trg_embedding = fluid.layers.embedding(
input=trg_language_word,
size=[dict_size, word_dim],
size=[target_dict_size, word_dim],
dtype='float32',
is_sparse=is_sparse,
param_attr=fluid.ParamAttr(name='vemb'))
is_sparse=is_sparse)
rnn = pd.DynamicRNN()
rnn = fluid.layers.DynamicRNN()
with rnn.block():
current_word = rnn.step_input(trg_embedding)
pre_state = rnn.memory(init=context, need_reorder=True)
current_state = pd.fc(
input=[current_word, pre_state], size=decoder_size, act='tanh')
x = rnn.step_input(trg_embedding)
pre_state = rnn.memory(init=encoder_last_proj, need_reorder=True)
encoder_out = rnn.static_input(encoder_out)
encoder_out_proj = rnn.static_input(encoder_out_proj)
out, current_state = cell(x, pre_state, encoder_out, encoder_out_proj)
prob = fluid.layers.fc(input=out, size=target_dict_size, act='softmax')
current_score = pd.fc(
input=current_state, size=target_dict_dim, act='softmax')
rnn.update_memory(pre_state, current_state)
rnn.output(current_score)
rnn.output(prob)
return rnn()
def train_program():
context = encoder()
rnn_out = train_decoder(context)
label = pd.data(
def train_model():
encoder_out = encoder()
rnn_out = train_decoder(encoder_out)
label = fluid.layers.data(
name="target_language_next_word", shape=[1], dtype='int64', lod_level=1)
cost = pd.cross_entropy(input=rnn_out, label=label)
avg_cost = pd.mean(cost)
cost = fluid.layers.cross_entropy(input=rnn_out, label=label)
avg_cost = fluid.layers.mean(cost)
return avg_cost
def optimizer_func():
return fluid.optimizer.Adagrad(
learning_rate=1e-4,
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0))
lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(hidden_dim, 1000)
return fluid.optimizer.Adam(
learning_rate=lr_decay,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.1))
regularization_coeff=1e-4))
def train(use_cuda):
EPOCH_NUM = 1
train_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
avg_cost = train_model()
optimizer = optimizer_func()
optimizer.minimize(avg_cost)
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
train_reader = paddle.batch(
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
paddle.dataset.wmt16.train(source_dict_size, target_dict_size),
buf_size=10000),
batch_size=batch_size)
feed_order = [
'src_word_id', 'target_language_word', 'target_language_next_word'
]
feeder = fluid.DataFeeder(
feed_list=[
'src_word_id', 'target_language_word', 'target_language_next_word'
],
place=place,
program=train_prog)
exe.run(startup_prog)
EPOCH_NUM = 20
for pass_id in six.moves.xrange(EPOCH_NUM):
batch_id = 0
for data in train_data():
cost = exe.run(
train_prog, feed=feeder.feed(data), fetch_list=[avg_cost])[0]
print('pass_id: %d, batch_id: %d, loss: %f' % (pass_id, batch_id,
cost))
batch_id += 1
fluid.io.save_params(exe, model_save_dir, main_program=train_prog)
def infer_decoder(encoder_out):
encoder_last = fluid.layers.sequence_last_step(input=encoder_out)
encoder_last_proj = fluid.layers.fc(
input=encoder_last, size=decoder_size, act='tanh')
encoder_out_proj = fluid.layers.fc(
input=encoder_out, size=decoder_size, bias_attr=False)
max_len = fluid.layers.fill_constant(
shape=[1], dtype='int64', value=max_length)
counter = fluid.layers.zeros(shape=[1], dtype='int64', force_cpu=True)
def event_handler(event):
if isinstance(event, EndStepEvent):
if event.step % 10 == 0:
print('pass_id=' + str(event.epoch) + ' batch=' + str(
event.step))
init_ids = fluid.layers.data(
name="init_ids", shape=[1], dtype="int64", lod_level=2)
init_scores = fluid.layers.data(
name="init_scores", shape=[1], dtype="float32", lod_level=2)
# create and init arrays to save selected ids, scores and states for each step
ids_array = fluid.layers.array_write(init_ids, i=counter)
scores_array = fluid.layers.array_write(init_scores, i=counter)
state_array = fluid.layers.array_write(encoder_last_proj, i=counter)
if isinstance(event, EndEpochEvent):
trainer.save_params(model_save_dir)
cond = fluid.layers.less_than(x=counter, y=max_len)
while_op = fluid.layers.While(cond=cond)
with while_op.block():
pre_ids = fluid.layers.array_read(array=ids_array, i=counter)
pre_score = fluid.layers.array_read(array=scores_array, i=counter)
pre_state = fluid.layers.array_read(array=state_array, i=counter)
pre_ids_emb = fluid.layers.embedding(
input=pre_ids,
size=[target_dict_size, word_dim],
dtype='float32',
is_sparse=is_sparse)
out, current_state = cell(pre_ids_emb, pre_state, encoder_out,
encoder_out_proj)
prob = fluid.layers.fc(
input=current_state, size=target_dict_size, act='softmax')
# beam search
topk_scores, topk_indices = fluid.layers.topk(prob, k=beam_size)
accu_scores = fluid.layers.elementwise_add(
x=fluid.layers.log(topk_scores),
y=fluid.layers.reshape(pre_score, shape=[-1]),
axis=0)
accu_scores = fluid.layers.lod_reset(x=accu_scores, y=pre_ids)
selected_ids, selected_scores = fluid.layers.beam_search(
pre_ids, pre_score, topk_indices, accu_scores, beam_size, end_id=1)
fluid.layers.increment(x=counter, value=1, in_place=True)
# save selected ids and corresponding scores of each step
fluid.layers.array_write(selected_ids, array=ids_array, i=counter)
fluid.layers.array_write(selected_scores, array=scores_array, i=counter)
# update rnn state by sequence_expand acting as gather
current_state = fluid.layers.sequence_expand(current_state,
selected_ids)
fluid.layers.array_write(current_state, array=state_array, i=counter)
current_enc_out = fluid.layers.sequence_expand(encoder_out,
selected_ids)
fluid.layers.assign(current_enc_out, encoder_out)
current_enc_out_proj = fluid.layers.sequence_expand(encoder_out_proj,
selected_ids)
fluid.layers.assign(current_enc_out_proj, encoder_out_proj)
# update conditional variable
length_cond = fluid.layers.less_than(x=counter, y=max_len)
finish_cond = fluid.layers.logical_not(
fluid.layers.is_empty(x=selected_ids))
fluid.layers.logical_and(x=length_cond, y=finish_cond, out=cond)
translation_ids, translation_scores = fluid.layers.beam_search_decode(
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=1)
return translation_ids, translation_scores
def infer_model():
encoder_out = encoder()
translation_ids, translation_scores = infer_decoder(encoder_out)
return translation_ids, translation_scores
def infer(use_cuda):
infer_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard():
translation_ids, translation_scores = infer_model()
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
test_data = paddle.batch(
paddle.dataset.wmt16.test(source_dict_size, target_dict_size),
batch_size=batch_size)
src_idx2word = paddle.dataset.wmt16.get_dict(
"en", source_dict_size, reverse=True)
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", target_dict_size, reverse=True)
trainer = Trainer(
train_func=train_program, place=place, optimizer_func=optimizer_func)
fluid.io.load_params(exe, model_save_dir, main_program=infer_prog)
trainer.train(
reader=train_reader,
num_epochs=EPOCH_NUM,
event_handler=event_handler,
feed_order=feed_order)
for data in test_data():
src_word_id = fluid.create_lod_tensor(
data=map(lambda x: x[0], data),
recursive_seq_lens=[[len(x[0]) for x in data]],
place=place)
init_ids = fluid.create_lod_tensor(
data=np.array([[0]] * len(data), dtype='int64'),
recursive_seq_lens=[[1] * len(data)] * 2,
place=place)
init_scores = fluid.create_lod_tensor(
data=np.array([[0.]] * len(data), dtype='float32'),
recursive_seq_lens=[[1] * len(data)] * 2,
place=place)
seq_ids, seq_scores = exe.run(
infer_prog,
feed={
'src_word_id': src_word_id,
'init_ids': init_ids,
'init_scores': init_scores
},
fetch_list=[translation_ids, translation_scores],
return_numpy=False)
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps = [[] for i in range(len(seq_ids.lod()[0]) - 1)]
scores = [[] for i in range(len(seq_scores.lod()[0]) - 1)]
for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence
start = seq_ids.lod()[0][i]
end = seq_ids.lod()[0][i + 1]
print("Original sentence:")
print(" ".join([src_idx2word[idx] for idx in data[i][0][1:-1]]))
print("Translated score and sentence:")
for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([
trg_idx2word[idx]
for idx in np.array(seq_ids)[sub_start:sub_end][1:-1]
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print(scores[i][-1], hyps[i][-1].encode('utf8'))
def main(use_cuda):
train(use_cuda)
infer(use_cuda)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册