未验证 提交 2c4dbca7 编写于 作者: G Guo Sheng 提交者: GitHub

Fix the shape desc of attn_bias in Transformer. (#3994)

上级 ae6c6e64
......@@ -12,65 +12,73 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = None
# The placeholder for squence length in compile time.
seq_len = None
# The placeholder for head number in compile time.
n_head = 8
# The placeholder for model dim in compile time.
d_model = 512
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch]
"src_word": [(batch_size, seq_len), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(None, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(None, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
def get_input_descs(args):
"""
Generate a dict mapping data fields to the corresponding data shapes and
data types.
"""
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = None
# The placeholder for squence length in compile time.
seq_len = None
# The head number.
n_head = getattr(args, "n_head", 8)
# The model dim.
d_model = getattr(args, "d_model", 512)
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch]
"src_word": [(batch_size, seq_len), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(None, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(None, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
return input_descs
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
......
......@@ -93,10 +93,11 @@ def do_save_inference_model(args):
# define input and reader
input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{
"name": name,
"shape": desc.input_descs[name][0],
"dtype": desc.input_descs[name][1]
"shape": input_descs[name][0],
"dtype": input_descs[name][1]
} for name in input_field_names]
input_field = InputField(input_slots)
......
......@@ -134,10 +134,11 @@ def do_predict(args):
# define input and reader
input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{
"name": name,
"shape": desc.input_descs[name][0],
"dtype": desc.input_descs[name][1]
"shape": input_descs[name][0],
"dtype": input_descs[name][1]
} for name in input_field_names]
input_field = InputField(input_slots)
......
......@@ -175,10 +175,11 @@ def do_train(args):
input_field_names = desc.encoder_data_input_fields + \
desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{
"name": name,
"shape": desc.input_descs[name][0],
"dtype": desc.input_descs[name][1]
"shape": input_descs[name][0],
"dtype": input_descs[name][1]
} for name in input_field_names]
input_field = InputField(input_slots)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册