未验证 提交 2c4dbca7 编写于 作者: G Guo Sheng 提交者: GitHub

Fix the shape desc of attn_bias in Transformer. (#3994)

上级 ae6c6e64
......@@ -12,20 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = None
# The placeholder for squence length in compile time.
seq_len = None
# The placeholder for head number in compile time.
n_head = 8
# The placeholder for model dim in compile time.
d_model = 512
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
def get_input_descs(args):
"""
Generate a dict mapping data fields to the corresponding data shapes and
data types.
"""
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = None
# The placeholder for squence length in compile time.
seq_len = None
# The head number.
n_head = getattr(args, "n_head", 8)
# The model dim.
d_model = getattr(args, "d_model", 512)
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch]
"src_word": [(batch_size, seq_len), "int64", 2],
......@@ -70,7 +76,9 @@ input_descs = {
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
}
return input_descs
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
......
......@@ -93,10 +93,11 @@ def do_save_inference_model(args):
# define input and reader
input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{
"name": name,
"shape": desc.input_descs[name][0],
"dtype": desc.input_descs[name][1]
"shape": input_descs[name][0],
"dtype": input_descs[name][1]
} for name in input_field_names]
input_field = InputField(input_slots)
......
......@@ -134,10 +134,11 @@ def do_predict(args):
# define input and reader
input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{
"name": name,
"shape": desc.input_descs[name][0],
"dtype": desc.input_descs[name][1]
"shape": input_descs[name][0],
"dtype": input_descs[name][1]
} for name in input_field_names]
input_field = InputField(input_slots)
......
......@@ -175,10 +175,11 @@ def do_train(args):
input_field_names = desc.encoder_data_input_fields + \
desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{
"name": name,
"shape": desc.input_descs[name][0],
"dtype": desc.input_descs[name][1]
"shape": input_descs[name][0],
"dtype": input_descs[name][1]
} for name in input_field_names]
input_field = InputField(input_slots)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册