未验证 提交 ae6c6e64 编写于 作者: Y Yibing Liu 提交者: GitHub

Fix input_mask shape in inputs (#3992)

上级 4195df96
...@@ -25,8 +25,8 @@ from model.bert import BertModel ...@@ -25,8 +25,8 @@ from model.bert import BertModel
def create_model(args, bert_config, num_labels, is_prediction=False): def create_model(args, bert_config, num_labels, is_prediction=False):
input_fields = { input_fields = {
'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'labels'], 'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'labels'],
'shapes': [[None, None], [None, None], [None, None], 'shapes': [[None, None], [None, None], [None, None], [None, None, 1],
[None, args.max_seq_len, 1], [None, 1]], [None, 1]],
'dtypes': ['int64', 'int64', 'int64', 'float32', 'int64'], 'dtypes': ['int64', 'int64', 'int64', 'float32', 'int64'],
'lod_levels': [0, 0, 0, 0, 0], 'lod_levels': [0, 0, 0, 0, 0],
} }
......
...@@ -111,7 +111,7 @@ def create_model(bert_config, is_training=False): ...@@ -111,7 +111,7 @@ def create_model(bert_config, is_training=False):
input_fields = { input_fields = {
'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'start_positions', 'end_positions'], 'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'start_positions', 'end_positions'],
'shapes': [[None, None], [None, None], [None, None], 'shapes': [[None, None], [None, None], [None, None],
[None, args.max_seq_len, 1], [None, 1], [None, 1]], [None, None, 1], [None, 1], [None, 1]],
'dtypes': [ 'dtypes': [
'int64', 'int64', 'int64', 'float32', 'int64', 'int64'], 'int64', 'int64', 'int64', 'float32', 'int64', 'int64'],
'lod_levels': [0, 0, 0, 0, 0, 0], 'lod_levels': [0, 0, 0, 0, 0, 0],
...@@ -120,7 +120,7 @@ def create_model(bert_config, is_training=False): ...@@ -120,7 +120,7 @@ def create_model(bert_config, is_training=False):
input_fields = { input_fields = {
'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'unique_id'], 'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'unique_id'],
'shapes': [[None, None], [None, None], [None, None], 'shapes': [[None, None], [None, None], [None, None],
[None, args.max_seq_len, 1], [None, 1]], [None, None, 1], [None, 1]],
'dtypes': [ 'dtypes': [
'int64', 'int64', 'int64', 'float32', 'int64'], 'int64', 'int64', 'int64', 'float32', 'int64'],
'lod_levels': [0, 0, 0, 0, 0], 'lod_levels': [0, 0, 0, 0, 0],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册