未验证 提交 2c4dbca7 编写于 作者: G Guo Sheng 提交者: GitHub

Fix the shape desc of attn_bias in Transformer. (#3994)

上级 ae6c6e64
...@@ -12,20 +12,26 @@ ...@@ -12,20 +12,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# The placeholder for batch_size in compile time. Must be -1 currently to be def get_input_descs(args):
# consistent with some ops' infer-shape output in compile time, such as the """
# sequence_expand op used in beamsearch decoder. Generate a dict mapping data fields to the corresponding data shapes and
batch_size = None data types.
# The placeholder for squence length in compile time. """
seq_len = None # The placeholder for batch_size in compile time. Must be -1 currently to be
# The placeholder for head number in compile time. # consistent with some ops' infer-shape output in compile time, such as the
n_head = 8 # sequence_expand op used in beamsearch decoder.
# The placeholder for model dim in compile time. batch_size = None
d_model = 512 # The placeholder for squence length in compile time.
# Here list the data shapes and data types of all inputs. seq_len = None
# The shapes here act as placeholder and are set to pass the infer-shape in # The head number.
# compile time. n_head = getattr(args, "n_head", 8)
input_descs = { # The model dim.
d_model = getattr(args, "d_model", 512)
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is: # The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch] # [batch_size, max_src_len_in_batch]
"src_word": [(batch_size, seq_len), "int64", 2], "src_word": [(batch_size, seq_len), "int64", 2],
...@@ -70,7 +76,9 @@ input_descs = { ...@@ -70,7 +76,9 @@ input_descs = {
# This input is used in beam-search decoder for the first gather # This input is used in beam-search decoder for the first gather
# (cell states updation) # (cell states updation)
"init_idx": [(batch_size, ), "int32"], "init_idx": [(batch_size, ), "int32"],
} }
return input_descs
# Names of word embedding table which might be reused for weight sharing. # Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = ( word_emb_param_names = (
......
...@@ -93,10 +93,11 @@ def do_save_inference_model(args): ...@@ -93,10 +93,11 @@ def do_save_inference_model(args):
# define input and reader # define input and reader
input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{ input_slots = [{
"name": name, "name": name,
"shape": desc.input_descs[name][0], "shape": input_descs[name][0],
"dtype": desc.input_descs[name][1] "dtype": input_descs[name][1]
} for name in input_field_names] } for name in input_field_names]
input_field = InputField(input_slots) input_field = InputField(input_slots)
......
...@@ -134,10 +134,11 @@ def do_predict(args): ...@@ -134,10 +134,11 @@ def do_predict(args):
# define input and reader # define input and reader
input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields input_field_names = desc.encoder_data_input_fields + desc.fast_decoder_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{ input_slots = [{
"name": name, "name": name,
"shape": desc.input_descs[name][0], "shape": input_descs[name][0],
"dtype": desc.input_descs[name][1] "dtype": input_descs[name][1]
} for name in input_field_names] } for name in input_field_names]
input_field = InputField(input_slots) input_field = InputField(input_slots)
......
...@@ -175,10 +175,11 @@ def do_train(args): ...@@ -175,10 +175,11 @@ def do_train(args):
input_field_names = desc.encoder_data_input_fields + \ input_field_names = desc.encoder_data_input_fields + \
desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields desc.decoder_data_input_fields[:-1] + desc.label_data_input_fields
input_descs = desc.get_input_descs(args.args)
input_slots = [{ input_slots = [{
"name": name, "name": name,
"shape": desc.input_descs[name][0], "shape": input_descs[name][0],
"dtype": desc.input_descs[name][1] "dtype": input_descs[name][1]
} for name in input_field_names] } for name in input_field_names]
input_field = InputField(input_slots) input_field = InputField(input_slots)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册