diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 8ab9efce1a275ea9539b05c0b959dee42d83c759..955e512a3de77c158c934e1f086036e928372e38 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -1,8 +1,10 @@ class TrainTaskConfig(object): + # only support GPU currently use_gpu = True # the epoch number to train. pass_num = 30 # the number of sequences contained in a mini-batch. + # deprecated, set batch_size in args. batch_size = 32 # the hyper parameters for Adam optimizer. # This static learning_rate will be multiplied to the LearningRateScheduler @@ -13,8 +15,6 @@ class TrainTaskConfig(object): eps = 1e-9 # the parameters for learning rate scheduling. warmup_steps = 4000 - # the flag indicating to use average loss or sum loss when training. - use_avg_cost = True # the weight used to mix up the ground-truth distribution and the fixed # uniform distribution in label smoothing when training. # Set this as zero if label smoothing is not wanted. @@ -38,22 +38,20 @@ class InferTaskConfig(object): batch_size = 10 # the parameters for beam search. beam_size = 5 - max_length = 30 + max_length = 256 # the number of decoded sentences to output. n_best = 1 # the flags indicating whether to output the special tokens. output_bos = False output_eos = False - output_unk = False + output_unk = True # the directory for loading the trained model. model_path = "trained_models/pass_1.infer.model" class ModelHyperParams(object): - # This model directly uses paddle.dataset.wmt16 in which , and - # token has alreay been added. As for the token, any token - # included in dict can be used to pad, since the paddings' loss will be - # masked out and make no effect on parameter gradients. + # These following five vocabularies related configurations will be set + # automatically according to the passed vocabulary path and special tokens. # size of source word dictionary. src_vocab_size = 10000 # size of target word dictionay @@ -68,13 +66,13 @@ class ModelHyperParams(object): # The size of position encoding table should at least plus 1, since the # sinusoid position encoding starts from 1 and 0 can be used as the padding # token for position encoding. - max_length = 50 + max_length = 256 # the dimension for word embeddings, which is also the last dimension of # the input and output of multi-head attention, position-wise feed-forward # networks, encoder and decoder. d_model = 512 # size of the hidden layer in position-wise feed-forward networks. - d_inner_hid = 1024 + d_inner_hid = 2048 # the dimension that keys are projected to for dot-product attention. d_key = 64 # the dimension that values are projected to for dot-product attention. @@ -85,6 +83,9 @@ class ModelHyperParams(object): n_layer = 6 # dropout rate used by all dropout layers. dropout = 0.1 + # the flag indicating whether to share embedding and softmax weights. + # vocabularies in source and target should be same for weight sharing. + weight_sharing = True def merge_cfg_from_list(cfg_list, g_cfgs): @@ -172,6 +173,10 @@ input_descs = { "lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"], } +# Names of word embedding table which might be reused for weight sharing. +word_emb_param_names = ( + "src_word_emb_table", + "trg_word_emb_table", ) # Names of position encoding table which will be initialized externally. pos_enc_param_names = ( "src_pos_enc_table", diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index e8f7f47dd5c0dc4937b73bd1693b2fd14fb8d55c..33cc5553768307653d2c9b176fe4ceba884b3371 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -359,6 +359,7 @@ def infer(args): start_mark=args.special_token[0], end_mark=args.special_token[1], unk_mark=args.special_token[2], + max_length=ModelHyperParams.max_length, clip_last_batch=False) trg_idx2word = test_data.load_dict( diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index 9c5d8adc312d48eb7c232789e590755e1b349d3a..7756d633fb05d27904f84dc9c41e25643c17eb04 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -46,26 +46,14 @@ def multi_head_attention(queries, """ q = layers.fc(input=queries, size=d_key * n_head, - param_attr=fluid.initializer.Xavier( - uniform=False, - fan_in=d_model * d_key, - fan_out=n_head * d_key), bias_attr=False, num_flatten_dims=2) k = layers.fc(input=keys, size=d_key * n_head, - param_attr=fluid.initializer.Xavier( - uniform=False, - fan_in=d_model * d_key, - fan_out=n_head * d_key), bias_attr=False, num_flatten_dims=2) v = layers.fc(input=values, size=d_value * n_head, - param_attr=fluid.initializer.Xavier( - uniform=False, - fan_in=d_model * d_value, - fan_out=n_head * d_value), bias_attr=False, num_flatten_dims=2) return q, k, v @@ -84,7 +72,7 @@ def multi_head_attention(queries, # The value 0 in shape attr means copying the corresponding dimension # size of the input as the output dimension size. reshaped = layers.reshape( - x=x, shape=[0, -1, n_head, hidden_size // n_head]) + x=x, shape=[0, 0, n_head, hidden_size // n_head]) # permuate the dimensions into: # [batch_size, n_head, max_sequence_len, hidden_size_per_head] @@ -104,7 +92,7 @@ def multi_head_attention(queries, # size of the input as the output dimension size. return layers.reshape( x=trans_x, - shape=map(int, [0, -1, trans_x.shape[2] * trans_x.shape[3]])) + shape=map(int, [0, 0, trans_x.shape[2] * trans_x.shape[3]])) def scaled_dot_product_attention(q, k, v, attn_bias, d_model, dropout_rate): """ @@ -140,7 +128,6 @@ def multi_head_attention(queries, # Project back to the model size. proj_out = layers.fc(input=out, size=d_model, - param_attr=fluid.initializer.Xavier(uniform=False), bias_attr=False, num_flatten_dims=2) return proj_out @@ -155,14 +142,8 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid): hidden = layers.fc(input=x, size=d_inner_hid, num_flatten_dims=2, - param_attr=fluid.initializer.Uniform( - low=-(d_hid**-0.5), high=(d_hid**-0.5)), act="relu") - out = layers.fc(input=hidden, - size=d_hid, - num_flatten_dims=2, - param_attr=fluid.initializer.Uniform( - low=-(d_inner_hid**-0.5), high=(d_inner_hid**-0.5))) + out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2) return out @@ -200,6 +181,7 @@ def prepare_encoder(src_word, src_max_len, dropout_rate=0., src_data_shape=None, + word_emb_param_name=None, pos_enc_param_name=None): """Add word embeddings and position encodings. The output tensor has a shape of: @@ -209,7 +191,10 @@ def prepare_encoder(src_word, src_word_emb = layers.embedding( src_word, size=[src_vocab_size, src_emb_dim], - param_attr=fluid.initializer.Normal(0., 1.)) + param_attr=fluid.ParamAttr( + name=word_emb_param_name, + initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5))) + src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) src_pos_enc = layers.embedding( src_pos, size=[src_max_len, src_emb_dim], @@ -415,7 +400,12 @@ def transformer( d_model, d_inner_hid, dropout_rate, + weight_sharing, label_smooth_eps, ): + if weight_sharing: + assert src_vocab_size == src_vocab_size, ( + "Vocabularies in source and target should be same for weight sharing." + ) enc_inputs = make_all_inputs(encoder_data_input_fields + encoder_util_input_fields) @@ -429,6 +419,7 @@ def transformer( d_model, d_inner_hid, dropout_rate, + weight_sharing, enc_inputs, ) dec_inputs = make_all_inputs(decoder_data_input_fields[:-1] + @@ -444,6 +435,7 @@ def transformer( d_model, d_inner_hid, dropout_rate, + weight_sharing, dec_inputs, enc_output, ) @@ -459,7 +451,6 @@ def transformer( logits=predict, label=label, soft_label=True if label_smooth_eps else False) - # cost = layers.softmax_with_cross_entropy(logits=predict, label=gold) weighted_cost = cost * weights sum_cost = layers.reduce_sum(weighted_cost) token_num = layers.reduce_sum(weights) @@ -476,6 +467,7 @@ def wrap_encoder(src_vocab_size, d_model, d_inner_hid, dropout_rate, + weight_sharing, enc_inputs=None): """ The wrapper assembles together all needed layers for the encoder. @@ -497,7 +489,8 @@ def wrap_encoder(src_vocab_size, d_model, max_length, dropout_rate, - src_data_shape, ) + src_data_shape, + word_emb_param_name=word_emb_param_names[0]) enc_output = encoder( enc_input, src_slf_attn_bias, @@ -522,6 +515,7 @@ def wrap_decoder(trg_vocab_size, d_model, d_inner_hid, dropout_rate, + weight_sharing, dec_inputs=None, enc_output=None): """ @@ -547,7 +541,9 @@ def wrap_decoder(trg_vocab_size, d_model, max_length, dropout_rate, - trg_data_shape, ) + trg_data_shape, + word_emb_param_name=word_emb_param_names[0] + if weight_sharing else word_emb_param_names[1]) dec_output = decoder( dec_input, enc_output, @@ -565,11 +561,20 @@ def wrap_decoder(trg_vocab_size, src_attn_pre_softmax_shape, src_attn_post_softmax_shape, ) # Return logits for training and probs for inference. - predict = layers.reshape( - x=layers.fc(input=dec_output, - size=trg_vocab_size, - bias_attr=False, - num_flatten_dims=2), - shape=[-1, trg_vocab_size], - act="softmax" if dec_inputs is None else None) + if weight_sharing: + predict = layers.reshape( + x=layers.matmul( + x=dec_output, + y=fluid.get_var(word_emb_param_names[0]), + transpose_y=True), + shape=[-1, trg_vocab_size], + act="softmax" if dec_inputs is None else None) + else: + predict = layers.reshape( + x=layers.fc(input=dec_output, + size=trg_vocab_size, + bias_attr=False, + num_flatten_dims=2), + shape=[-1, trg_vocab_size], + act="softmax" if dec_inputs is None else None) return predict diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index bf9edb52bedf065242d4f49391302ba988d7dcac..58eadc1b46f7a09aefa9a595fb70781b76363e72 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -203,50 +203,50 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, [num_token], dtype="float32") -def train(args): - dev_count = fluid.core.get_cuda_device_count() +def read_multiple(reader, count, clip_last=True): + """ + Stack data from reader for multi-devices. + """ - def read_multiple(reader, - count=dev_count if args.use_token_batch else 1, - clip_last=True): - """ - Stack data from reader for multi-devices. - """ - - def __impl__(): - res = [] - for item in reader(): - res.append(item) - if len(res) == count: - yield res - res = [] + def __impl__(): + res = [] + for item in reader(): + res.append(item) if len(res) == count: yield res - elif not clip_last: - data = [] - for item in res: - data += item - if len(data) > count: - inst_num_per_part = len(data) // count - yield [ - data[inst_num_per_part * i:inst_num_per_part * (i + 1)] - for i in range(count) - ] - - return __impl__ - - def split_data(data, num_part=dev_count): - """ - Split data for each device. - """ - if len(data) == num_part: - return data - data = data[0] - inst_num_per_part = len(data) // num_part - return [ - data[inst_num_per_part * i:inst_num_per_part * (i + 1)] - for i in range(num_part) - ] + res = [] + if len(res) == count: + yield res + elif not clip_last: + data = [] + for item in res: + data += item + if len(data) > count: + inst_num_per_part = len(data) // count + yield [ + data[inst_num_per_part * i:inst_num_per_part * (i + 1)] + for i in range(count) + ] + + return __impl__ + + +def split_data(data, num_part): + """ + Split data for each device. + """ + if len(data) == num_part: + return data + data = data[0] + inst_num_per_part = len(data) // num_part + return [ + data[inst_num_per_part * i:inst_num_per_part * (i + 1)] + for i in range(num_part) + ] + + +def train(args): + dev_count = fluid.core.get_cuda_device_count() sum_cost, avg_cost, predict, token_num = transformer( ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size, @@ -254,7 +254,7 @@ def train(args): ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, - TrainTaskConfig.label_smooth_eps) + ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps) lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, TrainTaskConfig.warmup_steps, @@ -288,9 +288,12 @@ def train(args): start_mark=args.special_token[0], end_mark=args.special_token[1], unk_mark=args.special_token[2], + max_length=ModelHyperParams.max_length, clip_last_batch=False) + train_data = read_multiple( + reader=train_data.batch_generator, + count=dev_count if args.use_token_batch else 1) - train_data = read_multiple(reader=train_data.batch_generator) build_strategy = fluid.BuildStrategy() # Since the token number differs among devices, customize gradient scale to # use token average cost among multi-devices. and the gradient scale is @@ -303,9 +306,11 @@ def train(args): def test_context(): # Context to do validation. - test_program = fluid.default_main_program().clone() - with fluid.program_guard(test_program): - test_program = fluid.io.get_inference_program([avg_cost]) + test_program = fluid.default_main_program().clone(for_test=True) + test_exe = fluid.ParallelExecutor( + use_cuda=TrainTaskConfig.use_gpu, + main_program=test_program, + share_vars_from=train_exe) val_data = reader.DataReader( src_vocab_fpath=args.src_vocab_fpath, @@ -319,22 +324,22 @@ def train(args): start_mark=args.special_token[0], end_mark=args.special_token[1], unk_mark=args.special_token[2], + max_length=ModelHyperParams.max_length, clip_last_batch=False, shuffle=False, shuffle_batch=False) - test_exe = fluid.ParallelExecutor( - use_cuda=TrainTaskConfig.use_gpu, - main_program=test_program, - share_vars_from=train_exe) - def test(exe=test_exe): test_total_cost = 0 test_total_token = 0 - test_data = read_multiple(reader=val_data.batch_generator) + test_data = read_multiple( + reader=val_data.batch_generator, + count=dev_count if args.use_token_batch else 1) for batch_id, data in enumerate(test_data()): feed_list = [] - for place_id, data_buffer in enumerate(split_data(data)): + for place_id, data_buffer in enumerate( + split_data( + data, num_part=dev_count)): data_input_dict, util_input_dict, _ = prepare_batch_input( data_buffer, data_input_names, util_input_names, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, @@ -367,7 +372,9 @@ def train(args): feed_list = [] total_num_token = 0 lr_rate = lr_scheduler.update_learning_rate() - for place_id, data_buffer in enumerate(split_data(data)): + for place_id, data_buffer in enumerate( + split_data( + data, num_part=dev_count)): data_input_dict, util_input_dict, num_token = prepare_batch_input( data_buffer, data_input_names, util_input_names, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, @@ -377,17 +384,14 @@ def train(args): dict(data_input_dict.items() + util_input_dict.items() + {lr_scheduler.learning_rate.name: lr_rate}.items())) - if not init: + if not init: # init the position encoding table for pos_enc_param_name in pos_enc_param_names: pos_enc = position_encoding_init( ModelHyperParams.max_length + 1, ModelHyperParams.d_model) feed_list[place_id][pos_enc_param_name] = pos_enc for feed_dict in feed_list: - feed_dict[ - sum_cost.name + - "@GRAD"] = 1. / total_num_token if TrainTaskConfig.use_avg_cost else np.asarray( - [1.], dtype="float32") + feed_dict[sum_cost.name + "@GRAD"] = 1. / total_num_token outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name], feed=feed_list) sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])