# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #!/usr/bin/python #-*- coding:utf-8 -*- import sys if sys.version[0] == '2': reload(sys) sys.setdefaultencoding("utf-8") import json import copy from preprocess import metric_max_over_ground_truths, f1_score def compute_paragraph_score(sample): """ For each paragraph, compute the f1 score compared with the question Args: sample: a sample in the dataset. Returns: None Raises: None """ question = sample["segmented_question"] for doc in sample['documents']: doc['segmented_paragraphs_scores'] = [] for p_idx, para_tokens in enumerate(doc['segmented_paragraphs']): if len(question) > 0: related_score = metric_max_over_ground_truths( f1_score, para_tokens, [question]) else: related_score = 0.0 doc['segmented_paragraphs_scores'].append(related_score) def dup_remove(doc): """ For each document, remove the duplicated paragraphs Args: doc: a doc in the sample Returns: bool Raises: None """ paragraphs_his = {} del_ids = [] para_id = None if 'most_related_para' in doc: para_id = doc['most_related_para'] doc['paragraphs_length'] = [] for p_idx, (segmented_paragraph, paragraph_score) in \ enumerate(zip(doc["segmented_paragraphs"], doc["segmented_paragraphs_scores"])): doc['paragraphs_length'].append(len(segmented_paragraph)) paragraph = ''.join(segmented_paragraph) if paragraph in paragraphs_his: del_ids.append(p_idx) if p_idx == para_id: para_id = paragraphs_his[paragraph] continue paragraphs_his[paragraph] = p_idx # delete prev_del_num = 0 del_num = 0 for p_idx in del_ids: if p_idx < para_id: prev_del_num += 1 del doc["segmented_paragraphs"][p_idx - del_num] del doc["segmented_paragraphs_scores"][p_idx - del_num] del doc['paragraphs_length'][p_idx - del_num] del_num += 1 if len(del_ids) != 0: if 'most_related_para' in doc: doc['most_related_para'] = para_id - prev_del_num doc['paragraphs'] = [] for segmented_para in doc["segmented_paragraphs"]: paragraph = ''.join(segmented_para) doc['paragraphs'].append(paragraph) return True else: return False def paragraph_selection(sample, mode): """ For each document, select paragraphs that includes as much information as possible Args: sample: a sample in the dataset. mode: string of ("train", "dev", "test"), indicate the type of dataset to process. Returns: None Raises: None """ # predefined maximum length of paragraph MAX_P_LEN = 500 # predefined splitter splitter = u'' # topN of related paragraph to choose topN = 3 doc_id = None if 'answer_docs' in sample and len(sample['answer_docs']) > 0: doc_id = sample['answer_docs'][0] if doc_id >= len(sample['documents']): # Data error, answer doc ID > number of documents, this sample # will be filtered by dataset.py return for d_idx, doc in enumerate(sample['documents']): if 'segmented_paragraphs_scores' not in doc: continue status = dup_remove(doc) segmented_title = doc["segmented_title"] title_len = len(segmented_title) para_id = None if doc_id is not None: para_id = sample['documents'][doc_id]['most_related_para'] total_len = title_len + sum(doc['paragraphs_length']) # add splitter para_num = len(doc["segmented_paragraphs"]) total_len += para_num if total_len <= MAX_P_LEN: incre_len = title_len total_segmented_content = copy.deepcopy(segmented_title) for p_idx, segmented_para in enumerate(doc["segmented_paragraphs"]): if doc_id == d_idx and para_id > p_idx: incre_len += len([splitter] + segmented_para) if doc_id == d_idx and para_id == p_idx: incre_len += 1 total_segmented_content += [splitter] + segmented_para if doc_id == d_idx: answer_start = incre_len + sample['answer_spans'][0][0] answer_end = incre_len + sample['answer_spans'][0][1] sample['answer_spans'][0][0] = answer_start sample['answer_spans'][0][1] = answer_end doc["segmented_paragraphs"] = [total_segmented_content] doc["segmented_paragraphs_scores"] = [1.0] doc['paragraphs_length'] = [total_len] doc['paragraphs'] = [''.join(total_segmented_content)] doc['most_related_para'] = 0 continue # find topN paragraph id para_infos = [] for p_idx, (para_tokens, para_scores) in \ enumerate(zip(doc['segmented_paragraphs'], doc['segmented_paragraphs_scores'])): para_infos.append( (para_tokens, para_scores, len(para_tokens), p_idx)) para_infos.sort(key=lambda x: (-x[1], x[2])) topN_idx = [] for para_info in para_infos[:topN]: topN_idx.append(para_info[-1]) final_idx = [] total_len = title_len if doc_id == d_idx: if mode == "train": final_idx.append(para_id) total_len = title_len + 1 + doc['paragraphs_length'][para_id] for id in topN_idx: if total_len > MAX_P_LEN: break if doc_id == d_idx and id == para_id and mode == "train": continue total_len += 1 + doc['paragraphs_length'][id] final_idx.append(id) total_segmented_content = copy.deepcopy(segmented_title) final_idx.sort() incre_len = title_len for id in final_idx: if doc_id == d_idx and id < para_id: incre_len += 1 + doc['paragraphs_length'][id] if doc_id == d_idx and id == para_id: incre_len += 1 total_segmented_content += [splitter] + doc['segmented_paragraphs'][ id] if doc_id == d_idx: answer_start = incre_len + sample['answer_spans'][0][0] answer_end = incre_len + sample['answer_spans'][0][1] sample['answer_spans'][0][0] = answer_start sample['answer_spans'][0][1] = answer_end doc["segmented_paragraphs"] = [total_segmented_content] doc["segmented_paragraphs_scores"] = [1.0] doc['paragraphs_length'] = [total_len] doc['paragraphs'] = [''.join(total_segmented_content)] doc['most_related_para'] = 0 if __name__ == "__main__": # mode="train"/"dev"/"test" mode = sys.argv[1] for line in sys.stdin: line = line.strip() if line == "": continue try: sample = json.loads(line, encoding='utf8') except: print >> sys.stderr, "Invalid input json format - '{}' will be ignored".format( line) continue compute_paragraph_score(sample) paragraph_selection(sample, mode) print(json.dumps(sample, encoding='utf8', ensure_ascii=False))