未验证 提交 33a279eb 编写于 作者: S smallv0221 提交者: GitHub

Fix dureader api bugs (#5021)

* update lrscheduler

* minor fix

* add pre-commit

* minor fix

* Add __len__ to squad dataset

* minor fix

* Add dureader robust prototype

* dataset implement

* minor fix

* fix var name

* add dureader-yesno train script and dataset

* add readme and fix md5sum

* integrete dureader datasets

* change var names: segment to mode, root to data_file

* minor fix

* update var name

* Fix api bugs
上级 a9be5bbb
...@@ -122,7 +122,7 @@ def do_train(args): ...@@ -122,7 +122,7 @@ def do_train(args):
doc_stride=args.doc_stride, doc_stride=args.doc_stride,
max_query_length=args.max_query_length, max_query_length=args.max_query_length,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
segment='train') mode='train')
train_batch_sampler = paddle.io.DistributedBatchSampler( train_batch_sampler = paddle.io.DistributedBatchSampler(
train_ds, batch_size=args.batch_size, shuffle=True) train_ds, batch_size=args.batch_size, shuffle=True)
...@@ -147,7 +147,7 @@ def do_train(args): ...@@ -147,7 +147,7 @@ def do_train(args):
doc_stride=args.doc_stride, doc_stride=args.doc_stride,
max_query_length=args.max_query_length, max_query_length=args.max_query_length,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
segment='dev') mode='dev')
dev_batch_sampler = paddle.io.BatchSampler( dev_batch_sampler = paddle.io.BatchSampler(
dev_ds, batch_size=args.batch_size, shuffle=False) dev_ds, batch_size=args.batch_size, shuffle=False)
...@@ -170,7 +170,7 @@ def do_train(args): ...@@ -170,7 +170,7 @@ def do_train(args):
doc_stride=args.doc_stride, doc_stride=args.doc_stride,
max_query_length=args.max_query_length, max_query_length=args.max_query_length,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
segment='test') mode='test')
test_batch_sampler = paddle.io.BatchSampler( test_batch_sampler = paddle.io.BatchSampler(
test_ds, batch_size=args.batch_size, shuffle=False) test_ds, batch_size=args.batch_size, shuffle=False)
......
...@@ -90,7 +90,7 @@ def evaluate(model, data_loader, args): ...@@ -90,7 +90,7 @@ def evaluate(model, data_loader, args):
end_logits=end_logits)) end_logits=end_logits))
all_predictions, all_nbest_json, scores_diff_json = compute_predictions( all_predictions, all_nbest_json, scores_diff_json = compute_predictions(
data_loader.dataset.examples, data_loader.dataset.data, all_results, data_loader.dataset.examples, data_loader.dataset.features, all_results,
args.n_best_size, args.max_answer_length, args.do_lower_case, args.n_best_size, args.max_answer_length, args.do_lower_case,
args.version_2_with_negative, args.null_score_diff_threshold, args.version_2_with_negative, args.null_score_diff_threshold,
args.verbose, data_loader.dataset.tokenizer) args.verbose, data_loader.dataset.tokenizer)
......
...@@ -165,7 +165,7 @@ class DuReader(SQuAD): ...@@ -165,7 +165,7 @@ class DuReader(SQuAD):
examples.append(example) examples.append(example)
self.examples = examples[:2000] self.examples = examples
class DuReaderRobust(SQuAD): class DuReaderRobust(SQuAD):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册