未验证 提交 2c4dbca7 编写于 作者: G Guo Sheng 提交者: GitHub

Fix the shape desc of attn_bias in Transformer. (#3994)

上级 ae6c6e64
...@@ -12,65 +12,73 @@ ...@@ -12,65 +12,73 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# The placeholder for batch_size in compile time. Must be -1 currently to be def get_input_descs(args):
# consistent with some ops' infer-shape output in compile time, such as the """
# sequence_expand op used in beamsearch decoder. Generate a dict mapping data fields to the corresponding data shapes and
batch_size = None data types.
# The placeholder for squence length in compile time. """
seq_len = None # The placeholder for batch_size in compile time. Must be -1 currently to be
# The placeholder for head number in compile time. # consistent with some ops' infer-shape output in compile time, such as the
n_head = 8 # sequence_expand op used in beamsearch decoder.
# The placeholder for model dim in compile time. batch_size = None
d_model = 512 # The placeholder for squence length in compile time.
# Here list the data shapes and data types of all inputs. seq_len = None
# The shapes here act as placeholder and are set to pass the infer-shape in # The head number.
# compile time. n_head = getattr(args, "n_head", 8)
input_descs = { # The model dim.
# The actual data shape of src_word is: d_model = getattr(args, "d_model", 512)
# [batch_size, max_src_len_in_batch]
"src_word": [(batch_size, seq_len), "int64", 2], # Here list the data shapes and data types of all inputs.
# The actual data shape of src_pos is: # The shapes here act as placeholder and are set to pass the infer-shape in
# [batch_size, max_src_len_in_batch, 1] # compile time.
"src_pos": [(batch_size, seq_len), "int64"], input_descs = {
# This input is used to remove attention weights on paddings in the # The actual data shape of src_word is:
# encoder. # [batch_size, max_src_len_in_batch]
# The actual data shape of src_slf_attn_bias is: "src_word": [(batch_size, seq_len), "int64", 2],
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] # The actual data shape of src_pos is:
"src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], # [batch_size, max_src_len_in_batch, 1]
# The actual data shape of trg_word is: "src_pos": [(batch_size, seq_len), "int64"],
# [batch_size, max_trg_len_in_batch, 1] # This input is used to remove attention weights on paddings in the
"trg_word": [(batch_size, seq_len), "int64", # encoder.
2], # lod_level is only used in fast decoder. # The actual data shape of src_slf_attn_bias is:
# The actual data shape of trg_pos is: # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
# [batch_size, max_trg_len_in_batch, 1] "src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
"trg_pos": [(batch_size, seq_len), "int64"], # The actual data shape of trg_word is:
# This input is used to remove attention weights on paddings and # [batch_size, max_trg_len_in_batch, 1]
# subsequent words in the decoder. "trg_word": [(batch_size, seq_len), "int64",
# The actual data shape of trg_slf_attn_bias is: 2], # lod_level is only used in fast decoder.
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] # The actual data shape of trg_pos is:
"trg_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], # [batch_size, max_trg_len_in_batch, 1]
# This input is used to remove attention weights on paddings of the source "trg_pos": [(batch_size, seq_len), "int64"],
# input in the encoder-decoder attention. # This input is used to remove attention weights on paddings and
# The actual data shape of trg_src_attn_bias is: # subsequent words in the decoder.
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] # The actual data shape of trg_slf_attn_bias is:
"trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
# This input is used in independent decoder program for inference. "trg_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# The actual data shape of enc_output is: # This input is used to remove attention weights on paddings of the source
# [batch_size, max_src_len_in_batch, d_model] # input in the encoder-decoder attention.
"enc_output": [(batch_size, seq_len, d_model), "float32"], # The actual data shape of trg_src_attn_bias is:
# The actual data shape of label_word is: # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
# [batch_size * max_trg_len_in_batch, 1] "trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
"lbl_word": [(None, 1), "int64"], # This input is used in independent decoder program for inference.
# This input is used to mask out the loss of paddding tokens. # The actual data shape of enc_output is:
# The actual data shape of label_weight is: # [batch_size, max_src_len_in_batch, d_model]
# [batch_size * max_trg_len_in_batch, 1] "enc_output": [(batch_size, seq_len, d_model), "float32"],
"lbl_weight": [(None, 1), "float32"], # The actual data shape of label_word is:
# This input is used in beam-search decoder. # [batch_size * max_trg_len_in_batch, 1]
"init_score": [(batch_size, 1), "float32", 2], "lbl_word": [(None, 1), "int64"],
# This input is used in beam-search decoder for the first gather # This input is used to mask out the loss of paddding tokens.
# (cell states updation) # The actual data shape of label_weight is:
"init_idx": [(batch_size, ), "int32"], # [batch_size * max_trg_len_in_batch, 1]
} "lbl_weight": [(None, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
return input_descs
# Names of word embedding table which might be reused for weight sharing. # Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = ( word_emb_param_names = (
......
...@@ -93,10 +93,11 @@ def do_save_inference_model(args): ...@@ -93,10 +93,11 @@ def do_save_inference_model(args):
# define input and reader # define input and reader
input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{ input_slots = [{
"name": name, "name": name,
"shape": desc.input_descs[name][0], "shape": input_descs[name][0],
"dtype": desc.input_descs[name][1] "dtype": input_descs[name][1]
} for name in input_field_names] } for name in input_field_names]
input_field = InputField(input_slots) input_field = InputField(input_slots)
......
...@@ -134,10 +134,11 @@ def do_predict(args): ...@@ -134,10 +134,11 @@ def do_predict(args):
# define input and reader # define input and reader
input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{ input_slots = [{
"name": name, "name": name,
"shape": desc.input_descs[name][0], "shape": input_descs[name][0],
"dtype": desc.input_descs[name][1] "dtype": input_descs[name][1]
} for name in input_field_names] } for name in input_field_names]
input_field = InputField(input_slots) input_field = InputField(input_slots)
......
...@@ -175,10 +175,11 @@ def do_train(args): ...@@ -175,10 +175,11 @@ def do_train(args):
input_field_names = desc.encoder_data_input_fields + \ input_field_names = desc.encoder_data_input_fields + \
desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{ input_slots = [{
"name": name, "name": name,
"shape": desc.input_descs[name][0], "shape": input_descs[name][0],
"dtype": desc.input_descs[name][1] "dtype": input_descs[name][1]
} for name in input_field_names] } for name in input_field_names]
input_field = InputField(input_slots) input_field = InputField(input_slots)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册