From 2c4dbca75da6bf313274638d050f64c7fb9d6b8f Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Wed, 27 Nov 2019 17:15:09 +0800 Subject: [PATCH] Fix the shape desc of attn_bias in Transformer. (#3994) --- PaddleNLP/PaddleMT/transformer/desc.py | 126 ++++++++++-------- .../PaddleMT/transformer/inference_model.py | 5 +- PaddleNLP/PaddleMT/transformer/predict.py | 5 +- PaddleNLP/PaddleMT/transformer/train.py | 5 +- 4 files changed, 76 insertions(+), 65 deletions(-) diff --git a/PaddleNLP/PaddleMT/transformer/desc.py b/PaddleNLP/PaddleMT/transformer/desc.py index d6c34191..f6fa768a 100644 --- a/PaddleNLP/PaddleMT/transformer/desc.py +++ b/PaddleNLP/PaddleMT/transformer/desc.py @@ -12,65 +12,73 @@ # See the License for the specific language governing permissions and # limitations under the License. -# The placeholder for batch_size in compile time. Must be -1 currently to be -# consistent with some ops' infer-shape output in compile time, such as the -# sequence_expand op used in beamsearch decoder. -batch_size = None -# The placeholder for squence length in compile time. -seq_len = None -# The placeholder for head number in compile time. -n_head = 8 -# The placeholder for model dim in compile time. -d_model = 512 -# Here list the data shapes and data types of all inputs. -# The shapes here act as placeholder and are set to pass the infer-shape in -# compile time. -input_descs = { - # The actual data shape of src_word is: - # [batch_size, max_src_len_in_batch] - "src_word": [(batch_size, seq_len), "int64", 2], - # The actual data shape of src_pos is: - # [batch_size, max_src_len_in_batch, 1] - "src_pos": [(batch_size, seq_len), "int64"], - # This input is used to remove attention weights on paddings in the - # encoder. - # The actual data shape of src_slf_attn_bias is: - # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] - "src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], - # The actual data shape of trg_word is: - # [batch_size, max_trg_len_in_batch, 1] - "trg_word": [(batch_size, seq_len), "int64", - 2], # lod_level is only used in fast decoder. - # The actual data shape of trg_pos is: - # [batch_size, max_trg_len_in_batch, 1] - "trg_pos": [(batch_size, seq_len), "int64"], - # This input is used to remove attention weights on paddings and - # subsequent words in the decoder. - # The actual data shape of trg_slf_attn_bias is: - # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] - "trg_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], - # This input is used to remove attention weights on paddings of the source - # input in the encoder-decoder attention. - # The actual data shape of trg_src_attn_bias is: - # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] - "trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], - # This input is used in independent decoder program for inference. - # The actual data shape of enc_output is: - # [batch_size, max_src_len_in_batch, d_model] - "enc_output": [(batch_size, seq_len, d_model), "float32"], - # The actual data shape of label_word is: - # [batch_size * max_trg_len_in_batch, 1] - "lbl_word": [(None, 1), "int64"], - # This input is used to mask out the loss of paddding tokens. - # The actual data shape of label_weight is: - # [batch_size * max_trg_len_in_batch, 1] - "lbl_weight": [(None, 1), "float32"], - # This input is used in beam-search decoder. - "init_score": [(batch_size, 1), "float32", 2], - # This input is used in beam-search decoder for the first gather - # (cell states updation) - "init_idx": [(batch_size, ), "int32"], -} +def get_input_descs(args): + """ + Generate a dict mapping data fields to the corresponding data shapes and + data types. + """ + # The placeholder for batch_size in compile time. Must be -1 currently to be + # consistent with some ops' infer-shape output in compile time, such as the + # sequence_expand op used in beamsearch decoder. + batch_size = None + # The placeholder for squence length in compile time. + seq_len = None + # The head number. + n_head = getattr(args, "n_head", 8) + # The model dim. + d_model = getattr(args, "d_model", 512) + + # Here list the data shapes and data types of all inputs. + # The shapes here act as placeholder and are set to pass the infer-shape in + # compile time. + input_descs = { + # The actual data shape of src_word is: + # [batch_size, max_src_len_in_batch] + "src_word": [(batch_size, seq_len), "int64", 2], + # The actual data shape of src_pos is: + # [batch_size, max_src_len_in_batch, 1] + "src_pos": [(batch_size, seq_len), "int64"], + # This input is used to remove attention weights on paddings in the + # encoder. + # The actual data shape of src_slf_attn_bias is: + # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] + "src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], + # The actual data shape of trg_word is: + # [batch_size, max_trg_len_in_batch, 1] + "trg_word": [(batch_size, seq_len), "int64", + 2], # lod_level is only used in fast decoder. + # The actual data shape of trg_pos is: + # [batch_size, max_trg_len_in_batch, 1] + "trg_pos": [(batch_size, seq_len), "int64"], + # This input is used to remove attention weights on paddings and + # subsequent words in the decoder. + # The actual data shape of trg_slf_attn_bias is: + # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] + "trg_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], + # This input is used to remove attention weights on paddings of the source + # input in the encoder-decoder attention. + # The actual data shape of trg_src_attn_bias is: + # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] + "trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], + # This input is used in independent decoder program for inference. + # The actual data shape of enc_output is: + # [batch_size, max_src_len_in_batch, d_model] + "enc_output": [(batch_size, seq_len, d_model), "float32"], + # The actual data shape of label_word is: + # [batch_size * max_trg_len_in_batch, 1] + "lbl_word": [(None, 1), "int64"], + # This input is used to mask out the loss of paddding tokens. + # The actual data shape of label_weight is: + # [batch_size * max_trg_len_in_batch, 1] + "lbl_weight": [(None, 1), "float32"], + # This input is used in beam-search decoder. + "init_score": [(batch_size, 1), "float32", 2], + # This input is used in beam-search decoder for the first gather + # (cell states updation) + "init_idx": [(batch_size, ), "int32"], + } + + return input_descs # Names of word embedding table which might be reused for weight sharing. word_emb_param_names = ( diff --git a/PaddleNLP/PaddleMT/transformer/inference_model.py b/PaddleNLP/PaddleMT/transformer/inference_model.py index 40fc7ede..5de0a107 100644 --- a/PaddleNLP/PaddleMT/transformer/inference_model.py +++ b/PaddleNLP/PaddleMT/transformer/inference_model.py @@ -93,10 +93,11 @@ def do_save_inference_model(args): # define input and reader input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields + input_descs = desc.get_input_descs(args.args) input_slots = [{ "name": name, - "shape": desc.input_descs[name][0], - "dtype": desc.input_descs[name][1] + "shape": input_descs[name][0], + "dtype": input_descs[name][1] } for name in input_field_names] input_field = InputField(input_slots) diff --git a/PaddleNLP/PaddleMT/transformer/predict.py b/PaddleNLP/PaddleMT/transformer/predict.py index 7ad847fd..2ad93e58 100644 --- a/PaddleNLP/PaddleMT/transformer/predict.py +++ b/PaddleNLP/PaddleMT/transformer/predict.py @@ -134,10 +134,11 @@ def do_predict(args): # define input and reader input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields + input_descs = desc.get_input_descs(args.args) input_slots = [{ "name": name, - "shape": desc.input_descs[name][0], - "dtype": desc.input_descs[name][1] + "shape": input_descs[name][0], + "dtype": input_descs[name][1] } for name in input_field_names] input_field = InputField(input_slots) diff --git a/PaddleNLP/PaddleMT/transformer/train.py b/PaddleNLP/PaddleMT/transformer/train.py index b3555be6..c9fb5d72 100644 --- a/PaddleNLP/PaddleMT/transformer/train.py +++ b/PaddleNLP/PaddleMT/transformer/train.py @@ -175,10 +175,11 @@ def do_train(args): input_field_names = desc.encoder_data_input_fields + \ desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields + input_descs = desc.get_input_descs(args.args) input_slots = [{ "name": name, - "shape": desc.input_descs[name][0], - "dtype": desc.input_descs[name][1] + "shape": input_descs[name][0], + "dtype": input_descs[name][1] } for name in input_field_names] input_field = InputField(input_slots) -- GitLab