提交 271883bf 编写于 作者: K kinghuin 提交者: wuzewu

support ChineseGLUE (#217)

* machine reading comprehension
上级 8419f9d5
...@@ -113,3 +113,8 @@ dmypy.json ...@@ -113,3 +113,8 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
# pycharm
.DS_Store
.idea/
FETCH_HEAD
\ No newline at end of file
...@@ -33,8 +33,7 @@ import time ...@@ -33,8 +33,7 @@ import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddlehub as hub import paddlehub as hub
import evaluate_v1 from paddlehub.finetune.task.reading_comprehension_task import write_predictions
import evaluate_v2
hub.common.logger.logger.setLevel("INFO") hub.common.logger.logger.setLevel("INFO")
...@@ -54,354 +53,36 @@ parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=True, ...@@ -54,354 +53,36 @@ parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=True,
parser.add_argument("--max_answer_length", type=int, default=30, help="Max answer length.") parser.add_argument("--max_answer_length", type=int, default=30, help="Max answer length.")
parser.add_argument("--n_best_size", type=int, default=20, help="The total number of n-best predictions to generate in the nbest_predictions.json output file.") parser.add_argument("--n_best_size", type=int, default=20, help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
parser.add_argument("--null_score_diff_threshold", type=float, default=0.0, help="If null_score - best_non_null is greater than the threshold predict null.") parser.add_argument("--null_score_diff_threshold", type=float, default=0.0, help="If null_score - best_non_null is greater than the threshold predict null.")
parser.add_argument("--version_2_with_negative", type=ast.literal_eval, default=False, help="If true, the SQuAD examples contain some that do not have an answer. If using squad v2.0, it should be set true.") parser.add_argument("--dataset", type=str, default="squad", help="Support squad, squad2.0, drcd and cmrc2018")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable. # yapf: enable.
def write_predictions(
all_examples,
all_features,
all_results,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
n_best_size=20,
max_answer_length=30,
do_lower_case=True,
version_2_with_negative=False,
null_score_diff_threshold=0.0,
):
"""Write final predictions to the json file and log-odds of null if needed."""
print("Writing predictions to: %s" % (output_prediction_file))
print("Writing nbest to: %s" % (output_nbest_file))
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
example_index_to_features[feature.example_index].append(feature)
unique_id_to_result = {}
for result in all_results:
unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction", [
"feature_index", "start_index", "end_index", "start_logit",
"end_logit"
])
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict()
for (example_index, example) in enumerate(all_examples):
features = example_index_to_features[example_index]
prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
min_null_feature_index = 0 # the paragraph slice with min mull score
null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id]
start_indexes = get_best_indexes(result.start_logits, n_best_size)
end_indexes = get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[
0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes:
for end_index in end_indexes:
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index >= len(feature.tokens):
continue
if end_index >= len(feature.tokens):
continue
if start_index not in feature.token_to_orig_map:
continue
if end_index not in feature.token_to_orig_map:
continue
if not feature.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
feature_index=feature_index,
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]))
if version_2_with_negative:
prelim_predictions.append(
_PrelimPrediction(
feature_index=min_null_feature_index,
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit))
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True)
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_logit", "end_logit"])
seen_predictions = {}
nbest = []
for pred in prelim_predictions:
if len(nbest) >= n_best_size:
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(
pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(
orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, do_lower_case)
if final_text in seen_predictions:
continue
seen_predictions[final_text] = True
else:
final_text = ""
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't inlude the empty option in the n-best, inlcude it
if version_2_with_negative:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1
total_scores = []
best_non_null_entry = None
for entry in nbest:
total_scores.append(entry.start_logit + entry.end_logit)
if not best_non_null_entry:
if entry.text:
best_non_null_entry = entry
# debug
if best_non_null_entry is None:
print("Emmm..., sth wrong")
probs = compute_softmax(total_scores)
nbest_json = []
for (i, entry) in enumerate(nbest):
output = collections.OrderedDict()
output["text"] = entry.text
output["probability"] = probs[i]
output["start_logit"] = entry.start_logit
output["end_logit"] = entry.end_logit
nbest_json.append(output)
assert len(nbest_json) >= 1
if not version_2_with_negative:
all_predictions[example.qas_id] = nbest_json[0]["text"]
else:
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json
with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
def get_final_text(pred_text, orig_text, do_lower_case):
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heruistic between
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def _strip_spaces(text):
ns_chars = []
ns_to_s_map = collections.OrderedDict()
for (i, c) in enumerate(text):
if c == " ":
continue
ns_to_s_map[len(ns_chars)] = i
ns_chars.append(c)
ns_text = "".join(ns_chars)
return (ns_text, ns_to_s_map)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer = hub.reader.tokenization.BasicTokenizer(
do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text))
start_position = tok_text.find(pred_text)
if start_position == -1:
return orig_text
end_position = start_position + len(pred_text) - 1
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
if len(orig_ns_text) != len(tok_ns_text):
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map = {}
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
tok_s_to_ns_map[tok_index] = i
orig_start_position = None
if start_position in tok_s_to_ns_map:
ns_start_position = tok_s_to_ns_map[start_position]
if ns_start_position in orig_ns_to_s_map:
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None:
return orig_text
orig_end_position = None
if end_position in tok_s_to_ns_map:
ns_end_position = tok_s_to_ns_map[end_position]
if ns_end_position in orig_ns_to_s_map:
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None:
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
return output_text
def get_best_indexes(logits, n_best_size):
"""Get the n-best logits from a list."""
index_and_score = sorted(
enumerate(logits), key=lambda x: x[1], reverse=True)
best_indexes = []
for i in range(len(index_and_score)):
if i >= n_best_size:
break
best_indexes.append(index_and_score[i][0])
return best_indexes
def compute_softmax(scores):
"""Compute softmax probability over raw logits."""
if not scores:
return []
max_score = None
for score in scores:
if max_score is None or score > max_score:
max_score = score
exp_scores = []
total_sum = 0.0
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = []
for score in exp_scores:
probs.append(score / total_sum)
return probs
if __name__ == '__main__': if __name__ == '__main__':
# Load Paddlehub bert_uncased_L-12_H-768_A-12 pretrained model # Download dataset and use ReadingComprehensionReader to read dataset
module = hub.Module(name="bert_uncased_L-12_H-768_A-12") if args.dataset == "squad":
# module = hub.Module(module_dir=["./bert_uncased_L-12_H-768_A-12.hub_module"]) dataset = hub.dataset.SQUAD(version_2_with_negative=False)
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset == "squad2.0" or args.dataset == "squad2":
args.dataset = "squad2.0"
dataset = hub.dataset.SQUAD(version_2_with_negative=True)
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset == "drcd":
dataset = hub.dataset.DRCD()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
elif args.dataset == "cmrc2018":
dataset = hub.dataset.CMRC2018()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
else:
raise Exception(
"Only support datasets: squad, squad2.0, drcd and cmrc2018")
inputs, outputs, program = module.context( inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len) trainable=True, max_seq_len=args.max_seq_len)
# Download dataset and use ReadingComprehensionReader to read dataset
dataset = hub.dataset.SQUAD(
version_2_with_negative=args.version_2_with_negative)
reader = hub.reader.ReadingComprehensionReader( reader = hub.reader.ReadingComprehensionReader(
dataset=dataset, dataset=dataset,
vocab_path=module.get_vocab_path(), vocab_path=module.get_vocab_path(),
max_seq_length=args.max_seq_len, max_seq_len=args.max_seq_len,
doc_stride=128, doc_stride=128,
max_query_length=64) max_query_length=64)
...@@ -444,82 +125,5 @@ if __name__ == '__main__': ...@@ -444,82 +125,5 @@ if __name__ == '__main__':
config=config) config=config)
# Data to be predicted # Data to be predicted
data = dataset.predict_examples data = dataset.dev_examples[97:98]
reading_comprehension_task.predict(data=data)
features = reader.convert_examples_to_features(
examples=data, is_training=False)
run_states = reading_comprehension_task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"])
all_results = []
for batch_idx, batch_result in enumerate(results):
np_unique_ids = batch_result[0]
np_start_logits = batch_result[1]
np_end_logits = batch_result[2]
np_num_seqs = batch_result[3]
for idx in range(np_unique_ids.shape[0]):
unique_id = int(np_unique_ids[idx])
start_logits = [float(x) for x in np_start_logits[idx].flat]
end_logits = [float(x) for x in np_end_logits[idx].flat]
all_results.append(
RawResult(
unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
output_prediction_file = os.path.join(args.result_dir, "predictions.json")
output_nbest_file = os.path.join(args.result_dir, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(args.result_dir, "null_odds.json")
write_predictions(
data,
features,
all_results,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
max_answer_length=args.max_answer_length,
n_best_size=args.n_best_size,
version_2_with_negative=args.version_2_with_negative,
null_score_diff_threshold=args.null_score_diff_threshold)
with io.open(dataset.predict_file, 'r', encoding="utf8") as dataset_file:
dataset_json = json.load(dataset_file)
dataset = dataset_json['data']
with io.open(
output_prediction_file, 'r', encoding="utf8") as prediction_file:
predictions = json.load(prediction_file)
if not args.version_2_with_negative:
print(json.dumps(evaluate_v1.evaluate(dataset, predictions)))
else:
with io.open(
output_null_log_odds_file, 'r', encoding="utf8") as odds_file:
na_probs = json.load(odds_file)
# Maps qid to true/false
qid_to_has_ans = evaluate_v2.make_qid_to_has_ans(dataset)
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = evaluate_v2.get_raw_scores(dataset, predictions)
exact_thresh = evaluate_v2.apply_no_ans_threshold(
exact_raw, na_probs, qid_to_has_ans, na_prob_thresh=1.0)
f1_thresh = evaluate_v2.apply_no_ans_threshold(
f1_raw, na_probs, qid_to_has_ans, na_prob_thresh=1.0)
out_eval = evaluate_v2.make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = evaluate_v2.make_eval_dict(
exact_thresh, f1_thresh, qid_list=has_ans_qids)
evaluate_v2.merge_eval(out_eval, has_ans_eval, 'HasAns')
if no_ans_qids:
no_ans_eval = evaluate_v2.make_eval_dict(
exact_thresh, f1_thresh, qid_list=no_ans_qids)
evaluate_v2.merge_eval(out_eval, no_ans_eval, 'NoAns')
evaluate_v2.find_all_best_thresh(out_eval, predictions, exact_raw,
f1_raw, na_probs, qid_to_has_ans)
print(json.dumps(out_eval, indent=4))
...@@ -31,28 +31,42 @@ parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight dec ...@@ -31,28 +31,42 @@ parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight dec
parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy") parser.add_argument("--warmup_proportion", type=float, default=0.0, help="Warmup proportion params for warmup strategy")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint") parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--max_seq_len", type=int, default=384, help="Number of words of the longest seqence.") parser.add_argument("--max_seq_len", type=int, default=384, help="Number of words of the longest seqence.")
parser.add_argument("--null_score_diff_threshold", type=float, default=0.0, help="If null_score - best_non_null is greater than the threshold predict null.")
parser.add_argument("--n_best_size", type=int, default=20,help="The total number of n-best predictions to generate in the ""nbest_predictions.json output file.")
parser.add_argument("--max_answer_length", type=int, default=30,help="The maximum length of an answer that can be generated. This is needed ""because the start and end predictions are not conditioned on one another.")
parser.add_argument("--batch_size", type=int, default=8, help="Total examples' number in batch for training.") parser.add_argument("--batch_size", type=int, default=8, help="Total examples' number in batch for training.")
parser.add_argument("--use_pyreader", type=ast.literal_eval, default=True, help="Whether use pyreader to feed data.") parser.add_argument("--use_pyreader", type=ast.literal_eval, default=False, help="Whether use pyreader to feed data.")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=True, help="Whether use data parallel.") parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
parser.add_argument("--version_2_with_negative", type=ast.literal_eval, default=False, help="If true, the SQuAD examples contain some that do not have an answer. If using squad v2.0, it should be set true.") parser.add_argument("--dataset", type=str, default="squad", help="Support squad, squad2.0, drcd and cmrc2018")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable. # yapf: enable.
if __name__ == '__main__': if __name__ == '__main__':
# Load Paddlehub bert_uncased_L-12_H-768_A-12 pretrained model # Download dataset and use ReadingComprehensionReader to read dataset
module = hub.Module(name="bert_uncased_L-12_H-768_A-12") if args.dataset == "squad":
dataset = hub.dataset.SQUAD(version_2_with_negative=False)
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset == "squad2.0" or args.dataset == "squad2":
args.dataset = "squad2.0"
dataset = hub.dataset.SQUAD(version_2_with_negative=True)
module = hub.Module(name="bert_uncased_L-12_H-768_A-12")
elif args.dataset == "drcd":
dataset = hub.dataset.DRCD()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
elif args.dataset == "cmrc2018":
dataset = hub.dataset.CMRC2018()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
else:
raise Exception(
"Only support datasets: squad, squad2.0, drcd and cmrc2018")
inputs, outputs, program = module.context( inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len) trainable=True, max_seq_len=args.max_seq_len)
# Download dataset and use ReadingComprehensionReader to read dataset
dataset = hub.dataset.SQUAD(
version_2_with_negative=args.version_2_with_negative)
reader = hub.reader.ReadingComprehensionReader( reader = hub.reader.ReadingComprehensionReader(
dataset=dataset, dataset=dataset,
vocab_path=module.get_vocab_path(), vocab_path=module.get_vocab_path(),
max_seq_length=args.max_seq_len, max_seq_len=args.max_seq_len,
doc_stride=128, doc_stride=128,
max_query_length=64) max_query_length=64)
...@@ -76,9 +90,10 @@ if __name__ == '__main__': ...@@ -76,9 +90,10 @@ if __name__ == '__main__':
# Setup runing config for PaddleHub Finetune API # Setup runing config for PaddleHub Finetune API
config = hub.RunConfig( config = hub.RunConfig(
log_interval=10, log_interval=10,
eval_interval=300,
save_ckpt_interval=10000,
use_pyreader=args.use_pyreader, use_pyreader=args.use_pyreader,
use_data_parallel=args.use_data_parallel, use_data_parallel=args.use_data_parallel,
save_ckpt_interval=1000,
use_cuda=args.use_gpu, use_cuda=args.use_gpu,
num_epoch=args.num_epoch, num_epoch=args.num_epoch,
batch_size=args.batch_size, batch_size=args.batch_size,
...@@ -91,7 +106,9 @@ if __name__ == '__main__': ...@@ -91,7 +106,9 @@ if __name__ == '__main__':
data_reader=reader, data_reader=reader,
feature=seq_output, feature=seq_output,
feed_list=feed_list, feed_list=feed_list,
config=config) config=config,
sub_task=args.dataset,
)
# Finetune by PaddleHub's API # Finetune by PaddleHub's API
reading_comprehension_task.finetune() reading_comprehension_task.finetune_and_eval()
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0,1
# Recommending hyper parameters for difference task
# squad: batch_size=8, weight_decay=0, num_epoch=3, max_seq_len=512, lr=5e-5
# squad2.0: batch_size=8, weight_decay=0, num_epoch=3, max_seq_len=512, lr=5e-5
# cmrc2018: batch_size=8, weight_decay=0, num_epoch=2, max_seq_len=512, lr=2.5e-5
# drcd: batch_size=8, weight_decay=0, num_epoch=2, max_seq_len=512, lr=2.5e-5
dataset=cmrc2018
python -u reading_comprehension.py \ python -u reading_comprehension.py \
--batch_size=12 \ --batch_size=8 \
--use_gpu=True \ --use_gpu=True \
--checkpoint_dir="./ckpt_rc" \ --checkpoint_dir=./ckpt_${dataset} \
--learning_rate=3e-5 \ --learning_rate=2.5e-5 \
--weight_decay=0.01 \ --weight_decay=0.01 \
--warmup_proportion=0.1 \ --warmup_proportion=0.1 \
--num_epoch=2 \ --num_epoch=2 \
--max_seq_len=384 \ --max_seq_len=512 \
--use_pyreader=True \ --dataset=${dataset}
--use_data_parallel=True \
--version_2_with_negative=False
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_rc" CKPT_DIR="./ckpt_cmrc2018"
RES_DIR="./result" dataset=cmrc2018
mkdir $RES_DIR
python -u predict.py \ python -u predict.py \
--batch_size=12 \ --batch_size=8 \
--use_gpu=True \ --use_gpu=True \
--dataset=${dataset} \
--checkpoint_dir=${CKPT_DIR} \ --checkpoint_dir=${CKPT_DIR} \
--learning_rate=3e-5 \ --learning_rate=2.5e-5 \
--weight_decay=0.01 \ --weight_decay=0.01 \
--warmup_proportion=0.1 \ --warmup_proportion=0.1 \
--num_epoch=1 \ --num_epoch=1 \
--max_seq_len=384 \ --max_seq_len=512 \
--use_pyreader=False \ --use_pyreader=False \
--use_data_parallel=False \ --use_data_parallel=False
--version_2_with_negative=False \
--result_dir=${RES_DIR}
export FLAGS_eager_delete_tensor_gb=0.0 export FLAGS_eager_delete_tensor_gb=0.0
export CUDA_VISIBLE_DEVICES=0
CKPT_DIR="./ckpt_sequence_label" CKPT_DIR="./ckpt_sequence_label"
python -u sequence_label.py \ python -u sequence_label.py \
......
...@@ -37,7 +37,7 @@ args = parser.parse_args() ...@@ -37,7 +37,7 @@ args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
# Load Paddlehub ERNIE pretrained model # Load Paddlehub ERNIE pretrained model
module = hub.Module(name="ernie") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
inputs, outputs, program = module.context( inputs, outputs, program = module.context(
trainable=True, max_seq_len=args.max_seq_len) trainable=True, max_seq_len=args.max_seq_len)
...@@ -69,6 +69,9 @@ if __name__ == '__main__': ...@@ -69,6 +69,9 @@ if __name__ == '__main__':
# Setup runing config for PaddleHub Finetune API # Setup runing config for PaddleHub Finetune API
config = hub.RunConfig( config = hub.RunConfig(
log_interval=10,
eval_interval=300,
save_ckpt_interval=10000,
use_data_parallel=args.use_data_parallel, use_data_parallel=args.use_data_parallel,
use_pyreader=args.use_pyreader, use_pyreader=args.use_pyreader,
use_cuda=args.use_gpu, use_cuda=args.use_gpu,
......
...@@ -4,19 +4,24 @@ ...@@ -4,19 +4,24 @@
其中分类任务可以分为两大类: 其中分类任务可以分为两大类:
* **单句分类** * **单句分类**
- ChnSentiCorp - ChineseGLUE-IFLYTEK
- ChineseGLUE-THUCNEWS
- GLUE-Cola - GLUE-Cola
- GLUE-SST2 - GLUE-SST2
- ChnSentiCorp
* **句对分类** * **句对分类**
- LCQMC - ChineseGLUE-LCQMC
- NLPCC-DBQA - ChineseGLUE-INEWS
- ChineseGLUE-TNEWS
- ChinesGLUE-BQ
- ChineseGLUE-XNLI_zh
- GLUE-MNLI - GLUE-MNLI
- GLUE-QQP - GLUE-QQP
- GLUE-QNLI - GLUE-QNLI
- GLUE-STS-B - GLUE-STS-B
- GLUE-MRPC - GLUE-MRPC
- GLUE-RTE - GLUE-RTE
- NLPCC-DBQA
- XNLI - XNLI
## 如何开始Finetune ## 如何开始Finetune
......
...@@ -5,11 +5,36 @@ export CUDA_VISIBLE_DEVICES=0 ...@@ -5,11 +5,36 @@ export CUDA_VISIBLE_DEVICES=0
DATASET="chnsenticorp" DATASET="chnsenticorp"
CKPT_DIR="./ckpt_${DATASET}" CKPT_DIR="./ckpt_${DATASET}"
python -u text_classifier.py \
--batch_size=24 \
--use_gpu=True \
--dataset=${DATASET} \
--checkpoint_dir=${CKPT_DIR} \
--learning_rate=5e-5 \
--weight_decay=0.01 \
--max_seq_len=128 \
--num_epoch=3 \
--use_pyreader=True \
--use_data_parallel=True \
--use_taskid=False
# Recommending hyper parameters for difference task # Recommending hyper parameters for difference task
# for ChineseGLUE:
# TNews: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# LCQMC: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# XNLI_zh: batch_size=32, weight_decay=0, num_epoch=2, max_seq_len=128, lr=5e-5
# INEWS: batch_size=4, weight_decay=0, num_epoch=3, max_seq_len=512, lr=5e-5
# DRCD: see demo: reading-comprehension
# CMRC2018: see demo: reading-comprehension
# BQ: batch_size=32, weight_decay=0, num_epoch=2, max_seq_len=100, lr=1e-5
# MSRANER: see demo: sequence-labeling
# THUCNEWS: batch_size=8, weight_decay=0, num_epoch=2, max_seq_len=512, lr=5e-5
# IFLYTEKDATA: batch_size=16, weight_decay=0, num_epoch=5, max_seq_len=256, lr=1e-5
# for other tasks:
# ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5 # ChnSentiCorp: batch_size=24, weight_decay=0.01, num_epoch=3, max_seq_len=128, lr=5e-5
# NLPCC_DBQA: batch_size=8, weight_decay=0.01, num_epoch=3, max_seq_len=512, lr=2e-5 # NLPCC_DBQA: batch_size=8, weight_decay=0.01, num_epoch=3, max_seq_len=512, lr=2e-5
# LCQMC: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=2e-5 # LCQMC: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=2e-5
# TNews: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# QQP: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5 # QQP: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# QNLI: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5 # QNLI: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# SST-2: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5 # SST-2: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
...@@ -22,23 +47,10 @@ CKPT_DIR="./ckpt_${DATASET}" ...@@ -22,23 +47,10 @@ CKPT_DIR="./ckpt_${DATASET}"
# mnli_mm: dev and test in mismatched dataset. # mnli_mm: dev and test in mismatched dataset.
# The difference can be seen in https://www.nyu.edu/projects/bowman/multinli/paper.pdf. # The difference can be seen in https://www.nyu.edu/projects/bowman/multinli/paper.pdf.
# If you are not sure which one to pick, just use mnli or mnli_m. # If you are not sure which one to pick, just use mnli or mnli_m.
# XNLI: batch_size=32, weight_decay=0, num_epoch=2, max_seq_len=128, lr=5e-5 # XNLI: batch_size=32, weight_decay=0, num_epoch=3, max_seq_len=128, lr=5e-5
# Specify the language with an underscore like xnli_zh. # Specify the language with an underscore like xnli_zh.
# ar- Arabic bg- Bulgarian de- German # ar- Arabic bg- Bulgarian de- German
# el- Greek en- English es- Spanish # el- Greek en- English es- Spanish
# fr- French hi- Hindi ru- Russian # fr- French hi- Hindi ru- Russian
# sw- Swahili th- Thai tr- Turkish # sw- Swahili th- Thai tr- Turkish
# ur- Urdu vi- Vietnamese zh- Chinese (Simplified) # ur- Urdu vi- Vietnamese zh- Chinese (Simplified)
python -u text_classifier.py \
--batch_size=24 \
--use_gpu=True \
--dataset=${DATASET} \
--checkpoint_dir=${CKPT_DIR} \
--learning_rate=5e-5 \
--weight_decay=0.01 \
--max_seq_len=128 \
--num_epoch=3 \
--use_pyreader=True \
--use_data_parallel=True \
--use_taskid=False \
...@@ -43,20 +43,36 @@ if __name__ == '__main__': ...@@ -43,20 +43,36 @@ if __name__ == '__main__':
# Download dataset and use ClassifyReader to read dataset # Download dataset and use ClassifyReader to read dataset
if args.dataset.lower() == "chnsenticorp": if args.dataset.lower() == "chnsenticorp":
dataset = hub.dataset.ChnSentiCorp() dataset = hub.dataset.ChnSentiCorp()
module = hub.Module(name="ernie") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"] metrics_choices = ["acc"]
elif args.dataset.lower() == "tnews": elif args.dataset.lower() == "tnews":
dataset = hub.dataset.TNews() dataset = hub.dataset.TNews()
module = hub.Module(name="ernie") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"] metrics_choices = ["acc", "f1"]
elif args.dataset.lower() == "nlpcc_dbqa": elif args.dataset.lower() == "nlpcc_dbqa":
dataset = hub.dataset.NLPCC_DBQA() dataset = hub.dataset.NLPCC_DBQA()
module = hub.Module(name="ernie") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"] metrics_choices = ["acc"]
elif args.dataset.lower() == "lcqmc": elif args.dataset.lower() == "lcqmc":
dataset = hub.dataset.LCQMC() dataset = hub.dataset.LCQMC()
module = hub.Module(name="ernie") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"] metrics_choices = ["acc"]
elif args.dataset.lower() == 'inews':
dataset = hub.dataset.INews()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"]
elif args.dataset.lower() == 'bq':
dataset = hub.dataset.BQ()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"]
elif args.dataset.lower() == 'thucnews':
dataset = hub.dataset.THUCNEWS()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"]
elif args.dataset.lower() == 'iflytek':
dataset = hub.dataset.IFLYTEK()
module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc", "f1"]
elif args.dataset.lower() == "mrpc": elif args.dataset.lower() == "mrpc":
dataset = hub.dataset.GLUE("MRPC") dataset = hub.dataset.GLUE("MRPC")
if args.use_taskid: if args.use_taskid:
...@@ -116,7 +132,7 @@ if __name__ == '__main__': ...@@ -116,7 +132,7 @@ if __name__ == '__main__':
metrics_choices = ["acc"] metrics_choices = ["acc"]
elif args.dataset.lower().startswith("xnli"): elif args.dataset.lower().startswith("xnli"):
dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:]) dataset = hub.dataset.XNLI(language=args.dataset.lower()[-2:])
module = hub.Module(name="bert_multi_cased_L-12_H-768_A-12") module = hub.Module(name="roberta_wwm_ext_chinese_L-24_H-1024_A-16")
metrics_choices = ["acc"] metrics_choices = ["acc"]
else: else:
raise ValueError("%s dataset is not defined" % args.dataset) raise ValueError("%s dataset is not defined" % args.dataset)
...@@ -140,7 +156,7 @@ if __name__ == '__main__': ...@@ -140,7 +156,7 @@ if __name__ == '__main__':
pooled_output = outputs["pooled_output"] pooled_output = outputs["pooled_output"]
# Setup feed list for data feeder # Setup feed list for data feeder
# Must feed all the tensor of ERNIE's module need # Must feed all the tensor of module need
feed_list = [ feed_list = [
inputs["input_ids"].name, inputs["input_ids"].name,
inputs["position_ids"].name, inputs["position_ids"].name,
......
...@@ -24,6 +24,12 @@ from .squad import SQUAD ...@@ -24,6 +24,12 @@ from .squad import SQUAD
from .xnli import XNLI from .xnli import XNLI
from .glue import GLUE from .glue import GLUE
from .tnews import TNews from .tnews import TNews
from .inews import INews
from .drcd import DRCD
from .cmrc2018 import CMRC2018
from .bq import BQ
from .iflytek import IFLYTEK
from .thucnews import THUCNEWS
# CV Dataset # CV Dataset
from .dogcat import DogCatDataset as DogCat from .dogcat import DogCatDataset as DogCat
......
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import io
import os
import csv
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/bq.tar.gz"
class BQ(HubDataset):
def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "bq")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, "train.txt")
self.train_examples = self._read_file(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.txt")
self.dev_examples = self._read_file(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.txt")
self.test_examples = self._read_file(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return ["0", "1"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_file(self, input_file):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file:
examples = []
for (i, line) in enumerate(file):
data = line.strip().split("\t")
example = InputExample(
guid=i, label=data[2], text_a=data[0], text_b=data[1])
examples.append(example)
return examples
if __name__ == "__main__":
ds = BQ()
for e in ds.get_dev_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run BERT on cmrc2018"""
import json
import os
import sys
from paddlehub.reader import tokenization
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/cmrc2018.tar.gz"
SPIECE_UNDERLINE = '▁'
class CMRC2018Example(object):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def __init__(self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
start_position=None,
end_position=None):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.end_position = end_position
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
s += ", question_text: %s" % (tokenization.printable_text(
self.question_text))
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position is not None:
s += ", orig_answer_text: %s" % (self.orig_answer_text)
s += ", start_position: %d" % (self.start_position)
s += ", end_position: %d" % (self.end_position)
return s
class CMRC2018(object):
"""A single set of features of data."""
def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "cmrc2018")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_dev_examples()
self._load_test_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, "cmrc2018_train.json")
self.train_examples = self._read_json(self.train_file, is_training=True)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "cmrc2018_dev.json")
self.dev_examples = self._read_json(self.dev_file, is_training=False)
def _load_test_examples(self):
pass
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return []
def _read_json(self, input_file, is_training=False):
"""Read a cmrc2018 json file into a list of CRCDExample."""
def _is_chinese_char(cp):
if ((cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF)
or (cp >= 0x20000 and cp <= 0x2A6DF)
or (cp >= 0x2A700 and cp <= 0x2B73F)
or (cp >= 0x2B740 and cp <= 0x2B81F)
or (cp >= 0x2B820 and cp <= 0x2CEAF)
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F)):
return True
return False
def _is_punctuation(c):
if c in [
'。', ',', '!', '?', ';', '、', ':', '(', ')', '-', '~', '「',
'《', '》', ',', '」', '"', '“', '”', '$', '『', '』', '—', ';',
'。', '(', ')', '-', '~', '。', '‘', '’', '─', ':'
]:
return True
return False
def _tokenize_chinese_chars(text):
"""Because Chinese (and Japanese Kanji and Korean Hanja) does not have whitespace
characters, we add spaces around every character in the CJK Unicode range before
applying WordPiece. This means that Chinese is effectively character-tokenized.
Note that the CJK Unicode block only includes Chinese-origin characters and
does not include Hangul Korean or Katakana/Hiragana Japanese, which are tokenized
with whitespace+WordPiece like all other languages."""
output = []
for char in text:
cp = ord(char)
if _is_chinese_char(cp) or _is_punctuation(char):
if len(output) > 0 and output[-1] != SPIECE_UNDERLINE:
output.append(SPIECE_UNDERLINE)
output.append(char)
output.append(SPIECE_UNDERLINE)
else:
output.append(char)
return "".join(output)
def is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(
c) == 0x202F or ord(c) == 0x3000 or c == SPIECE_UNDERLINE:
return True
return False
examples = []
drop = 0
with open(input_file, "r") as reader:
input_data = json.load(reader)["data"]
for entry in input_data:
for paragraph in entry["paragraphs"]:
paragraph_text = paragraph["context"]
context = _tokenize_chinese_chars(paragraph_text)
doc_tokens = []
char_to_word_offset = []
prev_is_whitespace = True
for c in context:
if is_whitespace(c):
prev_is_whitespace = True
else:
if prev_is_whitespace:
doc_tokens.append(c)
else:
doc_tokens[-1] += c
prev_is_whitespace = False
if c != SPIECE_UNDERLINE:
char_to_word_offset.append(len(doc_tokens) - 1)
for qa in paragraph["qas"]:
qas_id = qa["id"]
question_text = qa["question"]
# Only select the first answer
answer = qa["answers"][0]
orig_answer_text = answer["text"]
answer_offset = answer["answer_start"]
while paragraph_text[answer_offset] in [
" ", "\t", "\r", "\n", "。", ",", ":", ":", ".", ","
]:
answer_offset += 1
start_position = char_to_word_offset[answer_offset]
answer_length = len(orig_answer_text)
end_offset = answer_offset + answer_length - 1
if end_offset >= len(char_to_word_offset):
end_offset = len(char_to_word_offset) - 1
end_position = char_to_word_offset[end_offset]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
if is_training:
actual_text = "".join(
doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = "".join(
tokenization.whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1:
drop += 1
# logger.warning((actual_text, " vs ",
# cleaned_answer_text, " in ", qa))
continue
example = CMRC2018Example(
qas_id=qas_id,
question_text=question_text,
doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text,
start_position=start_position,
end_position=end_position)
examples.append(example)
logger.warning("%i bad examples has been dropped" % drop)
return examples
if __name__ == "__main__":
print("begin")
ds = CMRC2018()
print("train")
examples = ds.get_train_examples()
for index, e in enumerate(examples):
if index < 10:
print(e)
print("dev")
examples = ds.get_dev_examples()
for index, e in enumerate(examples):
if index < 10:
print(e)
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run BERT on DRCD"""
import json
import os
import sys
from paddlehub.reader import tokenization
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/drcd.tar.gz"
SPIECE_UNDERLINE = '▁'
class DRCDExample(object):
"""A single training/test example for simple sequence classification.
For examples without an answer, the start and end position are -1.
"""
def __init__(self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
start_position=None,
end_position=None,
is_impossible=False):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.end_position = end_position
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
s += ", question_text: %s" % (tokenization.printable_text(
self.question_text))
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position is not None:
s += ", orig_answer_text: %s" % (self.orig_answer_text)
s += ", start_position: %d" % (self.start_position)
s += ", end_position: %d" % (self.end_position)
return s
class DRCD(object):
"""A single set of features of data."""
def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "drcd")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_dev_examples()
self._load_test_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, "DRCD_training.json")
self.train_examples = self._read_json(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "DRCD_dev.json")
self.dev_examples = self._read_json(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "DRCD_test.json")
self.test_examples = self._read_json(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def _read_json(self, input_file):
"""Read a DRCD json file into a list of CRCDExample."""
def _is_chinese_char(cp):
if ((cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF)
or (cp >= 0x20000 and cp <= 0x2A6DF)
or (cp >= 0x2A700 and cp <= 0x2B73F)
or (cp >= 0x2B740 and cp <= 0x2B81F)
or (cp >= 0x2B820 and cp <= 0x2CEAF)
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F)):
return True
return False
def _is_punctuation(c):
if c in [
'。', ',', '!', '?', ';', '、', ':', '(', ')', '-', '~', '「',
'《', '》', ',', '」', '"', '“', '”', '$', '『', '』', '—', ';',
'。', '(', ')', '-', '~', '。', '‘', '’', '─', ':'
]:
return True
return False
def _tokenize_chinese_chars(text):
"""Because Chinese (and Japanese Kanji and Korean Hanja) does not have whitespace
characters, we add spaces around every character in the CJK Unicode range before
applying WordPiece. This means that Chinese is effectively character-tokenized.
Note that the CJK Unicode block only includes Chinese-origin characters and
does not include Hangul Korean or Katakana/Hiragana Japanese, which are tokenized
with whitespace+WordPiece like all other languages."""
output = []
for char in text:
cp = ord(char)
if _is_chinese_char(cp) or _is_punctuation(char):
if len(output) > 0 and output[-1] != SPIECE_UNDERLINE:
output.append(SPIECE_UNDERLINE)
output.append(char)
output.append(SPIECE_UNDERLINE)
else:
output.append(char)
return "".join(output)
def is_whitespace(c):
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(
c) == 0x202F or ord(c) == 0x3000 or c == SPIECE_UNDERLINE:
return True
return False
examples = []
with open(input_file, "r") as reader:
input_data = json.load(reader)["data"]
for entry in input_data:
for paragraph in entry["paragraphs"]:
paragraph_text = paragraph["context"]
context = _tokenize_chinese_chars(paragraph_text)
doc_tokens = []
char_to_word_offset = []
prev_is_whitespace = True
for c in context:
if is_whitespace(c):
prev_is_whitespace = True
else:
if prev_is_whitespace:
doc_tokens.append(c)
else:
doc_tokens[-1] += c
prev_is_whitespace = False
if c != SPIECE_UNDERLINE:
char_to_word_offset.append(len(doc_tokens) - 1)
for qa in paragraph["qas"]:
qas_id = qa["id"]
question_text = qa["question"]
# Only select the first answer
answer = qa["answers"][0]
orig_answer_text = answer["text"]
answer_offset = answer["answer_start"]
while paragraph_text[answer_offset] in [
" ", "\t", "\r", "\n", "。", ",", ":", ":", ".", ","
]:
answer_offset += 1
start_position = char_to_word_offset[answer_offset]
answer_length = len(orig_answer_text)
end_position = char_to_word_offset[answer_offset +
answer_length - 1]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
actual_text = "".join(
doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = "".join(
tokenization.whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logger.warning((actual_text, " vs ",
cleaned_answer_text, " in ", qa))
continue
example = DRCDExample(
qas_id=qas_id,
question_text=question_text,
doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text,
start_position=start_position,
end_position=end_position)
examples.append(example)
return examples
if __name__ == "__main__":
ds = DRCD()
print("train")
examples = ds.get_train_examples()
for index, e in enumerate(examples):
if index < 10:
print(e)
print("dev")
examples = ds.get_dev_examples()
for index, e in enumerate(examples):
if index < 10:
print(e)
print("test")
examples = ds.get_test_examples()
for index, e in enumerate(examples):
if index < 10:
print(e)
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import io
import os
import csv
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/iflytek.tar.gz"
class IFLYTEK(HubDataset):
def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "iflytek")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, "train.txt")
self.train_examples = self._read_file(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.txt")
self.dev_examples = self._read_file(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.txt")
self.test_examples = self._read_file(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return [str(i) for i in range(119)]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_file(self, input_file):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file:
examples = []
for (i, line) in enumerate(file):
data = line.strip().split("_!_")
try:
example = InputExample(
guid=i, label=str(data[0]), text_a=data[1], text_b=None)
examples.append(example)
except:
pass
return examples
if __name__ == "__main__":
ds = IFLYTEK()
for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import io
import os
import csv
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/inews.tar.gz"
class INews(HubDataset):
"""
INews is a sentiment analysis dataset for Internet News
"""
def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "inews")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, "train.txt")
self.train_examples = self._read_file(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.txt")
self.dev_examples = self._read_file(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.txt")
self.test_examples = self._read_file(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return ["0", "1", "2"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_file(self, input_file):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file:
examples = []
for (i, line) in enumerate(file):
if i == 0:
continue
data = line.strip().split("_!_")
example = InputExample(
guid=i, label=data[0], text_a=data[2], text_b=data[3])
examples.append(example)
return examples
if __name__ == "__main__":
ds = INews()
for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
...@@ -76,42 +76,50 @@ class SQUAD(object): ...@@ -76,42 +76,50 @@ class SQUAD(object):
url=_DATA_URL, save_path=DATA_HOME, print_progress=True) url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else: else:
logger.info("Dataset {} already cached.".format(self.dataset_dir)) logger.info("Dataset {} already cached.".format(self.dataset_dir))
self.version_2_with_negative = version_2_with_negative
self._load_train_examples(version_2_with_negative, is_training=True) self._load_train_examples(version_2_with_negative, if_has_answer=True)
self._load_predict_examples(version_2_with_negative, is_training=False) self._load_dev_examples(version_2_with_negative, if_has_answer=True)
def _load_train_examples(self, def _load_train_examples(self,
version_2_with_negative=False, version_2_with_negative=False,
is_training=True): if_has_answer=True):
if not version_2_with_negative: if not version_2_with_negative:
self.train_file = os.path.join(self.dataset_dir, "train-v1.1.json") self.train_file = os.path.join(self.dataset_dir, "train-v1.1.json")
else: else:
self.train_file = os.path.join(self.dataset_dir, "train-v2.0.json") self.train_file = os.path.join(self.dataset_dir, "train-v2.0.json")
self.train_examples = self._read_json(self.train_file, is_training, self.train_examples = self._read_json(self.train_file, if_has_answer,
version_2_with_negative) version_2_with_negative)
def _load_predict_examples(self, def _load_dev_examples(self,
version_2_with_negative=False, version_2_with_negative=False,
is_training=False): if_has_answer=True):
if not version_2_with_negative: if not version_2_with_negative:
self.predict_file = os.path.join(self.dataset_dir, "dev-v1.1.json") self.dev_file = os.path.join(self.dataset_dir, "dev-v1.1.json")
else: else:
self.predict_file = os.path.join(self.dataset_dir, "dev-v2.0.json") self.dev_file = os.path.join(self.dataset_dir, "dev-v2.0.json")
self.dev_examples = self._read_json(self.dev_file, if_has_answer,
version_2_with_negative)
self.predict_examples = self._read_json(self.predict_file, is_training, def _load_test_examples(self,
version_2_with_negative) version_2_with_negative=False,
is_training=False):
self.test_file = None
logger.error("not test_file")
def get_train_examples(self): def get_train_examples(self):
return self.train_examples return self.train_examples
def get_dev_examples(self): def get_dev_examples(self):
return [] return self.dev_examples
def get_test_examples(self): def get_test_examples(self):
return [] return []
def _read_json(self, input_file, is_training, def _read_json(self,
input_file,
if_has_answer,
version_2_with_negative=False): version_2_with_negative=False):
"""Read a SQuAD json file into a list of SquadExample.""" """Read a SQuAD json file into a list of SquadExample."""
with open(input_file, "r") as reader: with open(input_file, "r") as reader:
...@@ -148,14 +156,13 @@ class SQUAD(object): ...@@ -148,14 +156,13 @@ class SQUAD(object):
end_position = None end_position = None
orig_answer_text = None orig_answer_text = None
is_impossible = False is_impossible = False
if is_training: if if_has_answer:
if version_2_with_negative: if version_2_with_negative:
is_impossible = qa["is_impossible"] is_impossible = qa["is_impossible"]
if (len(qa["answers"]) != 1) and (not is_impossible): # if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError( # raise ValueError(
"For training, each question should have exactly 1 answer." # "For training, each question should have exactly 1 answer."
) # )
if not is_impossible: if not is_impossible:
answer = qa["answers"][0] answer = qa["answers"][0]
orig_answer_text = answer["text"] orig_answer_text = answer["text"]
...@@ -177,8 +184,8 @@ class SQUAD(object): ...@@ -177,8 +184,8 @@ class SQUAD(object):
orig_answer_text)) orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1: if actual_text.find(cleaned_answer_text) == -1:
logger.warning( logger.warning(
"Could not find answer: '%s' vs. '%s'", "Could not find answer: '%s' vs. '%s'" %
actual_text, cleaned_answer_text) (actual_text, cleaned_answer_text))
continue continue
else: else:
start_position = -1 start_position = -1
...@@ -199,8 +206,8 @@ class SQUAD(object): ...@@ -199,8 +206,8 @@ class SQUAD(object):
if __name__ == "__main__": if __name__ == "__main__":
ds = SQUAD(version_2_with_negative=True) ds = SQUAD(version_2_with_negative=False)
examples = ds.get_dev_examples() examples = ds.get_train_examples()
for index, e in enumerate(examples): for index, e in enumerate(examples):
if index < 10: if index < 10:
print(e) print(e)
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import io
import os
import csv
from paddlehub.dataset import InputExample, HubDataset
from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/thucnews.tar.gz"
class THUCNEWS(HubDataset):
def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "thucnews")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_train_examples(self):
self.train_file = os.path.join(self.dataset_dir, "train.txt")
self.train_examples = self._read_file(self.train_file)
def _load_dev_examples(self):
self.dev_file = os.path.join(self.dataset_dir, "dev.txt")
self.dev_examples = self._read_file(self.dev_file)
def _load_test_examples(self):
self.test_file = os.path.join(self.dataset_dir, "test.txt")
self.test_examples = self._read_file(self.test_file)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
return [str(i) for i in range(14)]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
def _read_file(self, input_file):
"""Reads a tab separated value file."""
with io.open(input_file, "r", encoding="UTF-8") as file:
examples = []
for (i, line) in enumerate(file):
data = line.strip().split("_!_")
try:
example = InputExample(
guid=i, label=data[0], text_a=data[3], text_b=None)
examples.append(example)
except:
pass
return examples
if __name__ == "__main__":
ds = THUCNEWS()
for e in ds.get_train_examples()[:10]:
print("{}\t{}\t{}\t{}".format(e.guid, e.text_a, e.text_b, e.label))
...@@ -32,7 +32,7 @@ _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/tnews.tar.gz" ...@@ -32,7 +32,7 @@ _DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/tnews.tar.gz"
class TNews(HubDataset): class TNews(HubDataset):
""" """
TNews is the chinese news classification dataset on JinRiTouDiao App. TNews is the chinese news classification dataset on Jinri Toutiao App.
""" """
def __init__(self): def __init__(self):
......
...@@ -17,14 +17,8 @@ from __future__ import absolute_import ...@@ -17,14 +17,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time
import paddle.fluid as fluid
import numpy as np import numpy as np
from paddlehub.common.logger import logger
import paddlehub as hub
# Sequence label evaluation functions # Sequence label evaluation functions
def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1): def chunk_eval(np_labels, np_infers, np_lens, tag_num, dev_count=1):
......
# -*- coding: utf-8 -*-
'''
Evaluation script for CMRC 2018
version: v5 - special
Note:
v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets
v5: formatted output, add usage description
v4: fixed segmentation issues
'''
from __future__ import print_function
from collections import OrderedDict
import re
import json
import nltk
import sys
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
print("Downloading nltk punkt")
nltk.download('punkt')
# split Chinese with English
def mixed_segmentation(in_str, rm_punc=False):
in_str = str(in_str).lower().strip()
segs_out = []
temp_str = ""
sp_char = [
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
')', '-', '~', '『', '』'
]
for char in in_str:
if rm_punc and char in sp_char:
continue
if re.search(r'[\u4e00-\u9fa5]', char) or char in sp_char:
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
temp_str = ""
segs_out.append(char)
else:
temp_str += char
# handling last part
if temp_str != "":
ss = nltk.word_tokenize(temp_str)
segs_out.extend(ss)
return segs_out
# remove punctuation
def remove_punctuation(in_str):
in_str = str(in_str).lower().strip()
sp_char = [
'-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', ':',
'?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', '」', '(',
')', '-', '~', '『', '』'
]
out_segs = []
for char in in_str:
if char in sp_char:
continue
else:
out_segs.append(char)
return ''.join(out_segs)
# find longest common string
def find_lcs(s1, s2):
m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)]
mmax = 0
p = 0
for i in range(len(s1)):
for j in range(len(s2)):
if s1[i] == s2[j]:
m[i + 1][j + 1] = m[i][j] + 1
if m[i + 1][j + 1] > mmax:
mmax = m[i + 1][j + 1]
p = i + 1
return s1[p - mmax:p], mmax
def evaluate(ground_truth_file, prediction_file):
f1 = 0
em = 0
total_count = 0
skip_count = 0
for instance in ground_truth_file:
# context_id = instance['context_id'].strip()
# context_text = instance['context_text'].strip()
for para in instance["paragraphs"]:
for qas in para['qas']:
total_count += 1
query_id = qas['id'].strip()
query_text = qas['question'].strip()
answers = [x["text"] for x in qas['answers']]
if query_id not in prediction_file:
print('Unanswered question: {}\n'.format(query_id))
skip_count += 1
continue
prediction = str(prediction_file[query_id])
f1 += calc_f1_score(answers, prediction)
em += calc_em_score(answers, prediction)
f1_score = 100.0 * f1 / total_count
em_score = 100.0 * em / total_count
return f1_score, em_score, total_count, skip_count
def calc_f1_score(answers, prediction):
f1_scores = []
for ans in answers:
ans_segs = mixed_segmentation(ans, rm_punc=True)
prediction_segs = mixed_segmentation(prediction, rm_punc=True)
lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
if lcs_len == 0:
f1_scores.append(0)
continue
precision = 1.0 * lcs_len / len(prediction_segs)
recall = 1.0 * lcs_len / len(ans_segs)
f1 = (2 * precision * recall) / (precision + recall)
f1_scores.append(f1)
return max(f1_scores)
def calc_em_score(answers, prediction):
em = 0
for ans in answers:
ans_ = remove_punctuation(ans)
prediction_ = remove_punctuation(prediction)
if ans_ == prediction_:
em = 1
break
return em
def get_eval(original_file, prediction_file):
F1, EM, TOTAL, SKIP = evaluate(original_file, prediction_file)
AVG = (EM + F1) * 0.5
output_result = OrderedDict()
output_result['AVERAGE'] = AVG
output_result['F1'] = F1
output_result['EM'] = EM
output_result['TOTAL'] = TOTAL
output_result['SKIP'] = SKIP
return output_result
...@@ -138,6 +138,7 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, ...@@ -138,6 +138,7 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs,
main_eval['best_exact_thresh'] = exact_thresh main_eval['best_exact_thresh'] = exact_thresh
main_eval['best_f1'] = best_f1 main_eval['best_f1'] = best_f1
main_eval['best_f1_thresh'] = f1_thresh main_eval['best_f1_thresh'] = f1_thresh
return main_eval
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
...@@ -161,3 +162,28 @@ def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): ...@@ -161,3 +162,28 @@ def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
best_score = cur_score best_score = cur_score
best_thresh = na_probs[qid] best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh return 100.0 * best_score / len(scores), best_thresh
def evaluate(dataset, predictions, na_probs):
qid_to_has_ans = make_qid_to_has_ans(dataset)
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(dataset, predictions)
exact_thresh = apply_no_ans_threshold(
exact_raw, na_probs, qid_to_has_ans, na_prob_thresh=1.0)
f1_thresh = apply_no_ans_threshold(
f1_raw, na_probs, qid_to_has_ans, na_prob_thresh=1.0)
out_eval = make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = make_eval_dict(
exact_thresh, f1_thresh, qid_list=has_ans_qids)
merge_eval(out_eval, has_ans_eval, 'HasAns')
if no_ans_qids:
no_ans_eval = make_eval_dict(
exact_thresh, f1_thresh, qid_list=no_ans_qids)
merge_eval(out_eval, no_ans_eval, 'NoAns')
find_all_best_thresh(out_eval, predictions, exact_raw, f1_raw, na_probs,
qid_to_has_ans)
return out_eval
...@@ -414,11 +414,7 @@ class CombinedStrategy(DefaultStrategy): ...@@ -414,11 +414,7 @@ class CombinedStrategy(DefaultStrategy):
# self.num_examples = {'train': -1, 'dev': -1, 'test': -1} before data_generator # self.num_examples = {'train': -1, 'dev': -1, 'test': -1} before data_generator
data_reader.data_generator( data_reader.data_generator(
batch_size=config.batch_size, phase='train', shuffle=True) batch_size=config.batch_size, phase='train', shuffle=True)
data_reader.data_generator( num_train_examples = data_reader.num_examples['train']
batch_size=config.batch_size, phase='dev', shuffle=False)
data_reader.data_generator(
batch_size=config.batch_size, phase='test', shuffle=False)
num_train_examples = len(data_reader.get_train_examples())
max_train_steps = config.num_epoch * num_train_examples // config.batch_size // dev_count max_train_steps = config.num_epoch * num_train_examples // config.batch_size // dev_count
......
...@@ -165,6 +165,10 @@ class BasicTask(object): ...@@ -165,6 +165,10 @@ class BasicTask(object):
def enter_phase(self, phase): def enter_phase(self, phase):
if phase not in ["train", "val", "dev", "test", "predict", "inference"]: if phase not in ["train", "val", "dev", "test", "predict", "inference"]:
raise RuntimeError() raise RuntimeError()
if phase in ["val", "dev"]:
phase = "dev"
elif phase in ["predict", "inference"]:
phase = "predict"
self._phases.append(phase) self._phases.append(phase)
def exit_phase(self): def exit_phase(self):
...@@ -330,7 +334,7 @@ class BasicTask(object): ...@@ -330,7 +334,7 @@ class BasicTask(object):
def env(self): def env(self):
phase = self.phase phase = self.phase
if phase in ["val", "dev", "test"]: if phase in ["val", "dev", "test"]:
phase = "val" phase = "dev"
if not phase in self._envs: if not phase in self._envs:
self._envs[phase] = RunEnv() self._envs[phase] = RunEnv()
return self._envs[phase] return self._envs[phase]
...@@ -468,18 +472,19 @@ class BasicTask(object): ...@@ -468,18 +472,19 @@ class BasicTask(object):
def _eval_end_event(self, run_states): def _eval_end_event(self, run_states):
eval_scores, eval_loss, run_speed = self._calculate_metrics(run_states) eval_scores, eval_loss, run_speed = self._calculate_metrics(run_states)
self.tb_writer.add_scalar( if 'train' in self._envs:
tag="Loss_{}".format(self.phase),
scalar_value=eval_loss,
global_step=self._envs['train'].current_step)
log_scores = ""
for metric in eval_scores:
self.tb_writer.add_scalar( self.tb_writer.add_scalar(
tag="{}_{}".format(metric, self.phase), tag="Loss_{}".format(self.phase),
scalar_value=eval_scores[metric], scalar_value=eval_loss,
global_step=self._envs['train'].current_step) global_step=self._envs['train'].current_step)
log_scores = ""
for metric in eval_scores:
if 'train' in self._envs:
self.tb_writer.add_scalar(
tag="{}_{}".format(metric, self.phase),
scalar_value=eval_scores[metric],
global_step=self._envs['train'].current_step)
log_scores += "%s=%.5f " % (metric, eval_scores[metric]) log_scores += "%s=%.5f " % (metric, eval_scores[metric])
logger.info( logger.info(
"[%s dataset evaluation result] loss=%.5f %s[step/sec: %.2f]" % "[%s dataset evaluation result] loss=%.5f %s[step/sec: %.2f]" %
...@@ -501,6 +506,7 @@ class BasicTask(object): ...@@ -501,6 +506,7 @@ class BasicTask(object):
"best_model") "best_model")
logger.info("best model saved to %s [best %s=%.5f]" % logger.info("best model saved to %s [best %s=%.5f]" %
(model_saved_dir, main_metric, main_value)) (model_saved_dir, main_metric, main_value))
save_result = fluid.io.save_persistables( save_result = fluid.io.save_persistables(
executor=self.exe, executor=self.exe,
dirname=model_saved_dir, dirname=model_saved_dir,
......
...@@ -583,7 +583,7 @@ class Module(object): ...@@ -583,7 +583,7 @@ class Module(object):
if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0: if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
raise ValueError( raise ValueError(
"max_seq_len({}) should be in the range of [1, {}]".format( "max_seq_len({}) should be in the range of [1, {}]".format(
MAX_SEQ_LENGTH)) max_seq_len, MAX_SEQ_LENGTH))
logger.info( logger.info(
"Set maximum sequence length of input tensor to {}".format( "Set maximum sequence length of input tensor to {}".format(
max_seq_len)) max_seq_len))
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册