From 866d3e03ee60eb66ee9b6911942d520af94baf23 Mon Sep 17 00:00:00 2001 From: guosheng Date: Sat, 1 Feb 2020 01:14:43 +0800 Subject: [PATCH] Add validation for dygraph Transformer. Add cross-attention cache for dygraph Transformer. Add greedy search for dygraph Transformer. --- dygraph/transformer/README.md | 9 +- dygraph/transformer/config.py | 206 -------------------------- dygraph/transformer/model.py | 212 +++++++++++++++++++++------ dygraph/transformer/train.py | 60 +++++++- dygraph/transformer/transformer.yaml | 2 + 5 files changed, 229 insertions(+), 260 deletions(-) delete mode 100644 dygraph/transformer/config.py diff --git a/dygraph/transformer/README.md b/dygraph/transformer/README.md index 6776e618..0b03ca1f 100644 --- a/dygraph/transformer/README.md +++ b/dygraph/transformer/README.md @@ -76,6 +76,7 @@ python -u train.py \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --special_token '' '' '' \ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ + --validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ --batch_size 4096 ``` @@ -91,6 +92,7 @@ python -u train.py \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --special_token '' '' '' \ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ + --validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ --batch_size 4096 \ --n_head 16 \ --d_model 1024 \ @@ -121,10 +123,11 @@ Paddle动态图支持多进程多卡进行模型训练,启动训练的方式 ```sh python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 --log_dir ./mylog train.py \ --epoch 30 \ - --src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \ - --trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \ + --src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ + --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --special_token '' '' '' \ - --training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ + --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ + --validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ --batch_size 4096 \ --print_step 100 \ --use_cuda True \ diff --git a/dygraph/transformer/config.py b/dygraph/transformer/config.py deleted file mode 100644 index b6e1b2bb..00000000 --- a/dygraph/transformer/config.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class TrainTaskConfig(object): - """ - TrainTaskConfig - """ - # the epoch number to train. - pass_num = 20 - # the number of sequences contained in a mini-batch. - # deprecated, set batch_size in args. - batch_size = 32 - # the hyper parameters for Adam optimizer. - # This static learning_rate will be multiplied to the LearningRateScheduler - # derived learning rate the to get the final learning rate. - learning_rate = 2.0 - beta1 = 0.9 - beta2 = 0.997 - eps = 1e-9 - # the parameters for learning rate scheduling. - warmup_steps = 8000 - # the weight used to mix up the ground-truth distribution and the fixed - # uniform distribution in label smoothing when training. - # Set this as zero if label smoothing is not wanted. - label_smooth_eps = 0.1 - - -class InferTaskConfig(object): - # the number of examples in one run for sequence generation. - batch_size = 4 - # the parameters for beam search. - beam_size = 4 - alpha=0.6 - # max decoded length, should be less than ModelHyperParams.max_length - max_out_len = 30 - - - -class ModelHyperParams(object): - """ - ModelHyperParams - """ - # These following five vocabularies related configurations will be set - # automatically according to the passed vocabulary path and special tokens. - # size of source word dictionary. - src_vocab_size = 10000 - # size of target word dictionay - trg_vocab_size = 10000 - # index for token - bos_idx = 0 - # index for token - eos_idx = 1 - # index for token - unk_idx = 2 - - # max length of sequences deciding the size of position encoding table. - max_length = 50 - # the dimension for word embeddings, which is also the last dimension of - # the input and output of multi-head attention, position-wise feed-forward - # networks, encoder and decoder. - d_model = 512 - # size of the hidden layer in position-wise feed-forward networks. - d_inner_hid = 2048 - # the dimension that keys are projected to for dot-product attention. - d_key = 64 - # the dimension that values are projected to for dot-product attention. - d_value = 64 - # number of head used in multi-head attention. - n_head = 8 - # number of sub-layers to be stacked in the encoder and decoder. - n_layer = 6 - # dropout rates of different modules. - prepostprocess_dropout = 0.1 - attention_dropout = 0.1 - relu_dropout = 0.1 - # to process before each sub-layer - preprocess_cmd = "n" # layer normalization - # to process after each sub-layer - postprocess_cmd = "da" # dropout + residual connection - # the flag indicating whether to share embedding and softmax weights. - # vocabularies in source and target should be same for weight sharing. - weight_sharing = False - - -# The placeholder for batch_size in compile time. Must be -1 currently to be -# consistent with some ops' infer-shape output in compile time, such as the -# sequence_expand op used in beamsearch decoder. -batch_size = -1 -# The placeholder for squence length in compile time. -seq_len = ModelHyperParams.max_length -# Here list the data shapes and data types of all inputs. -# The shapes here act as placeholder and are set to pass the infer-shape in -# compile time. -input_descs = { - # The actual data shape of src_word is: - # [batch_size, max_src_len_in_batch, 1] - "src_word": [(batch_size, seq_len, 1), "int64", 2], - # The actual data shape of src_pos is: - # [batch_size, max_src_len_in_batch, 1] - "src_pos": [(batch_size, seq_len, 1), "int64"], - # This input is used to remove attention weights on paddings in the - # encoder. - # The actual data shape of src_slf_attn_bias is: - # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] - "src_slf_attn_bias": - [(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"], - # The actual data shape of trg_word is: - # [batch_size, max_trg_len_in_batch, 1] - "trg_word": [(batch_size, seq_len, 1), "int64", - 2], # lod_level is only used in fast decoder. - # The actual data shape of trg_pos is: - # [batch_size, max_trg_len_in_batch, 1] - "trg_pos": [(batch_size, seq_len, 1), "int64"], - # This input is used to remove attention weights on paddings and - # subsequent words in the decoder. - # The actual data shape of trg_slf_attn_bias is: - # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] - "trg_slf_attn_bias": - [(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"], - # This input is used to remove attention weights on paddings of the source - # input in the encoder-decoder attention. - # The actual data shape of trg_src_attn_bias is: - # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] - "trg_src_attn_bias": - [(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"], - # This input is used in independent decoder program for inference. - # The actual data shape of enc_output is: - # [batch_size, max_src_len_in_batch, d_model] - "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"], - # The actual data shape of label_word is: - # [batch_size * max_trg_len_in_batch, 1] - "lbl_word": [(batch_size * seq_len, 1), "int64"], - # This input is used to mask out the loss of paddding tokens. - # The actual data shape of label_weight is: - # [batch_size * max_trg_len_in_batch, 1] - "lbl_weight": [(batch_size * seq_len, 1), "float32"], - # This input is used in beam-search decoder. - "init_score": [(batch_size, 1), "float32", 2], - # This input is used in beam-search decoder for the first gather - # (cell states updation) - "init_idx": [(batch_size, ), "int32"], -} - -# Names of word embedding table which might be reused for weight sharing. -word_emb_param_names = ( - "src_word_emb_table", - "trg_word_emb_table", -) -# Names of position encoding table which will be initialized externally. -pos_enc_param_names = ( - "src_pos_enc_table", - "trg_pos_enc_table", -) -# separated inputs for different usages. -encoder_data_input_fields = ( - "src_word", - "src_pos", - "src_slf_attn_bias", -) -decoder_data_input_fields = ( - "trg_word", - "trg_pos", - "trg_slf_attn_bias", - "trg_src_attn_bias", - "enc_output", -) -label_data_input_fields = ( - "lbl_word", - "lbl_weight", -) -# In fast decoder, trg_pos (only containing the current time step) is generated -# by ops and trg_slf_attn_bias is not needed. -fast_decoder_data_input_fields = ( - "trg_word", - # "init_score", - # "init_idx", - "trg_src_attn_bias", -) - - -def merge_cfg_from_list(cfg_list, g_cfgs): - """ - Set the above global configurations using the cfg_list. - """ - assert len(cfg_list) % 2 == 0 - for key, value in zip(cfg_list[0::2], cfg_list[1::2]): - for g_cfg in g_cfgs: - if hasattr(g_cfg, key): - try: - value = eval(value) - except Exception: # for file path - pass - setattr(g_cfg, key, value) - break diff --git a/dygraph/transformer/model.py b/dygraph/transformer/model.py index 326f96cc..25614f15 100644 --- a/dygraph/transformer/model.py +++ b/dygraph/transformer/model.py @@ -18,7 +18,6 @@ import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers -from paddle.fluid.layers.utils import map_structure from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay @@ -128,28 +127,45 @@ class MultiHeadAttention(Layer): output_dim=d_model, bias_attr=False) - def forward(self, queries, keys, values, attn_bias, cache=None): - # compute q ,k ,v - keys = queries if keys is None else keys - values = keys if values is None else values + def _prepare_qkv(self, queries, keys, values, cache=None): + if keys is None: # self-attention + keys, values = queries, queries + static_kv = False + else: # cross-attention + static_kv = True q = self.q_fc(queries) - k = self.k_fc(keys) - v = self.v_fc(values) - - # split head q = layers.reshape(x=q, shape=[0, 0, self.n_head, self.d_key]) q = layers.transpose(x=q, perm=[0, 2, 1, 3]) - k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key]) - k = layers.transpose(x=k, perm=[0, 2, 1, 3]) - v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value]) - v = layers.transpose(x=v, perm=[0, 2, 1, 3]) + + if cache is not None and static_kv and cache.has_key("static_k"): + # for encoder-decoder attention in inference and has cached + k = cache["static_k"] + v = cache["static_v"] + else: + k = self.k_fc(keys) + v = self.v_fc(values) + k = layers.reshape(x=k, shape=[0, 0, self.n_head, self.d_key]) + k = layers.transpose(x=k, perm=[0, 2, 1, 3]) + v = layers.reshape(x=v, shape=[0, 0, self.n_head, self.d_value]) + v = layers.transpose(x=v, perm=[0, 2, 1, 3]) if cache is not None: - cache_k, cache_v = cache["k"], cache["v"] - k = layers.concat([cache_k, k], axis=2) - v = layers.concat([cache_v, v], axis=2) - cache["k"], cache["v"] = k, v + if static_kv and not cache.has_key("static_k"): + # for encoder-decoder attention in inference and has not cached + cache["static_k"], cache["static_v"] = k, v + elif not static_kv: + # for decoder self-attention in inference + cache_k, cache_v = cache["k"], cache["v"] + k = layers.concat([cache_k, k], axis=2) + v = layers.concat([cache_v, v], axis=2) + cache["k"], cache["v"] = k, v + + return q, k, v + + def forward(self, queries, keys, values, attn_bias, cache=None): + # compute q ,k ,v + q, k, v = self._prepare_qkv(queries, keys, values, cache) # scale dot product attention product = layers.matmul(x=q, @@ -381,7 +397,7 @@ class DecoderLayer(Layer): cross_attn_output = self.cross_attn( self.preprocesser2(self_attn_output), enc_output, enc_output, - cross_attn_bias) + cross_attn_bias, cache) cross_attn_output = self.postprocesser2(cross_attn_output, self_attn_output) @@ -801,15 +817,45 @@ class Transformer(Layer): return finished_seq, finished_scores def beam_search(self, - src_word, - src_pos, - src_slf_attn_bias, - trg_word, - trg_src_attn_bias, - bos_id=0, - eos_id=1, - beam_size=4, - max_len=256): + src_word, + src_pos, + src_slf_attn_bias, + trg_word, + trg_src_attn_bias, + bos_id=0, + eos_id=1, + beam_size=4, + max_len=256): + if beam_size == 1: + return self._greedy_search(src_word, + src_pos, + src_slf_attn_bias, + trg_word, + trg_src_attn_bias, + bos_id=bos_id, + eos_id=eos_id, + max_len=max_len) + else: + return self._beam_search(src_word, + src_pos, + src_slf_attn_bias, + trg_word, + trg_src_attn_bias, + bos_id=bos_id, + eos_id=eos_id, + beam_size=beam_size, + max_len=max_len) + + def _beam_search(self, + src_word, + src_pos, + src_slf_attn_bias, + trg_word, + trg_src_attn_bias, + bos_id=0, + eos_id=1, + beam_size=4, + max_len=256): def expand_to_beam_size(tensor, beam_size): tensor = layers.reshape(tensor, [tensor.shape[0], 1] + tensor.shape[1:]) @@ -822,22 +868,30 @@ class Transformer(Layer): tensor.shape[2:]) def split_batch_beams(tensor): - return fluid.layers.reshape(tensor, - shape=[-1, beam_size] + - list(tensor.shape[1:])) + return layers.reshape(tensor, + shape=[-1, beam_size] + + list(tensor.shape[1:])) def mask_probs(probs, finished, noend_mask_tensor): # TODO: use where_op finished = layers.cast(finished, dtype=probs.dtype) - probs = layers.elementwise_mul( - layers.expand(layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]), - noend_mask_tensor, axis=-1) - layers.elementwise_mul(probs, (finished - 1), axis=0) + probs = layers.elementwise_mul(layers.expand( + layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]), + noend_mask_tensor, + axis=-1) - layers.elementwise_mul( + probs, (finished - 1), axis=0) return probs def gather(x, indices, batch_pos): - topk_coordinates = fluid.layers.stack([batch_pos, indices], axis=2) + topk_coordinates = layers.stack([batch_pos, indices], axis=2) return layers.gather_nd(x, topk_coordinates) + def update_states(func, caches): + for cache in caches: # no need to update static_kv + cache["k"] = func(cache["k"]) + cache["v"] = func(cache["v"]) + return caches + # run encoder enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) @@ -893,30 +947,29 @@ class Transformer(Layer): trg_pos = layers.fill_constant(shape=trg_word.shape, dtype="int64", value=i) - caches = map_structure( # can not be reshaped since the 0 size + caches = update_states( # can not be reshaped since the 0 size lambda x: x if i == 0 else merge_batch_beams(x), caches) logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, enc_output, caches) - caches = map_structure(split_batch_beams, caches) + caches = update_states(split_batch_beams, caches) step_log_probs = split_batch_beams( - fluid.layers.log(fluid.layers.softmax(logits))) + layers.log(layers.softmax(logits))) step_log_probs = mask_probs(step_log_probs, finished, noend_mask_tensor) log_probs = layers.elementwise_add(x=step_log_probs, - y=log_probs, - axis=0) + y=log_probs, + axis=0) log_probs = layers.reshape(log_probs, [-1, beam_size * self.trg_vocab_size]) scores = log_probs - topk_scores, topk_indices = fluid.layers.topk(input=scores, - k=beam_size) - beam_indices = fluid.layers.elementwise_floordiv( + topk_scores, topk_indices = layers.topk(input=scores, k=beam_size) + beam_indices = layers.elementwise_floordiv( topk_indices, vocab_size_tensor) - token_indices = fluid.layers.elementwise_mod( + token_indices = layers.elementwise_mod( topk_indices, vocab_size_tensor) # update states - caches = map_structure(lambda x: gather(x, beam_indices, batch_pos), + caches = update_states(lambda x: gather(x, beam_indices, batch_pos), caches) log_probs = gather(log_probs, topk_indices, batch_pos) finished = gather(finished, beam_indices, batch_pos) @@ -936,4 +989,73 @@ class Transformer(Layer): layers.gather_tree(predict_ids, parent_ids), [1, 2, 0]) finished_scores = topk_scores - return finished_seq, finished_scores \ No newline at end of file + return finished_seq, finished_scores + + def _greedy_search(self, + src_word, + src_pos, + src_slf_attn_bias, + trg_word, + trg_src_attn_bias, + bos_id=0, + eos_id=1, + max_len=256): + # run encoder + enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) + + # constant number + batch_size = enc_output.shape[0] + max_len = (enc_output.shape[1] + 20) if max_len is None else max_len + end_token_tensor = layers.fill_constant(shape=[batch_size, 1], + dtype="int64", + value=eos_id) + + predict_ids = [] + log_probs = layers.fill_constant(shape=[batch_size, 1], + dtype="float32", + value=0) + trg_word = layers.fill_constant(shape=[batch_size, 1], + dtype="int64", + value=bos_id) + + ## init states (caches) for transformer + caches = [{ + "k": + layers.fill_constant( + shape=[batch_size, self.n_head, 0, self.d_key], + dtype=enc_output.dtype, + value=0), + "v": + layers.fill_constant( + shape=[batch_size, self.n_head, 0, self.d_value], + dtype=enc_output.dtype, + value=0), + } for i in range(self.n_layer)] + + for i in range(max_len): + trg_pos = layers.fill_constant(shape=trg_word.shape, + dtype="int64", + value=i) + logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, + enc_output, caches) + step_log_probs = layers.log(layers.softmax(logits)) + log_probs = layers.elementwise_add(x=step_log_probs, + y=log_probs, + axis=0) + scores = log_probs + topk_scores, topk_indices = layers.topk(input=scores, k=1) + + finished = layers.equal(topk_indices, end_token_tensor) + trg_word = topk_indices + log_probs = topk_scores + + predict_ids.append(topk_indices) + + if layers.reduce_all(finished).numpy(): + break + + predict_ids = layers.stack(predict_ids, axis=0) + finished_seq = layers.transpose(predict_ids, [1, 2, 0]) + finished_scores = topk_scores + + return finished_seq, finished_scores diff --git a/dygraph/transformer/train.py b/dygraph/transformer/train.py index bbfb2c12..40d675b4 100644 --- a/dygraph/transformer/train.py +++ b/dygraph/transformer/train.py @@ -57,6 +57,25 @@ def do_train(args): max_length=args.max_length, n_head=args.n_head) batch_generator = processor.data_generator(phase="train") + if args.validation_file: + val_processor = reader.DataProcessor( + fpattern=args.validation_file, + src_vocab_fpath=args.src_vocab_fpath, + trg_vocab_fpath=args.trg_vocab_fpath, + token_delimiter=args.token_delimiter, + use_token_batch=args.use_token_batch, + batch_size=args.batch_size, + device_count=trainer_count, + pool_size=args.pool_size, + sort_type=args.sort_type, + shuffle=False, + shuffle_batch=False, + start_mark=args.special_token[0], + end_mark=args.special_token[1], + unk_mark=args.special_token[2], + max_length=args.max_length, + n_head=args.n_head) + val_batch_generator = val_processor.data_generator(phase="train") if trainer_count > 1: # for multi-process gpu training batch_generator = fluid.contrib.reader.distributed_batch_reader( batch_generator) @@ -73,6 +92,9 @@ def do_train(args): # define data loader train_loader = fluid.io.DataLoader.from_generator(capacity=10) train_loader.set_batch_generator(batch_generator, places=place) + if args.validation_file: + val_loader = fluid.io.DataLoader.from_generator(capacity=10) + val_loader.set_batch_generator(val_batch_generator, places=place) # define model transformer = Transformer( @@ -123,6 +145,7 @@ def do_train(args): # train loop for pass_id in range(args.epoch): pass_start_time = time.time() + avg_batch_time = time.time() batch_id = 0 for input_data in train_loader(): (src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos, @@ -155,7 +178,6 @@ def do_train(args): (step_idx, pass_id, batch_id, total_avg_cost, total_avg_cost - loss_normalizer, np.exp([min(total_avg_cost, 100)]))) - avg_batch_time = time.time() else: logging.info( "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " @@ -164,12 +186,37 @@ def do_train(args): total_avg_cost - loss_normalizer, np.exp([min(total_avg_cost, 100)]), args.print_step / (time.time() - avg_batch_time))) - avg_batch_time = time.time() - if step_idx % args.save_step == 0 and step_idx != 0 and ( - trainer_count == 1 - or fluid.dygraph.parallel.Env().dev_id == 0): - if args.save_model: + + if step_idx % args.save_step == 0 and step_idx != 0: + # validation + if args.validation_file: + transformer.eval() + total_sum_cost = 0 + total_token_num = 0 + for input_data in val_loader(): + (src_word, src_pos, src_slf_attn_bias, trg_word, + trg_pos, trg_slf_attn_bias, trg_src_attn_bias, + lbl_word, lbl_weight) = input_data + logits = transformer(src_word, src_pos, + src_slf_attn_bias, trg_word, + trg_pos, trg_slf_attn_bias, + trg_src_attn_bias) + sum_cost, avg_cost, token_num = criterion( + logits, lbl_word, lbl_weight) + total_sum_cost += sum_cost.numpy() + total_token_num += token_num.numpy() + total_avg_cost = total_sum_cost / total_token_num + logging.info("validation, step_idx: %d, avg loss: %f, " + "normalized loss: %f, ppl: %f" % + (step_idx, total_avg_cost, + total_avg_cost - loss_normalizer, + np.exp([min(total_avg_cost, 100)]))) + transformer.train() + + if args.save_model and ( + trainer_count == 1 + or fluid.dygraph.parallel.Env().dev_id == 0): model_dir = os.path.join(args.save_model, "step_" + str(step_idx)) if not os.path.exists(model_dir): @@ -181,6 +228,7 @@ def do_train(args): optimizer.state_dict(), os.path.join(model_dir, "transformer")) + avg_batch_time = time.time() batch_id += 1 step_idx += 1 diff --git a/dygraph/transformer/transformer.yaml b/dygraph/transformer/transformer.yaml index 76151f20..15d9e783 100644 --- a/dygraph/transformer/transformer.yaml +++ b/dygraph/transformer/transformer.yaml @@ -19,6 +19,8 @@ inference_model_dir: "infer_model" random_seed: None # The pattern to match training data files. training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de" +# The pattern to match validation data files. +validation_file: "wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de" # The pattern to match test data files. predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de" # The file to output the translation results of predict_file to. -- GitLab