未验证 提交 ae6c6e64 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix input_mask shape in inputs (#3992)

上级 4195df96
......@@ -25,8 +25,8 @@ from model.bert import BertModel
def create_model(args, bert_config, num_labels, is_prediction=False):
input_fields = {
'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'labels'],
'shapes': [[None, None], [None, None], [None, None],
[None, args.max_seq_len, 1], [None, 1]],
'shapes': [[None, None], [None, None], [None, None], [None, None, 1],
[None, 1]],
'dtypes': ['int64', 'int64', 'int64', 'float32', 'int64'],
'lod_levels': [0, 0, 0, 0, 0],
}
......
......@@ -111,7 +111,7 @@ def create_model(bert_config, is_training=False):
input_fields = {
'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'start_positions', 'end_positions'],
'shapes': [[None, None], [None, None], [None, None],
[None, args.max_seq_len, 1], [None, 1], [None, 1]],
[None, None, 1], [None, 1], [None, 1]],
'dtypes': [
'int64', 'int64', 'int64', 'float32', 'int64', 'int64'],
'lod_levels': [0, 0, 0, 0, 0, 0],
......@@ -120,7 +120,7 @@ def create_model(bert_config, is_training=False):
input_fields = {
'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'unique_id'],
'shapes': [[None, None], [None, None], [None, None],
[None, args.max_seq_len, 1], [None, 1]],
[None, None, 1], [None, 1]],
'dtypes': [
'int64', 'int64', 'int64', 'float32', 'int64'],
'lod_levels': [0, 0, 0, 0, 0],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册