diff --git a/fluid/DeepASR/data_utils/async_data_reader.py b/fluid/DeepASR/data_utils/async_data_reader.py index d6949257b6d4f142d9e5040ffab39f0236814de3..731c55de71e8d4b7db156f1ae72172c36eb1be7a 100644 --- a/fluid/DeepASR/data_utils/async_data_reader.py +++ b/fluid/DeepASR/data_utils/async_data_reader.py @@ -30,11 +30,12 @@ class SampleInfo(object): label_bin_path (str): File containing the label data. label_size (int): Byte count of the sample's label data. label_frame_num (int): Label number of the sample. + sample_name (str): Key of the sample """ def __init__(self, feature_bin_path, feature_start, feature_size, feature_frame_num, feature_dim, label_bin_path, label_start, - label_size, label_frame_num): + label_size, label_frame_num, sample_name): self.feature_bin_path = feature_bin_path self.feature_start = feature_start self.feature_size = feature_size @@ -45,6 +46,7 @@ class SampleInfo(object): self.label_start = label_start self.label_size = label_size self.label_frame_num = label_frame_num + self.sample_name = sample_name class SampleInfoBucket(object): @@ -102,24 +104,33 @@ class SampleInfoBucket(object): feature_bin_path = self._feature_bin_paths[block_idx] feature_desc_path = self._feature_desc_paths[block_idx] - label_desc_lines = open(label_desc_path).readlines() feature_desc_lines = open(feature_desc_path).readlines() - sample_num = int(label_desc_lines[0].split()[1]) - assert sample_num == int(feature_desc_lines[0].split()[1]) + label_desc_lines = [] + if label_desc_path != "": + label_desc_lines = open(label_desc_path).readlines() + sample_num = int(feature_desc_lines[0].split()[1]) + + if label_desc_path != "": + assert sample_num == int(label_desc_lines[0].split()[1]) for i in xrange(sample_num): feature_desc_split = feature_desc_lines[i + 1].split() + sample_name = feature_desc_split[0] feature_start = int(feature_desc_split[2]) feature_size = int(feature_desc_split[3]) feature_frame_num = int(feature_desc_split[4]) feature_dim = int(feature_desc_split[5]) - label_desc_split = label_desc_lines[i + 1].split() - label_start = int(label_desc_split[2]) - label_size = int(label_desc_split[3]) - label_frame_num = int(label_desc_split[4]) - assert feature_frame_num == label_frame_num + label_start = -1 + label_size = -1 + label_frame_num = feature_frame_num + if label_desc_path != "": + label_desc_split = label_desc_lines[i + 1].split() + label_start = int(label_desc_split[2]) + label_size = int(label_desc_split[3]) + label_frame_num = int(label_desc_split[4]) + assert feature_frame_num == label_frame_num if self._split_sentence_threshold == -1 or \ self._split_perturb == -1 or \ @@ -129,7 +140,7 @@ class SampleInfoBucket(object): SampleInfo(feature_bin_path, feature_start, feature_size, feature_frame_num, feature_dim, label_bin_path, label_start, label_size, - label_frame_num)) + label_frame_num, sample_name)) #split sentence else: cur_frame_pos = 0 @@ -150,13 +161,12 @@ class SampleInfoBucket(object): * feature_dim * 4, cur_frame_len * feature_dim * 4, cur_frame_len, feature_dim, label_bin_path, label_start + cur_frame_pos * 4, cur_frame_len * - 4, cur_frame_len)) + 4, cur_frame_len, sample_name)) remain_frame_num -= cur_frame_len cur_frame_pos += cur_frame_len if remain_frame_num <= 0: break - return sample_info_list @@ -192,7 +202,7 @@ class AsyncDataReader(object): def __init__(self, feature_file_list, - label_file_list, + label_file_list="", drop_frame_len=512, proc_num=10, sample_buffer_size=1024, @@ -221,16 +231,24 @@ class AsyncDataReader(object): def generate_bucket_list(self, is_shuffle): if self._block_info_list is None: block_feature_info_lines = open(self._feature_file_list).readlines() - block_label_info_lines = open(self._label_file_list).readlines() - assert len(block_feature_info_lines) == len(block_label_info_lines) self._block_info_list = [] - for i in xrange(0, len(block_feature_info_lines), 2): - block_info = (block_feature_info_lines[i], - block_feature_info_lines[i + 1], - block_label_info_lines[i], - block_label_info_lines[i + 1]) - self._block_info_list.append( - map(lambda line: line.strip(), block_info)) + if self._label_file_list != "": + block_label_info_lines = open(self._label_file_list).readlines() + assert len(block_feature_info_lines) == len( + block_label_info_lines) + for i in xrange(0, len(block_feature_info_lines), 2): + block_info = (block_feature_info_lines[i], + block_feature_info_lines[i + 1], + block_label_info_lines[i], + block_label_info_lines[i + 1]) + self._block_info_list.append( + map(lambda line: line.strip(), block_info)) + else: + for i in xrange(0, len(block_feature_info_lines), 2): + block_info = (block_feature_info_lines[i], + block_feature_info_lines[i + 1], "", "") + self._block_info_list.append( + map(lambda line: line.strip(), block_info)) if is_shuffle: self._rng.shuffle(self._block_info_list) @@ -310,19 +328,25 @@ class AsyncDataReader(object): sample_info.feature_dim, len(feature_bytes)) - label_bytes = read_bytes(sample_info.label_bin_path, - sample_info.label_start, - sample_info.label_size) - - assert sample_info.label_frame_num * 4 == len(label_bytes), ( - sample_info.label_bin_path, sample_info.label_array, - len(label_bytes)) - - label_array = struct.unpack('I' * sample_info.label_frame_num, - label_bytes) - label_data = np.array( - label_array, dtype='int64').reshape( - (sample_info.label_frame_num, 1)) + label_data = None + if sample_info.label_bin_path != "": + label_bytes = read_bytes(sample_info.label_bin_path, + sample_info.label_start, + sample_info.label_size) + + assert sample_info.label_frame_num * 4 == len( + label_bytes), (sample_info.label_bin_path, + sample_info.label_array, + len(label_bytes)) + + label_array = struct.unpack( + 'I' * sample_info.label_frame_num, label_bytes) + label_data = np.array( + label_array, dtype='int64').reshape( + (sample_info.label_frame_num, 1)) + else: + label_data = np.zeros( + (sample_info.label_frame_num, 1), dtype='int64') feature_frame_num = sample_info.feature_frame_num feature_dim = sample_info.feature_dim @@ -332,12 +356,11 @@ class AsyncDataReader(object): feature_data = np.array( feature_array, dtype='float32').reshape(( sample_info.feature_frame_num, sample_info.feature_dim)) - - sample_data = (feature_data, label_data) + sample_data = (feature_data, label_data, + sample_info.sample_name) for transformer in self._transformers: # @TODO(pkuyym) to make transfomer only accept feature_data sample_data = transformer.perform_trans(sample_data) - while order_id != out_order[0]: time.sleep(0.001) @@ -387,12 +410,14 @@ class AsyncDataReader(object): batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32") batch_label = np.zeros((lod[-1], 1), dtype="int64") start = 0 + name_lst = [] for sample in batch_samples: frame_num = sample[0].shape[0] batch_feature[start:start + frame_num, :] = sample[0] batch_label[start:start + frame_num, :] = sample[1] start += frame_num - return (batch_feature, batch_label) + name_lst.append(sample[2]) + return (batch_feature, batch_label, name_lst) @suppress_complaints(verbose=self._verbose, notify=self._force_exit) def batch_assembling_task(sample_generator, batch_queue): @@ -402,16 +427,16 @@ class AsyncDataReader(object): batch_samples.append(sample) lod.append(lod[-1] + sample[0].shape[0]) if len(batch_samples) == batch_size: - (batch_feature, batch_label) = batch_to_ndarray( + (batch_feature, batch_label, name_lst) = batch_to_ndarray( batch_samples, lod) - batch_queue.put((batch_feature, batch_label, lod)) + batch_queue.put((batch_feature, batch_label, lod, name_lst)) batch_samples = [] lod = [0] if len(batch_samples) >= minimum_batch_size: - (batch_feature, batch_label) = batch_to_ndarray(batch_samples, - lod) - batch_queue.put((batch_feature, batch_label, lod)) + (batch_feature, batch_label, name_lst) = batch_to_ndarray( + batch_samples, lod) + batch_queue.put((batch_feature, batch_label, lod, name_lst)) batch_queue.put(EpochEndSignal()) diff --git a/fluid/DeepASR/data_utils/augmentor/tests/test_data_trans.py b/fluid/DeepASR/data_utils/augmentor/tests/test_data_trans.py index 157ab02eee0093fe5d683e642b3d18d842cb4e19..9f76a9f8590d5f148398c4ffaff77dc95421df83 100644 --- a/fluid/DeepASR/data_utils/augmentor/tests/test_data_trans.py +++ b/fluid/DeepASR/data_utils/augmentor/tests/test_data_trans.py @@ -22,7 +22,7 @@ class TestTransMeanVarianceNorm(unittest.TestCase): feature = np.zeros((2, 120), dtype="float32") feature.fill(1) trans = trans_mean_variance_norm.TransMeanVarianceNorm(self._file_path) - (feature1, label1) = trans.perform_trans((feature, None)) + (feature1, label1, name) = trans.perform_trans((feature, None, None)) (mean, var) = trans.get_mean_var() feature_flat1 = feature1.flatten() feature_flat = feature.flatten() @@ -70,7 +70,7 @@ class TestTransAddDelta(unittest.TestCase): feature[2, 0:40].fill(3) feature[3, 0:40].fill(4) trans = trans_add_delta.TransAddDelta() - (feature, label) = trans.perform_trans((feature, None)) + (feature, label, name) = trans.perform_trans((feature, None, None)) self.assertAlmostEqual(feature.shape[0], 4) self.assertAlmostEqual(feature.shape[1], 120) self.assertAlmostEqual(1.0, feature[0][0]) @@ -93,7 +93,7 @@ class TestTransSplict(unittest.TestCase): feature[i, :].fill(i) trans = trans_splice.TransSplice() - (feature, label) = trans.perform_trans((feature, None)) + (feature, label, name) = trans.perform_trans((feature, None, None)) self.assertEqual(feature.shape[1], 110) for i in xrange(8): diff --git a/fluid/DeepASR/data_utils/augmentor/trans_add_delta.py b/fluid/DeepASR/data_utils/augmentor/trans_add_delta.py index dc1a4fa45be38152eba773c35e67d0ad3e4a13cb..aa8062f87c932b76dd8a79db825d07e8be273857 100644 --- a/fluid/DeepASR/data_utils/augmentor/trans_add_delta.py +++ b/fluid/DeepASR/data_utils/augmentor/trans_add_delta.py @@ -32,9 +32,9 @@ class TransAddDelta(object): Args: sample(object,tuple): contain feature numpy and label numpy Returns: - (feature, label) + (feature, label, name) """ - (feature, label) = sample + (feature, label, name) = sample frame_dim = feature.shape[1] d_frame_dim = frame_dim * 3 head_filled = 5 @@ -64,7 +64,7 @@ class TransAddDelta(object): start * d_frame_dim + 2 * frame_dim, frame_dim, nframe, d_frame_dim) mat.shape = tmp_shape - return (mat[head_filled:mat.shape[0] - tail_filled, :], label) + return (mat[head_filled:mat.shape[0] - tail_filled, :], label, name) def _regress(self, data_in, start_in, data_out, start_out, size, n, step): """ regress diff --git a/fluid/DeepASR/data_utils/augmentor/trans_mean_variance_norm.py b/fluid/DeepASR/data_utils/augmentor/trans_mean_variance_norm.py index 5b541d426c61364639f7a9d9f50bd51a2c06efa5..9f91b726ea2bcd432340cd06a3cb9006cd5f83f4 100644 --- a/fluid/DeepASR/data_utils/augmentor/trans_mean_variance_norm.py +++ b/fluid/DeepASR/data_utils/augmentor/trans_mean_variance_norm.py @@ -53,9 +53,9 @@ class TransMeanVarianceNorm(object): Args: sample(object):input sample, contain feature numpy and label numpy Returns: - (feature, label) + (feature, label, name) """ - (feature, label) = sample + (feature, label, name) = sample shape = feature.shape assert len(shape) == 2 nfeature_len = shape[0] * shape[1] @@ -68,4 +68,4 @@ class TransMeanVarianceNorm(object): feature[ncur_idx:ncur_idx + self._nLen] = block ncur_idx += self._nLen feature = feature.reshape(shape) - return (feature, label) + return (feature, label, name) diff --git a/fluid/DeepASR/data_utils/augmentor/trans_splice.py b/fluid/DeepASR/data_utils/augmentor/trans_splice.py index 94f5258de316045d41999b26c6963f8487e9c55a..1fab3d6b442c1613f18d16fd0b0ee89464dbeb2c 100644 --- a/fluid/DeepASR/data_utils/augmentor/trans_splice.py +++ b/fluid/DeepASR/data_utils/augmentor/trans_splice.py @@ -30,9 +30,9 @@ class TransSplice(object): Args: sample(object): input sample(feature, label) Return: - (feature, label) + (feature, label, name) """ - (feature, label) = sample + (feature, label, name) = sample nframe_num = feature.shape[0] nframe_dim = feature.shape[1] nnew_frame_dim = nframe_dim * ( @@ -61,4 +61,4 @@ class TransSplice(object): np.copyto(ret[i * nnew_frame_dim:(i + 1) * nnew_frame_dim], mat[i * nframe_dim:i * nframe_dim + nnew_frame_dim]) ret = ret.reshape((nframe_num, nnew_frame_dim)) - return (ret, label) + return (ret, label, name) diff --git a/fluid/DeepASR/train.py b/fluid/DeepASR/train.py index 917807987f3a5fa79254f84c99309ef7bc1b4f1a..3908a550cdcf095057ea6ab0b89e07dcecda51f9 100644 --- a/fluid/DeepASR/train.py +++ b/fluid/DeepASR/train.py @@ -210,6 +210,7 @@ def train(args): # train data reader train_data_reader = reader.AsyncDataReader(args.train_feature_lst, args.train_label_lst, -1) + train_data_reader.set_transformers(ltrans) # train for pass_id in xrange(args.pass_num): @@ -218,7 +219,7 @@ def train(args): train_data_reader.batch_iterator(args.batch_size, args.minimum_batch_size)): # load_data - (features, labels, lod) = batch_data + (features, labels, lod, name_lst) = batch_data feature_t.set(features, place) feature_t.set_lod([lod]) label_t.set(labels, place) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 71e4314953383b8f89b40fdfd8cc4274f954fed1..8bfdf6461bdbfae92afe36520b3b056dddb4836c 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -92,7 +92,9 @@ pos_enc_param_names = ( encoder_input_data_names = ( "src_word", "src_pos", - "src_slf_attn_bias", ) + "src_slf_attn_bias", + "src_slf_attn_pre_softmax_shape", + "src_slf_attn_post_softmax_shape", ) # Names of all data layers in decoder listed in order. decoder_input_data_names = ( @@ -100,6 +102,10 @@ decoder_input_data_names = ( "trg_pos", "trg_slf_attn_bias", "trg_src_attn_bias", + "trg_slf_attn_pre_softmax_shape", + "trg_slf_attn_post_softmax_shape", + "trg_src_attn_pre_softmax_shape", + "trg_src_attn_post_softmax_shape", "enc_output", ) # Names of label related data layers listed in order. diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index e4dee220cedf856633ee626b762804e49a10cfe8..b8b002dc0757481137d452400f276af4342a8af9 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -27,7 +27,14 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, is_target=False, return_pos=True, return_attn_bias=True, - return_max_len=True) + return_max_len=False) + # Append the shape inputs to reshape before and after softmax in encoder + # self attention. + enc_in_data = enc_in_data + [ + np.array( + [-1, enc_in_data[2].shape[-1]], dtype="int32"), np.array( + enc_in_data[2].shape, dtype="int32") + ] enc_output = exe.run(encoder, feed=dict(zip(enc_in_names, enc_in_data)), fetch_list=enc_out_names)[0] @@ -35,8 +42,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, # Beam Search. # To store the beam info. scores = np.zeros((batch_size, beam_size), dtype="float32") - prev_branchs = [[]] * batch_size - next_ids = [[]] * batch_size + prev_branchs = [[] for i in range(batch_size)] + next_ids = [[] for i in range(batch_size)] # Use beam_map to map the instance idx in batch to beam idx, since the # size of feeded batch is changing. beam_map = range(batch_size) @@ -64,8 +71,8 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, trg_words = np.array( [[bos_idx]] * batch_size * beam_size, dtype="int64") trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64") - src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[ - -1], enc_in_data[-2], 1 + src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[2].shape[ + -1], enc_in_data[2], 1 # This is used to remove attention on subsequent words. trg_slf_attn_bias = np.ones((batch_size * beam_size, trg_max_len, trg_max_len)) @@ -77,15 +84,33 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, trg_src_attn_bias = np.tile( src_slf_attn_bias[:, :, ::src_max_length, :], [beam_size, 1, trg_max_len, 1]) + # Append the shape inputs to reshape before and after softmax in + # decoder self attention. + trg_slf_attn_pre_softmax_shape = np.array( + [-1, trg_slf_attn_bias.shape[-1]], dtype="int32") + trg_slf_attn_post_softmax_shape = np.array( + trg_slf_attn_bias.shape, dtype="int32") + # Append the shape inputs to reshape before and after softmax in + # encoder-decoder attention. + trg_src_attn_pre_softmax_shape = np.array( + [-1, trg_src_attn_bias.shape[-1]], dtype="int32") + trg_src_attn_post_softmax_shape = np.array( + trg_src_attn_bias.shape, dtype="int32") enc_output = np.tile(enc_output, [beam_size, 1, 1]) - return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output + return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ + trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ + trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ + enc_output def update_dec_in_data(dec_in_data, next_ids, active_beams): """ Update the input data of decoder mainly by slicing from the previous input data and dropping the finished instance beams. """ - trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data + trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ + trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ + trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ + enc_output = dec_in_data trg_cur_len = len(next_ids[0]) + 1 # include the trg_words = np.array( [ @@ -112,8 +137,23 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, trg_src_attn_bias = np.tile(trg_src_attn_bias[ active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], [1, 1, trg_cur_len, 1]) + # Append the shape inputs to reshape before and after softmax in + # decoder self attention. + trg_slf_attn_pre_softmax_shape = np.array( + [-1, trg_slf_attn_bias.shape[-1]], dtype="int32") + trg_slf_attn_post_softmax_shape = np.array( + trg_slf_attn_bias.shape, dtype="int32") + # Append the shape inputs to reshape before and after softmax in + # encoder-decoder attention. + trg_src_attn_pre_softmax_shape = np.array( + [-1, trg_src_attn_bias.shape[-1]], dtype="int32") + trg_src_attn_post_softmax_shape = np.array( + trg_src_attn_bias.shape, dtype="int32") enc_output = enc_output[active_beams_indice, :, :] - return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output + return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ + trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, \ + trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, \ + enc_output dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output) diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index ba5ba4470759da5fd2c6dd3b3d61b88c3468bd27..ffc07e91421dbaf3ed6e370f04ec6f1d7439fcf8 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -32,7 +32,9 @@ def multi_head_attention(queries, d_value, d_model, n_head=1, - dropout_rate=0.): + dropout_rate=0., + pre_softmax_shape=None, + post_softmax_shape=None): """ Multi-Head Attention. Note that attn_bias is added to the logit before computing softmax activiation to mask certain selected positions so that @@ -111,26 +113,16 @@ def multi_head_attention(queries, """ Scaled Dot-Product Attention """ - - # FIXME(guosheng): Optimize the shape in reshape_op or softmax_op. - - # The current implementation of softmax_op only supports 2D tensor, - # consequently it cannot be directly used here. - # If to use the reshape_op, Besides, the shape of product inferred in - # compile-time is not the actual shape in run-time. It cann't be used - # to set the attribute of reshape_op. - # So, here define the softmax for temporary solution. - - def __softmax(x, eps=1e-9): - exp_out = layers.exp(x=x) - sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False) - return layers.elementwise_div(x=exp_out, y=sum_out, axis=0) - scaled_q = layers.scale(x=q, scale=d_model**-0.5) product = layers.matmul(x=scaled_q, y=k, transpose_y=True) - weights = __softmax( - layers.elementwise_add( - x=product, y=attn_bias) if attn_bias else product) + weights = layers.reshape( + x=layers.elementwise_add( + x=product, y=attn_bias) if attn_bias else product, + shape=[-1, product.shape[-1]], + actual_shape=pre_softmax_shape, + act="softmax") + weights = layers.reshape( + x=weights, shape=product.shape, actual_shape=post_softmax_shape) if dropout_rate: weights = layers.dropout( weights, dropout_prob=dropout_rate, is_test=False) @@ -177,7 +169,7 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid): return out -def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.): +def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.): """ Add residual connection, layer normalization and droput to the out tensor optionally according to the value of process_cmd. @@ -195,8 +187,9 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.): param_attr=fluid.initializer.Constant(1.), bias_attr=fluid.initializer.Constant(0.)) elif cmd == "d": # add dropout - if dropout: - out = layers.dropout(out, dropout_prob=dropout, is_test=False) + if dropout_rate: + out = layers.dropout( + out, dropout_prob=dropout_rate, is_test=False) return out @@ -210,7 +203,7 @@ def prepare_encoder(src_word, src_emb_dim, src_pad_idx, src_max_len, - dropout=0., + dropout_rate=0., pos_pad_idx=0, pos_enc_param_name=None): """Add word embeddings and position encodings. @@ -235,8 +228,8 @@ def prepare_encoder(src_word, # FIXME(guosheng): Decouple the program desc with batch_size. enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim]) return layers.dropout( - enc_input, dropout_prob=dropout, - is_test=False) if dropout else enc_input + enc_input, dropout_prob=dropout_rate, + is_test=False) if dropout_rate else enc_input prepare_encoder = partial( @@ -252,7 +245,9 @@ def encoder_layer(enc_input, d_value, d_model, d_inner_hid, - dropout_rate=0.): + dropout_rate=0., + pre_softmax_shape=None, + post_softmax_shape=None): """The encoder layers that can be stacked to form a deep encoder. This module consits of a multi-head (self) attention followed by @@ -260,9 +255,9 @@ def encoder_layer(enc_input, with the post_process_layer to add residual connection, layer normalization and droput. """ - attn_output = multi_head_attention(enc_input, enc_input, enc_input, - attn_bias, d_key, d_value, d_model, - n_head, dropout_rate) + attn_output = multi_head_attention( + enc_input, enc_input, enc_input, attn_bias, d_key, d_value, d_model, + n_head, dropout_rate, pre_softmax_shape, post_softmax_shape) attn_output = post_process_layer(enc_input, attn_output, "dan", dropout_rate) ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model) @@ -277,7 +272,9 @@ def encoder(enc_input, d_value, d_model, d_inner_hid, - dropout_rate=0.): + dropout_rate=0., + pre_softmax_shape=None, + post_softmax_shape=None): """ The encoder is composed of a stack of identical layers returned by calling encoder_layer. @@ -291,7 +288,9 @@ def encoder(enc_input, d_value, d_model, d_inner_hid, - dropout_rate, ) + dropout_rate, + pre_softmax_shape, + post_softmax_shape, ) enc_input = enc_output return enc_output @@ -305,7 +304,11 @@ def decoder_layer(dec_input, d_value, d_model, d_inner_hid, - dropout_rate=0.): + dropout_rate=0., + slf_attn_pre_softmax_shape=None, + slf_attn_post_softmax_shape=None, + src_attn_pre_softmax_shape=None, + src_attn_post_softmax_shape=None): """ The layer to be stacked in decoder part. The structure of this module is similar to that in the encoder part except @@ -320,7 +323,9 @@ def decoder_layer(dec_input, d_value, d_model, n_head, - dropout_rate, ) + dropout_rate, + slf_attn_pre_softmax_shape, + slf_attn_post_softmax_shape, ) slf_attn_output = post_process_layer( dec_input, slf_attn_output, @@ -335,7 +340,9 @@ def decoder_layer(dec_input, d_value, d_model, n_head, - dropout_rate, ) + dropout_rate, + src_attn_pre_softmax_shape, + src_attn_post_softmax_shape, ) enc_attn_output = post_process_layer( slf_attn_output, enc_attn_output, @@ -363,7 +370,11 @@ def decoder(dec_input, d_value, d_model, d_inner_hid, - dropout_rate=0.): + dropout_rate=0., + slf_attn_pre_softmax_shape=None, + slf_attn_post_softmax_shape=None, + src_attn_pre_softmax_shape=None, + src_attn_post_softmax_shape=None): """ The decoder is composed of a stack of identical decoder_layer layers. """ @@ -378,7 +389,11 @@ def decoder(dec_input, d_value, d_model, d_inner_hid, - dropout_rate, ) + dropout_rate, + slf_attn_pre_softmax_shape, + slf_attn_post_softmax_shape, + src_attn_pre_softmax_shape, + src_attn_post_softmax_shape, ) dec_input = dec_output return dec_output @@ -391,7 +406,9 @@ def make_inputs(input_data_names, is_pos, slf_attn_bias_flag, src_attn_bias_flag, - enc_output_flag=False): + enc_output_flag=False, + slf_attn_shape_flag=True, + src_attn_shape_flag=True): """ Define the input data layers for the transformer model. """ @@ -429,6 +446,32 @@ def make_inputs(input_data_names, dtype="float32", append_batch_size=False) input_layers += [src_attn_bias] + if slf_attn_shape_flag: + slf_attn_pre_softmax_shape = layers.data( + name=input_data_names[len(input_layers)], + shape=[3], + dtype="int32", + append_batch_size=False) + input_layers += [slf_attn_pre_softmax_shape] + slf_attn_post_softmax_shape = layers.data( + name=input_data_names[len(input_layers)], + shape=[3], + dtype="int32", + append_batch_size=False) + input_layers += [slf_attn_post_softmax_shape] + if src_attn_shape_flag: + src_attn_pre_softmax_shape = layers.data( + name=input_data_names[len(input_layers)], + shape=[3], + dtype="int32", + append_batch_size=False) + input_layers += [src_attn_pre_softmax_shape] + src_attn_post_softmax_shape = layers.data( + name=input_data_names[len(input_layers)], + shape=[3], + dtype="int32", + append_batch_size=False) + input_layers += [src_attn_post_softmax_shape] if enc_output_flag: enc_output = layers.data( name=input_data_names[len(input_layers)], @@ -436,6 +479,7 @@ def make_inputs(input_data_names, dtype="float32", append_batch_size=False) input_layers += [enc_output] + return input_layers @@ -453,8 +497,18 @@ def transformer( src_pad_idx, trg_pad_idx, pos_pad_idx, ): - enc_input_layers = make_inputs(encoder_input_data_names, n_head, d_model, - batch_size, max_length, True, True, False) + enc_input_layers = make_inputs( + encoder_input_data_names, + n_head, + d_model, + batch_size, + max_length, + is_pos=True, + slf_attn_bias_flag=True, + src_attn_bias_flag=False, + enc_output_flag=False, + slf_attn_shape_flag=True, + src_attn_shape_flag=False) enc_output = wrap_encoder( src_vocab_size, @@ -470,8 +524,18 @@ def transformer( pos_pad_idx, enc_input_layers, ) - dec_input_layers = make_inputs(decoder_input_data_names, n_head, d_model, - batch_size, max_length, True, True, True) + dec_input_layers = make_inputs( + decoder_input_data_names, + n_head, + d_model, + batch_size, + max_length, + is_pos=True, + slf_attn_bias_flag=True, + src_attn_bias_flag=True, + enc_output_flag=False, + slf_attn_shape_flag=True, + src_attn_shape_flag=True) predict = wrap_decoder( trg_vocab_size, @@ -490,9 +554,19 @@ def transformer( # Padding index do not contribute to the total loss. The weights is used to # cancel padding index in calculating the loss. - gold, weights = make_inputs(label_data_names, n_head, d_model, batch_size, - max_length, False, False, False) - cost = layers.cross_entropy(input=predict, label=gold) + gold, weights = make_inputs( + label_data_names, + n_head, + d_model, + batch_size, + max_length, + is_pos=False, + slf_attn_bias_flag=False, + src_attn_bias_flag=False, + enc_output_flag=False, + slf_attn_shape_flag=False, + src_attn_shape_flag=False) + cost = layers.softmax_with_cross_entropy(logits=predict, label=gold) weighted_cost = cost * weights return layers.reduce_sum(weighted_cost), predict @@ -514,11 +588,22 @@ def wrap_encoder(src_vocab_size, """ if enc_input_layers is None: # This is used to implement independent encoder program in inference. - src_word, src_pos, src_slf_attn_bias = make_inputs( - encoder_input_data_names, n_head, d_model, batch_size, max_length, - True, True, False) + src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \ + slf_attn_post_softmax_shape = make_inputs( + encoder_input_data_names, + n_head, + d_model, + batch_size, + max_length, + is_pos=True, + slf_attn_bias_flag=True, + src_attn_bias_flag=False, + enc_output_flag=False, + slf_attn_shape_flag=True, + src_attn_shape_flag=False) else: - src_word, src_pos, src_slf_attn_bias = enc_input_layers + src_word, src_pos, src_slf_attn_bias, slf_attn_pre_softmax_shape, \ + slf_attn_post_softmax_shape = enc_input_layers enc_input = prepare_encoder( src_word, src_pos, @@ -536,7 +621,9 @@ def wrap_encoder(src_vocab_size, d_value, d_model, d_inner_hid, - dropout_rate, ) + dropout_rate, + slf_attn_pre_softmax_shape, + slf_attn_post_softmax_shape, ) return enc_output @@ -558,11 +645,26 @@ def wrap_decoder(trg_vocab_size, """ if dec_input_layers is None: # This is used to implement independent decoder program in inference. - trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = make_inputs( - decoder_input_data_names, n_head, d_model, batch_size, max_length, - True, True, True, True) + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ + src_attn_pre_softmax_shape, src_attn_post_softmax_shape, \ + enc_output = make_inputs( + decoder_input_data_names, + n_head, + d_model, + batch_size, + max_length, + is_pos=True, + slf_attn_bias_flag=True, + src_attn_bias_flag=True, + enc_output_flag=True, + slf_attn_shape_flag=True, + src_attn_shape_flag=True) else: - trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_input_layers + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, \ + slf_attn_pre_softmax_shape, slf_attn_post_softmax_shape, \ + src_attn_pre_softmax_shape, src_attn_post_softmax_shape = \ + dec_input_layers dec_input = prepare_decoder( trg_word, @@ -583,13 +685,17 @@ def wrap_decoder(trg_vocab_size, d_value, d_model, d_inner_hid, - dropout_rate, ) - + dropout_rate, + slf_attn_pre_softmax_shape, + slf_attn_post_softmax_shape, + src_attn_pre_softmax_shape, + src_attn_post_softmax_shape, ) + # Return logits for training and probs for inference. predict = layers.reshape( x=layers.fc(input=dec_output, size=trg_vocab_size, bias_attr=False, num_flatten_dims=2), shape=[-1, trg_vocab_size], - act="softmax") + act="softmax" if dec_input_layers is None else None) return predict diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 65de8ef7fa8421bd72175175f1cf421a4237ddd5..13e4fe7a4aa787f6e59ceb15d40dbd1f1477c86c 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -66,13 +66,29 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], [1, 1, trg_max_len, 1]).astype("float32") + src_slf_attn_pre_softmax_shape = np.array( + [-1, src_slf_attn_bias.shape[-1]], dtype="int32") + src_slf_attn_post_softmax_shape = np.array( + src_slf_attn_bias.shape, dtype="int32") + trg_slf_attn_pre_softmax_shape = np.array( + [-1, trg_slf_attn_bias.shape[-1]], dtype="int32") + trg_slf_attn_post_softmax_shape = np.array( + trg_slf_attn_bias.shape, dtype="int32") + trg_src_attn_pre_softmax_shape = np.array( + [-1, trg_src_attn_bias.shape[-1]], dtype="int32") + trg_src_attn_post_softmax_shape = np.array( + trg_src_attn_bias.shape, dtype="int32") lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, n_head, False, False, False, False) lbl_weight = (lbl_word != trg_pad_idx).astype("float32").reshape([-1, 1]) input_dict = dict( zip(input_data_names, [ - src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, - trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight + src_word, src_pos, src_slf_attn_bias, + src_slf_attn_pre_softmax_shape, src_slf_attn_post_softmax_shape, + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, + trg_slf_attn_pre_softmax_shape, trg_slf_attn_post_softmax_shape, + trg_src_attn_pre_softmax_shape, trg_src_attn_post_softmax_shape, + lbl_word, lbl_weight ])) return input_dict