提交 504c7f15 编写于 作者: G guosheng

Support python3 in Transformer

上级 6966c992
...@@ -115,11 +115,11 @@ seq_len = ModelHyperParams.max_length ...@@ -115,11 +115,11 @@ seq_len = ModelHyperParams.max_length
# compile time. # compile time.
input_descs = { input_descs = {
# The actual data shape of src_word is: # The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1L), "int64", 2], "src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is: # The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1] # [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1L), "int64"], "src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the # This input is used to remove attention weights on paddings in the
# encoder. # encoder.
# The actual data shape of src_slf_attn_bias is: # The actual data shape of src_slf_attn_bias is:
...@@ -127,12 +127,12 @@ input_descs = { ...@@ -127,12 +127,12 @@ input_descs = {
"src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"], seq_len), "float32"],
# The actual data shape of trg_word is: # The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1L), "int64", "trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder. 2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is: # The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1L), "int64"], "trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and # This input is used to remove attention weights on paddings and
# subsequent words in the decoder. # subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is: # The actual data shape of trg_slf_attn_bias is:
...@@ -151,15 +151,13 @@ input_descs = { ...@@ -151,15 +151,13 @@ input_descs = {
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"], "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is: # The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1L), "int64"], "lbl_word": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of paddding tokens. # This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is: # The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1] # [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1L), "float32"], "lbl_weight": [(batch_size * seq_len, 1), "float32"],
# These inputs are used to change the shape tensor in beam-search decoder. # This input is used in beam-search decoder.
"trg_slf_attn_pre_softmax_shape_delta": [(2L, ), "int32"], "init_score": [(batch_size, 1), "float32"],
"trg_slf_attn_post_softmax_shape_delta": [(4L, ), "int32"],
"init_score": [(batch_size, 1L), "float32"],
} }
# Names of word embedding table which might be reused for weight sharing. # Names of word embedding table which might be reused for weight sharing.
...@@ -190,6 +188,3 @@ fast_decoder_data_input_fields = ( ...@@ -190,6 +188,3 @@ fast_decoder_data_input_fields = (
"trg_word", "trg_word",
"init_score", "init_score",
"trg_src_attn_bias", ) "trg_src_attn_bias", )
# fast_decoder_util_input_fields = (
# "trg_slf_attn_pre_softmax_shape_delta",
# "trg_slf_attn_post_softmax_shape_delta", )
...@@ -59,8 +59,7 @@ def parse_args(): ...@@ -59,8 +59,7 @@ def parse_args():
"provided in util.py to do this.") "provided in util.py to do this.")
parser.add_argument( parser.add_argument(
"--token_delimiter", "--token_delimiter",
type=partial( type=lambda x: str(x.encode().decode("unicode-escape")),
str.decode, encoding="string-escape"),
default=" ", default=" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter.; " "For EN-DE BPE data we provided, use spaces as token delimiter.; "
...@@ -99,11 +98,11 @@ def post_process_seq(seq, ...@@ -99,11 +98,11 @@ def post_process_seq(seq,
if idx == eos_idx: if idx == eos_idx:
eos_pos = i eos_pos = i
break break
seq = seq[:eos_pos + 1] seq = [
return filter( idx for idx in seq[:eos_pos + 1]
lambda idx: (output_bos or idx != bos_idx) and \ if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
(output_eos or idx != eos_idx), ]
seq) return seq
def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head, def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
...@@ -164,8 +163,10 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece): ...@@ -164,8 +163,10 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece):
fluid.io.load_vars( fluid.io.load_vars(
exe, exe,
InferTaskConfig.model_path, InferTaskConfig.model_path,
vars=filter(lambda var: isinstance(var, fluid.framework.Parameter), vars=[
fluid.default_main_program().list_vars())) var for var in fluid.default_main_program().list_vars()
if isinstance(var, fluid.framework.Parameter)
])
# This is used here to set dropout to the test mode. # This is used here to set dropout to the test mode.
infer_program = fluid.default_main_program().inference_optimize() infer_program = fluid.default_main_program().inference_optimize()
...@@ -203,7 +204,7 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece): ...@@ -203,7 +204,7 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece):
post_process_seq(np.array(seq_ids)[sub_start:sub_end]), post_process_seq(np.array(seq_ids)[sub_start:sub_end]),
trg_idx2word)) trg_idx2word))
scores[i].append(np.array(seq_scores)[sub_end - 1]) scores[i].append(np.array(seq_scores)[sub_end - 1])
print hyps[i][-1] print(hyps[i][-1])
if len(hyps[i]) >= InferTaskConfig.n_best: if len(hyps[i]) >= InferTaskConfig.n_best:
break break
......
...@@ -12,7 +12,7 @@ def position_encoding_init(n_position, d_pos_vec): ...@@ -12,7 +12,7 @@ def position_encoding_init(n_position, d_pos_vec):
Generate the initial values for the sinusoid position encoding table. Generate the initial values for the sinusoid position encoding table.
""" """
position_enc = np.array([[ position_enc = np.array([[
pos / np.power(10000, 2 * (j // 2) / d_pos_vec) pos / np.power(10000, 2. * (j // 2) / d_pos_vec)
for j in range(d_pos_vec) for j in range(d_pos_vec)
] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)]) ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
...@@ -90,8 +90,7 @@ def multi_head_attention(queries, ...@@ -90,8 +90,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension # The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size. # size of the input as the output dimension size.
return layers.reshape( return layers.reshape(
x=trans_x, x=trans_x, shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]])
shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]]))
def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate):
""" """
......
...@@ -2,7 +2,6 @@ import glob ...@@ -2,7 +2,6 @@ import glob
import os import os
import random import random
import tarfile import tarfile
import cPickle
class SortType(object): class SortType(object):
......
...@@ -2,8 +2,8 @@ import argparse ...@@ -2,8 +2,8 @@ import argparse
import ast import ast
import multiprocessing import multiprocessing
import os import os
import six
import time import time
from functools import partial
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -78,8 +78,7 @@ def parse_args(): ...@@ -78,8 +78,7 @@ def parse_args():
help="The <bos>, <eos> and <unk> tokens in the dictionary.") help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument( parser.add_argument(
"--token_delimiter", "--token_delimiter",
type=partial( type=lambda x: str(x.encode().decode("unicode-escape")),
str.decode, encoding="string-escape"),
default=" ", default=" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. " "For EN-DE BPE data we provided, use spaces as token delimiter. "
...@@ -138,8 +137,6 @@ def pad_batch_data(insts, ...@@ -138,8 +137,6 @@ def pad_batch_data(insts,
""" """
return_list = [] return_list = []
max_len = max(len(inst) for inst in insts) max_len = max(len(inst) for inst in insts)
num_token = reduce(lambda x, y: x + y,
[len(inst) for inst in insts]) if return_num_token else 0
# Any token included in dict can be used to pad, since the paddings' loss # Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients. # will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array( inst_data = np.array(
...@@ -151,7 +148,7 @@ def pad_batch_data(insts, ...@@ -151,7 +148,7 @@ def pad_batch_data(insts,
return_list += [inst_weight.astype("float32").reshape([-1, 1])] return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data else: # position data
inst_pos = np.array([ inst_pos = np.array([
range(1, len(inst) + 1) + [0] * (max_len - len(inst)) list(range(1, len(inst) + 1)) + [0] * (max_len - len(inst))
for inst in insts for inst in insts
]) ])
return_list += [inst_pos.astype("int64").reshape([-1, 1])] return_list += [inst_pos.astype("int64").reshape([-1, 1])]
...@@ -176,6 +173,9 @@ def pad_batch_data(insts, ...@@ -176,6 +173,9 @@ def pad_batch_data(insts,
if return_max_len: if return_max_len:
return_list += [max_len] return_list += [max_len]
if return_num_token: if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token] return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0] return return_list if len(return_list) > 1 else return_list[0]
...@@ -258,7 +258,7 @@ def split_data(data, num_part): ...@@ -258,7 +258,7 @@ def split_data(data, num_part):
def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
util_input_names, sum_cost, token_num): sum_cost, token_num):
# Context to do validation. # Context to do validation.
test_program = train_progm.clone() test_program = train_progm.clone()
with fluid.program_guard(test_program): with fluid.program_guard(test_program):
...@@ -299,9 +299,9 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, ...@@ -299,9 +299,9 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
split_data( split_data(
data, num_part=dev_count)): data, num_part=dev_count)):
data_input_dict, _ = prepare_batch_input( data_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names, data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.d_model)
feed_list.append(data_input_dict) feed_list.append(data_input_dict)
outs = exe.run(feed=feed_list, outs = exe.run(feed=feed_list,
...@@ -323,7 +323,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -323,7 +323,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path) fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
lr_scheduler.current_steps = TrainTaskConfig.start_step lr_scheduler.current_steps = TrainTaskConfig.start_step
else: else:
print "init fluid.framework.default_startup_program" print("init fluid.framework.default_startup_program")
exe.run(fluid.framework.default_startup_program()) exe.run(fluid.framework.default_startup_program())
train_data = reader.DataReader( train_data = reader.DataReader(
...@@ -363,8 +363,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -363,8 +363,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
if args.val_file_pattern is not None: if args.val_file_pattern is not None:
test = test_context(train_progm, avg_cost, train_exe, dev_count, test = test_context(train_progm, avg_cost, train_exe, dev_count,
data_input_names, util_input_names, sum_cost, data_input_names, sum_cost, token_num)
token_num)
# the best cross-entropy value with label smoothing # the best cross-entropy value with label smoothing
loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log( loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
...@@ -372,8 +371,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -372,8 +371,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
)) + TrainTaskConfig.label_smooth_eps * )) + TrainTaskConfig.label_smooth_eps *
np.log(TrainTaskConfig.label_smooth_eps / ( np.log(TrainTaskConfig.label_smooth_eps / (
ModelHyperParams.trg_vocab_size - 1) + 1e-20)) ModelHyperParams.trg_vocab_size - 1) + 1e-20))
step_idx = 0
inst_num = 0
init = False init = False
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time() pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
feed_list = [] feed_list = []
...@@ -388,11 +390,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -388,11 +390,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
ModelHyperParams.eos_idx, ModelHyperParams.n_head, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model) ModelHyperParams.d_model)
total_num_token += num_token total_num_token += num_token
feed_kv_pairs = data_input_dict.items() inst_num += len(data_buffer)
feed_kv_pairs = list(data_input_dict.items())
if args.local: if args.local:
feed_kv_pairs += { feed_kv_pairs += list({
lr_scheduler.learning_rate.name: lr_rate lr_scheduler.learning_rate.name: lr_rate
}.items() }.items())
feed_list.append(dict(feed_kv_pairs)) feed_list.append(dict(feed_kv_pairs))
if not init: if not init:
...@@ -410,14 +413,17 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -410,14 +413,17 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
) # sum the cost from multi-devices ) # sum the cost from multi-devices
total_token_num = token_num_val.sum() total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num total_avg_cost = total_sum_cost / total_token_num
print("epoch: %d, batch: %d, avg loss: %f, normalized loss: %f," print(
" ppl: %f" % (pass_id, batch_id, total_avg_cost, "step_idx: %d, total samples: %d, epoch: %d, batch: %d, avg loss: %f, "
total_avg_cost - loss_normalizer, "normalized loss: %f, ppl: %f" %
np.exp([min(total_avg_cost, 100)]))) (step_idx, inst_num, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
if batch_id > 0 and batch_id % 1000 == 0: if batch_id > 0 and batch_id % 1000 == 0:
fluid.io.save_persistables( fluid.io.save_persistables(
exe, exe,
os.path.join(TrainTaskConfig.ckpt_dir, "latest.checkpoint")) os.path.join(TrainTaskConfig.ckpt_dir, "latest.checkpoint"))
step_idx += 1
init = True init = True
time_consumed = time.time() - pass_start_time time_consumed = time.time() - pass_start_time
...@@ -450,7 +456,7 @@ def train(args): ...@@ -450,7 +456,7 @@ def train(args):
is_local = os.getenv("PADDLE_IS_LOCAL", "1") is_local = os.getenv("PADDLE_IS_LOCAL", "1")
if is_local == '0': if is_local == '0':
args.local = False args.local = False
print args print(args)
if args.device == 'CPU': if args.device == 'CPU':
TrainTaskConfig.use_gpu = False TrainTaskConfig.use_gpu = False
...@@ -531,7 +537,7 @@ def train(args): ...@@ -531,7 +537,7 @@ def train(args):
pserver_startup = t.get_startup_program(current_endpoint, pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog) pserver_prog)
print "psserver begin run" print("psserver begin run")
with open('pserver_startup.desc', 'w') as f: with open('pserver_startup.desc', 'w') as f:
f.write(str(pserver_startup)) f.write(str(pserver_startup))
with open('pserver_prog.desc', 'w') as f: with open('pserver_prog.desc', 'w') as f:
......
...@@ -17,6 +17,35 @@ _ALPHANUMERIC_CHAR_SET = set( ...@@ -17,6 +17,35 @@ _ALPHANUMERIC_CHAR_SET = set(
unicodedata.category(six.unichr(i)).startswith("N"))) unicodedata.category(six.unichr(i)).startswith("N")))
# Unicode utility functions that work with Python 2 and 3
def native_to_unicode(s):
return s if is_unicode(s) else to_unicode(s)
def unicode_to_native(s):
if six.PY2:
return s.encode("utf-8") if is_unicode(s) else s
else:
return s
def is_unicode(s):
if six.PY2:
if isinstance(s, unicode):
return True
else:
if isinstance(s, str):
return True
return False
def to_unicode(s, ignore_errors=False):
if is_unicode(s):
return s
error_mode = "ignore" if ignore_errors else "strict"
return s.decode("utf-8", errors=error_mode)
def unescape_token(escaped_token): def unescape_token(escaped_token):
""" """
Inverse of encoding escaping. Inverse of encoding escaping.
...@@ -44,9 +73,7 @@ def subtoken_ids_to_str(subtoken_ids, vocabs): ...@@ -44,9 +73,7 @@ def subtoken_ids_to_str(subtoken_ids, vocabs):
subtokens = [vocabs.get(subtoken_id, u"") for subtoken_id in subtoken_ids] subtokens = [vocabs.get(subtoken_id, u"") for subtoken_id in subtoken_ids]
# Convert a list of subtokens to a list of tokens. # Convert a list of subtokens to a list of tokens.
concatenated = "".join([ concatenated = "".join([native_to_unicode(t) for t in subtokens])
t if isinstance(t, unicode) else t.decode("utf-8") for t in subtokens
])
split = concatenated.split("_") split = concatenated.split("_")
tokens = [] tokens = []
for t in split: for t in split:
...@@ -65,4 +92,4 @@ def subtoken_ids_to_str(subtoken_ids, vocabs): ...@@ -65,4 +92,4 @@ def subtoken_ids_to_str(subtoken_ids, vocabs):
ret.append(token) ret.append(token)
seq = "".join(ret) seq = "".join(ret)
return seq.encode("utf-8") return unicode_to_native(seq)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册