未验证 提交 84738545 编写于 作者: A Aurelius84 提交者: GitHub

Add dygraph_to_static training unitTest of transformer model (#23316)

上级 420944e5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import time
import os
import unittest
import paddle.fluid as fluid
import transformer_util as util
from transformer_dygraph_model import Transformer
from transformer_dygraph_model import CrossEntropyCriterion
trainer_count = 1
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
SEED = 10
def train_static(args, batch_generator):
train_prog = fluid.default_main_program()
startup_prog = fluid.default_startup_program()
train_prog.random_seed = SEED
startup_prog.random_seed = SEED
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
# define input and reader
input_field_names = util.encoder_data_input_fields + \
util.decoder_data_input_fields[:-1] + util.label_data_input_fields
input_descs = util.get_input_descs(args)
input_slots = [{
"name": name,
"shape": input_descs[name][0],
"dtype": input_descs[name][1]
} for name in input_field_names]
input_field = util.InputField(input_slots)
# Define DataLoader
data_loader = fluid.io.DataLoader.from_generator(
input_field.feed_list, capacity=60)
data_loader.set_batch_generator(batch_generator, places=place)
# define model
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value,
args.d_model, args.d_inner_hid, args.prepostprocess_dropout,
args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
args.postprocess_cmd, args.weight_sharing, args.bos_idx,
args.eos_idx)
logits = transformer(*input_field.feed_list[:7])
# define loss
criterion = CrossEntropyCriterion(args.label_smooth_eps)
lbl_word, lbl_weight = input_field.feed_list[7:]
sum_cost, avg_cost, token_num = criterion(logits, lbl_word,
lbl_weight)
# define optimizer
learning_rate = fluid.layers.learning_rate_scheduler.noam_decay(
args.d_model, args.warmup_steps, args.learning_rate)
optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate,
beta1=args.beta1,
beta2=args.beta2,
epsilon=float(args.eps))
optimizer.minimize(avg_cost)
# the best cross-entropy value with label smoothing
loss_normalizer = -((1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps * np.log(
args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
step_idx = 0
total_batch_num = 0
avg_loss = []
exe = fluid.Executor(place)
exe.run(startup_prog)
for pass_id in range(args.epoch):
batch_id = 0
for feed_dict in data_loader:
outs = exe.run(program=train_prog,
feed=feed_dict,
fetch_list=[sum_cost.name, token_num.name])
if step_idx % args.print_step == 0:
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
1])
total_sum_cost = sum_cost_val.sum()
total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num
avg_loss.append(total_avg_cost)
if step_idx == 0:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
avg_batch_time = time.time()
batch_id += 1
step_idx += 1
total_batch_num = total_batch_num + 1
if step_idx == 10:
if args.save_model:
model_path = os.path.join(
args.save_model, "step_" + str(step_idx), "transformer")
fluid.save(train_prog, model_path)
break
return np.array(avg_loss)
def train_dygraph(args, batch_generator):
with fluid.dygraph.guard(place):
if SEED is not None:
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
# define data loader
train_loader = fluid.io.DataLoader.from_generator(capacity=10)
train_loader.set_batch_generator(batch_generator, places=place)
# define model
transformer = Transformer(
args.src_vocab_size, args.trg_vocab_size, args.max_length + 1,
args.n_layer, args.n_head, args.d_key, args.d_value, args.d_model,
args.d_inner_hid, args.prepostprocess_dropout,
args.attention_dropout, args.relu_dropout, args.preprocess_cmd,
args.postprocess_cmd, args.weight_sharing, args.bos_idx,
args.eos_idx)
# define loss
criterion = CrossEntropyCriterion(args.label_smooth_eps)
# define optimizer
learning_rate = fluid.layers.learning_rate_scheduler.noam_decay(
args.d_model, args.warmup_steps, args.learning_rate)
# define optimizer
optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate,
beta1=args.beta1,
beta2=args.beta2,
epsilon=float(args.eps),
parameter_list=transformer.parameters())
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
ce_time = []
ce_ppl = []
avg_loss = []
step_idx = 0
for pass_id in range(args.epoch):
pass_start_time = time.time()
batch_id = 0
for input_data in train_loader():
(src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word,
lbl_weight) = input_data
logits = transformer(src_word, src_pos, src_slf_attn_bias,
trg_word, trg_pos, trg_slf_attn_bias,
trg_src_attn_bias)
sum_cost, avg_cost, token_num = criterion(logits, lbl_word,
lbl_weight)
avg_cost.backward()
optimizer.minimize(avg_cost)
transformer.clear_gradients()
if step_idx % args.print_step == 0:
total_avg_cost = avg_cost.numpy() * trainer_count
avg_loss.append(total_avg_cost[0])
if step_idx == 0:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
avg_batch_time = time.time()
else:
logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.print_step / (time.time() - avg_batch_time)))
ce_ppl.append(np.exp([min(total_avg_cost, 100)]))
avg_batch_time = time.time()
batch_id += 1
step_idx += 1
if step_idx == 10:
if args.save_model:
model_dir = os.path.join(args.save_model + '_dygraph',
"step_" + str(step_idx))
if not os.path.exists(model_dir):
os.makedirs(model_dir)
fluid.save_dygraph(
transformer.state_dict(),
os.path.join(model_dir, "transformer"))
fluid.save_dygraph(
optimizer.state_dict(),
os.path.join(model_dir, "transformer"))
break
time_consumed = time.time() - pass_start_time
ce_time.append(time_consumed)
return np.array(avg_loss)
class TestTransformer(unittest.TestCase):
def prepare(self, mode='train'):
args = util.ModelHyperParams()
batch_generator = util.get_feed_data_reader(args, mode)
return args, batch_generator
def test_train(self):
args, batch_generator = self.prepare(mode='train')
static_avg_loss = train_static(args, batch_generator)
dygraph_avg_loss = train_dygraph(args, batch_generator)
self.assertTrue(np.allclose(static_avg_loss, dygraph_avg_loss))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer
from paddle.fluid.dygraph.jit import dygraph_to_static_func
def position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
"""
channels = d_pos_vec
position = np.arange(n_position)
num_timescales = channels // 2
log_timescale_increment = (np.log(float(1e4) / float(1)) /
(num_timescales - 1))
inv_timescales = np.exp(np.arange(
num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal
return position_enc.astype("float32")
class PrePostProcessLayer(Layer):
def __init__(self, process_cmd, d_model, dropout_rate):
super(PrePostProcessLayer, self).__init__()
self.process_cmd = process_cmd
self.functors = []
for cmd in self.process_cmd:
if cmd == "a": # add residual connection
self.functors.append(lambda x, y: x + y if y else x)
elif cmd == "n": # add layer normalization
self.functors.append(
self.add_sublayer(
"layer_norm_%d" % len(
self.sublayers(include_sublayers=False)),
LayerNorm(
normalized_shape=d_model,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))))
elif cmd == "d": # add dropout
if dropout_rate:
self.functors.append(lambda x: layers.dropout(
x, dropout_prob=dropout_rate, is_test=False))
@dygraph_to_static_func
def forward(self, x, residual=None):
for i, cmd in enumerate(self.process_cmd):
if cmd == "a":
x = self.functors[i](x, residual)
else:
x = self.functors[i](x)
return x
class MultiHeadAttention(Layer):
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_key = d_key
self.d_value = d_value
self.d_model = d_model
self.dropout_rate = dropout_rate
self.q_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
self.k_fc = Linear(
input_dim=d_model, output_dim=d_key * n_head, bias_attr=False)
self.v_fc = Linear(
input_dim=d_model, output_dim=d_value * n_head, bias_attr=False)
self.proj_fc = Linear(
input_dim=d_value * n_head, output_dim=d_model, bias_attr=False)
@dygraph_to_static_func
def forward(self, queries, keys, values, attn_bias, cache=None):
# compute q ,k ,v
keys = queries if keys is None else keys
values = keys if values is None else values
q = self.q_fc(queries)
k = self.k_fc(keys)
v = self.v_fc(values)
# split head
q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
q = layers.transpose(x=q, perm=[0, 2, 1, 3])
k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
k = layers.transpose(x=k, perm=[0, 2, 1, 3])
v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
v = layers.transpose(x=v, perm=[0, 2, 1, 3])
if cache is not None:
cache_k, cache_v = cache["k"], cache["v"]
k = layers.concat([cache_k, k], axis=2)
v = layers.concat([cache_v, v], axis=2)
cache["k"], cache["v"] = k, v
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if self.dropout_rate:
weights = layers.dropout(
weights, dropout_prob=self.dropout_rate, is_test=False)
out = layers.matmul(weights, v)
out = layers.transpose(out, perm=[0, 2, 1, 3])
out = layers.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
out = self.proj_fc(out)
return out
class FFN(Layer):
def __init__(self, d_inner_hid, d_model, dropout_rate):
super(FFN, self).__init__()
self.dropout_rate = dropout_rate
self.fc1 = Linear(input_dim=d_model, output_dim=d_inner_hid, act="relu")
self.fc2 = Linear(input_dim=d_inner_hid, output_dim=d_model)
@dygraph_to_static_func
def forward(self, x):
hidden = self.fc1(x)
if self.dropout_rate:
hidden = layers.dropout(
hidden, dropout_prob=self.dropout_rate, is_test=False)
out = self.fc2(hidden)
return out
class EncoderLayer(Layer):
def __init__(self,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(EncoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
@dygraph_to_static_func
def forward(self, enc_input, attn_bias):
attn_output = self.self_attn(
self.preprocesser1(enc_input), None, None, attn_bias)
attn_output = self.postprocesser1(attn_output, enc_input)
ffn_output = self.ffn(self.preprocesser2(attn_output))
ffn_output = self.postprocesser2(ffn_output, attn_output)
return ffn_output
class Encoder(Layer):
def __init__(self,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(Encoder, self).__init__()
self.encoder_layers = list()
for i in range(n_layer):
self.encoder_layers.append(
self.add_sublayer(
"layer_%d" % i,
EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd)))
self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
@dygraph_to_static_func
def forward(self, enc_input, attn_bias):
for encoder_layer in self.encoder_layers:
enc_output = encoder_layer(enc_input, attn_bias)
enc_input = enc_output
return self.processer(enc_output)
class Embedder(Layer):
def __init__(self, vocab_size, emb_dim, bos_idx=0):
super(Embedder, self).__init__()
self.word_embedder = Embedding(
size=[vocab_size, emb_dim],
padding_idx=bos_idx,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Normal(0., emb_dim**-0.5)))
@dygraph_to_static_func
def forward(self, word):
word_emb = self.word_embedder(word)
return word_emb
class WrapEncoder(Layer):
def __init__(self, src_vocab_size, max_length, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd, word_embedder):
super(WrapEncoder, self).__init__()
self.emb_dropout = prepostprocess_dropout
self.emb_dim = d_model
self.word_embedder = word_embedder
self.pos_encoder = Embedding(
size=[max_length, self.emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(
position_encoding_init(max_length, self.emb_dim)),
trainable=False))
self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd)
@dygraph_to_static_func
def forward(self, src_word, src_pos, src_slf_attn_bias):
word_emb = self.word_embedder(src_word)
word_emb = layers.scale(x=word_emb, scale=self.emb_dim**0.5)
pos_enc = self.pos_encoder(src_pos)
pos_enc.stop_gradient = True
emb = word_emb + pos_enc
enc_input = layers.dropout(
emb, dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb
enc_output = self.encoder(enc_input, src_slf_attn_bias)
return enc_output
class DecoderLayer(Layer):
def __init__(self,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(DecoderLayer, self).__init__()
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.cross_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
attention_dropout)
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
self.preprocesser3 = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
self.postprocesser3 = PrePostProcessLayer(postprocess_cmd, d_model,
prepostprocess_dropout)
@dygraph_to_static_func
def forward(self,
dec_input,
enc_output,
self_attn_bias,
cross_attn_bias,
cache=None):
self_attn_output = self.self_attn(
self.preprocesser1(dec_input), None, None, self_attn_bias, cache)
self_attn_output = self.postprocesser1(self_attn_output, dec_input)
cross_attn_output = self.cross_attn(
self.preprocesser2(self_attn_output), enc_output, enc_output,
cross_attn_bias)
cross_attn_output = self.postprocesser2(cross_attn_output,
self_attn_output)
ffn_output = self.ffn(self.preprocesser3(cross_attn_output))
ffn_output = self.postprocesser3(ffn_output, cross_attn_output)
return ffn_output
class Decoder(Layer):
def __init__(self, n_layer, n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd):
super(Decoder, self).__init__()
self.decoder_layers = list()
for i in range(n_layer):
self.decoder_layers.append(
self.add_sublayer(
"layer_%d" % i,
DecoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd)))
self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
prepostprocess_dropout)
@dygraph_to_static_func
def forward(self,
dec_input,
enc_output,
self_attn_bias,
cross_attn_bias,
caches=None):
for i, decoder_layer in enumerate(self.decoder_layers):
dec_output = decoder_layer(dec_input, enc_output, self_attn_bias,
cross_attn_bias, None
if caches is None else caches[i])
dec_input = dec_output
return self.processer(dec_output)
class WrapDecoder(Layer):
def __init__(self, trg_vocab_size, max_length, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd, share_input_output_embed, word_embedder):
super(WrapDecoder, self).__init__()
self.emb_dropout = prepostprocess_dropout
self.emb_dim = d_model
self.word_embedder = word_embedder
self.pos_encoder = Embedding(
size=[max_length, self.emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(
position_encoding_init(max_length, self.emb_dim)),
trainable=False))
self.decoder = Decoder(n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd)
if share_input_output_embed:
self.linear = lambda x: layers.matmul(x=x,
y=self.word_embedder.
word_embedder.weight,
transpose_y=True)
else:
self.linear = Linear(
input_dim=d_model, output_dim=trg_vocab_size, bias_attr=False)
@dygraph_to_static_func
def forward(self,
trg_word,
trg_pos,
trg_slf_attn_bias,
trg_src_attn_bias,
enc_output,
caches=None):
word_emb = self.word_embedder(trg_word)
word_emb = layers.scale(x=word_emb, scale=self.emb_dim**0.5)
pos_enc = self.pos_encoder(trg_pos)
pos_enc.stop_gradient = True
emb = word_emb + pos_enc
dec_input = layers.dropout(
emb, dropout_prob=self.emb_dropout,
is_test=False) if self.emb_dropout else emb
dec_output = self.decoder(dec_input, enc_output, trg_slf_attn_bias,
trg_src_attn_bias, caches)
dec_output = layers.reshape(
dec_output,
shape=[-1, dec_output.shape[-1]], )
logits = self.linear(dec_output)
return logits
class CrossEntropyCriterion(object):
def __init__(self, label_smooth_eps):
self.label_smooth_eps = label_smooth_eps
@dygraph_to_static_func
def __call__(self, predict, label, weights):
if self.label_smooth_eps:
label_out = layers.label_smooth(
label=layers.one_hot(
input=label, depth=predict.shape[-1]),
epsilon=self.label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=predict,
label=label_out,
soft_label=True if self.label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
token_num.stop_gradient = True
avg_cost = sum_cost / token_num
return sum_cost, avg_cost, token_num
class Transformer(Layer):
def __init__(self,
src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
bos_id=0,
eos_id=1):
super(Transformer, self).__init__()
src_word_embedder = Embedder(
vocab_size=src_vocab_size, emb_dim=d_model, bos_idx=bos_id)
self.encoder = WrapEncoder(
src_vocab_size, max_length, n_layer, n_head, d_key, d_value,
d_model, d_inner_hid, prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd, src_word_embedder)
if weight_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
trg_word_embedder = src_word_embedder
else:
trg_word_embedder = Embedder(
vocab_size=trg_vocab_size, emb_dim=d_model, bos_idx=bos_id)
self.decoder = WrapDecoder(
trg_vocab_size, max_length, n_layer, n_head, d_key, d_value,
d_model, d_inner_hid, prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd, weight_sharing,
trg_word_embedder)
self.trg_vocab_size = trg_vocab_size
self.n_layer = n_layer
self.n_head = n_head
self.d_key = d_key
self.d_value = d_value
@dygraph_to_static_func
def forward(self, src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias):
enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias)
predict = self.decoder(trg_word, trg_pos, trg_slf_attn_bias,
trg_src_attn_bias, enc_output)
return predict
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import warnings
import six
from functools import partial
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.dataset.wmt16 as wmt16
def get_input_descs(args):
batch_size = args.batch_size # TODO None(before)
seq_len = None
n_head = getattr(args, "n_head", 8)
d_model = getattr(args, "d_model", 512)
input_descs = {
"src_word": [(batch_size, seq_len), "int64", 2],
"src_pos": [(batch_size, seq_len), "int64"],
"src_slf_attn_bias":
[(batch_size, n_head, seq_len, seq_len), "float32"],
"trg_word": [(batch_size, seq_len), "int64", 2],
"trg_pos": [(batch_size, seq_len), "int64"],
"trg_slf_attn_bias":
[(batch_size, n_head, seq_len, seq_len), "float32"],
"trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"
], # TODO: 1 for predict, seq_len for train
"enc_output": [(batch_size, seq_len, d_model), "float32"],
"lbl_word": [(None, 1), "int64"],
"lbl_weight": [(None, 1), "float32"],
"init_score": [(batch_size, 1), "float32", 2],
"init_idx": [(batch_size, ), "int32"],
}
return input_descs
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
fast_decoder_data_input_fields = (
"trg_word",
"trg_src_attn_bias", )
class ModelHyperParams(object):
print_step = 2
init_from_params = "trained_models/step_10/"
save_model = "trained_models"
inference_model_dir = "infer_model"
output_file = "predict.txt"
batch_size = 5
epoch = 1
learning_rate = 2.0
beta1 = 0.9
beta2 = 0.997
eps = 1e-9
warmup_steps = 8000
label_smooth_eps = 0.1
beam_size = 5
max_out_len = 256
n_best = 1
src_vocab_size = 10000
trg_vocab_size = 10000
bos_idx = 0 # index for <bos> token
eos_idx = 1 # index for <eos> token
unk_idx = 2 # index for <unk> token
max_length = 256
d_model = 512
d_inner_hid = 2048
d_key = 64
d_value = 64
n_head = 8
n_layer = 6
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
preprocess_cmd = "n" # layer normalization
postprocess_cmd = "da" # dropout + residual connection
weight_sharing = True
def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=True,
return_num_token=False):
return_list = []
max_len = max(len(inst) for inst in insts)
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if is_label: # label weight
inst_weight = np.array([[1.] * len(inst) + [0.] * (max_len - len(inst))
for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
list(range(0, len(inst))) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data,
1).reshape([-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
[1, n_head, 1, 1]) * [-1e9]
else:
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
(max_len - len(inst))
for inst in insts])
slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
[1, n_head, max_len, 1])
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False,
return_num_token=True)
lbl_word = lbl_word.reshape(-1, 1)
lbl_weight = lbl_weight.reshape(-1, 1)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
]
return data_inputs
def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head):
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias
]
return data_inputs
def get_feed_data_reader(args, mode='train'):
def __for_train__():
train_reader = paddle.batch(
wmt16.train(args.src_vocab_size, args.trg_vocab_size),
batch_size=args.batch_size)
for batch in train_reader():
tensors = prepare_train_input(batch, args.eos_idx, args.eos_idx,
args.n_head)
yield tensors
def __for_test__():
test_reader = paddle.batch(
wmt16.train(args.src_vocab_size, args.trg_vocab_size),
batch_size=args.batch_size)
for batch in test_reader():
tensors = prepare_infer_input(batch, args.eos_idx, args.eos_idx,
args.n_head)
yield tensors
return __for_train__ if mode == 'train' else __for_test__
class InputField(object):
def __init__(self, input_slots):
self.feed_list = []
for slot in input_slots:
self.feed_list.append(
fluid.layers.data(
name=slot['name'],
shape=slot['shape'],
dtype=slot['dtype'],
lod_level=slot.get('lod_level', 0),
append_batch_size=False))
def load(program, model_path, executor=None, var_list=None):
"""
To load python2 saved models in python3.
"""
try:
fluid.load(program, model_path, executor, var_list)
except UnicodeDecodeError:
warnings.warn(
"An UnicodeDecodeError is catched, which might be caused by loading "
"a python2 saved model. Encoding of pickle.load would be set and "
"load again automatically.")
if six.PY3:
load_bak = pickle.load
pickle.load = partial(load_bak, encoding="latin1")
fluid.load(program, model_path, executor, var_list)
pickle.load = load_bak
def load_dygraph(model_path, keep_name_table=False):
"""
To load python2 saved models in python3.
"""
try:
para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table)
return para_dict, opti_dict
except UnicodeDecodeError:
warnings.warn(
"An UnicodeDecodeError is catched, which might be caused by loading "
"a python2 saved model. Encoding of pickle.load would be set and "
"load again automatically.")
if six.PY3:
load_bak = pickle.load
pickle.load = partial(load_bak, encoding="latin1")
para_dict, opti_dict = fluid.load_dygraph(model_path,
keep_name_table)
pickle.load = load_bak
return para_dict, opti_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册