From ae6c6e644bb3e0bd63e685f1bd48371091540aff Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 27 Nov 2019 15:35:52 +0800 Subject: [PATCH] Fix input_mask shape in inputs (#3992) --- PaddleNLP/PaddleLARK/BERT/model/classifier.py | 4 ++-- PaddleNLP/PaddleLARK/BERT/run_squad.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/PaddleNLP/PaddleLARK/BERT/model/classifier.py b/PaddleNLP/PaddleLARK/BERT/model/classifier.py index 03bfb7aa..079d5683 100644 --- a/PaddleNLP/PaddleLARK/BERT/model/classifier.py +++ b/PaddleNLP/PaddleLARK/BERT/model/classifier.py @@ -25,8 +25,8 @@ from model.bert import BertModel def create_model(args, bert_config, num_labels, is_prediction=False): input_fields = { 'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'labels'], - 'shapes': [[None, None], [None, None], [None, None], - [None, args.max_seq_len, 1], [None, 1]], + 'shapes': [[None, None], [None, None], [None, None], [None, None, 1], + [None, 1]], 'dtypes': ['int64', 'int64', 'int64', 'float32', 'int64'], 'lod_levels': [0, 0, 0, 0, 0], } diff --git a/PaddleNLP/PaddleLARK/BERT/run_squad.py b/PaddleNLP/PaddleLARK/BERT/run_squad.py index fc3659b6..e005b243 100644 --- a/PaddleNLP/PaddleLARK/BERT/run_squad.py +++ b/PaddleNLP/PaddleLARK/BERT/run_squad.py @@ -111,7 +111,7 @@ def create_model(bert_config, is_training=False): input_fields = { 'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'start_positions', 'end_positions'], 'shapes': [[None, None], [None, None], [None, None], - [None, args.max_seq_len, 1], [None, 1], [None, 1]], + [None, None, 1], [None, 1], [None, 1]], 'dtypes': [ 'int64', 'int64', 'int64', 'float32', 'int64', 'int64'], 'lod_levels': [0, 0, 0, 0, 0, 0], @@ -120,7 +120,7 @@ def create_model(bert_config, is_training=False): input_fields = { 'names': ['src_ids', 'pos_ids', 'sent_ids', 'input_mask', 'unique_id'], 'shapes': [[None, None], [None, None], [None, None], - [None, args.max_seq_len, 1], [None, 1]], + [None, None, 1], [None, 1]], 'dtypes': [ 'int64', 'int64', 'int64', 'float32', 'int64'], 'lod_levels': [0, 0, 0, 0, 0], -- GitLab