drcd.py 8.3 KB
Newer Older
K
kinghuin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# coding:utf-8
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run BERT on DRCD"""

import json
import os

from paddlehub.reader import tokenization
from paddlehub.common.dir import DATA_HOME
from paddlehub.common.logger import logger
K
kinghuin 已提交
23
from paddlehub.dataset.base_nlp_dataset import BaseNLPDatast
K
kinghuin 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

_DATA_URL = "https://bj.bcebos.com/paddlehub-dataset/drcd.tar.gz"
SPIECE_UNDERLINE = '▁'


class DRCDExample(object):
    """A single training/test example for simple sequence classification.

     For examples without an answer, the start and end position are -1.
  """

    def __init__(self,
                 qas_id,
                 question_text,
                 doc_tokens,
                 orig_answer_text=None,
                 start_position=None,
K
kinghuin 已提交
41
                 end_position=None):
K
kinghuin 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        self.qas_id = qas_id
        self.question_text = question_text
        self.doc_tokens = doc_tokens
        self.orig_answer_text = orig_answer_text
        self.start_position = start_position
        self.end_position = end_position

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        s = ""
        s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
        s += ", question_text: %s" % (tokenization.printable_text(
            self.question_text))
        s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
        if self.start_position is not None:
            s += ", orig_answer_text: %s" % (self.orig_answer_text)
            s += ", start_position: %d" % (self.start_position)
            s += ", end_position: %d" % (self.end_position)
        return s


K
kinghuin 已提交
65
class DRCD(BaseNLPDatast):
K
kinghuin 已提交
66 67 68
    """A single set of features of data."""

    def __init__(self):
K
kinghuin 已提交
69 70 71 72 73 74 75 76 77 78 79 80
        dataset_dir = os.path.join(DATA_HOME, "drcd")
        base_path = self._download_dataset(dataset_dir, url=_DATA_URL)
        super(DRCD, self).__init__(
            base_path=base_path,
            train_file="DRCD_training.json",
            dev_file="DRCD_dev.json",
            test_file="DRCD_test.json",
            label_file=None,
            label_list=None,
        )

    def _read_file(self, input_file, phase=None):
K
kinghuin 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
        """Read a DRCD json file into a list of CRCDExample."""

        def _is_chinese_char(cp):
            if ((cp >= 0x4E00 and cp <= 0x9FFF)
                    or (cp >= 0x3400 and cp <= 0x4DBF)
                    or (cp >= 0x20000 and cp <= 0x2A6DF)
                    or (cp >= 0x2A700 and cp <= 0x2B73F)
                    or (cp >= 0x2B740 and cp <= 0x2B81F)
                    or (cp >= 0x2B820 and cp <= 0x2CEAF)
                    or (cp >= 0xF900 and cp <= 0xFAFF)
                    or (cp >= 0x2F800 and cp <= 0x2FA1F)):
                return True
            return False

        def _is_punctuation(c):
            if c in [
                    '。', ',', '!', '?', ';', '、', ':', '(', ')', '-', '~', '「',
                    '《', '》', ',', '」', '"', '“', '”', '$', '『', '』', '—', ';',
                    '。', '(', ')', '-', '~', '。', '‘', '’', '─', ':'
            ]:
                return True
            return False

        def _tokenize_chinese_chars(text):
            """Because Chinese (and Japanese Kanji and Korean Hanja) does not have whitespace
            characters, we add spaces around every character in the CJK Unicode range before
            applying WordPiece. This means that Chinese is effectively character-tokenized.
            Note that the CJK Unicode block only includes Chinese-origin characters and
            does not include Hangul Korean or Katakana/Hiragana Japanese, which are tokenized
            with whitespace+WordPiece like all other languages."""
            output = []
            for char in text:
                cp = ord(char)
                if _is_chinese_char(cp) or _is_punctuation(char):
                    if len(output) > 0 and output[-1] != SPIECE_UNDERLINE:
                        output.append(SPIECE_UNDERLINE)
                    output.append(char)
                    output.append(SPIECE_UNDERLINE)
                else:
                    output.append(char)
            return "".join(output)

        def is_whitespace(c):
            if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(
                    c) == 0x202F or ord(c) == 0x3000 or c == SPIECE_UNDERLINE:
                return True
            return False

        examples = []
        with open(input_file, "r") as reader:
            input_data = json.load(reader)["data"]
        for entry in input_data:
            for paragraph in entry["paragraphs"]:
                paragraph_text = paragraph["context"]
                context = _tokenize_chinese_chars(paragraph_text)

                doc_tokens = []
                char_to_word_offset = []
                prev_is_whitespace = True
                for c in context:
                    if is_whitespace(c):
                        prev_is_whitespace = True
                    else:
                        if prev_is_whitespace:
                            doc_tokens.append(c)
                        else:
                            doc_tokens[-1] += c
                        prev_is_whitespace = False
                    if c != SPIECE_UNDERLINE:
                        char_to_word_offset.append(len(doc_tokens) - 1)

                for qa in paragraph["qas"]:
                    qas_id = qa["id"]
                    question_text = qa["question"]

                    # Only select the first answer
                    answer = qa["answers"][0]
                    orig_answer_text = answer["text"]
                    answer_offset = answer["answer_start"]
                    while paragraph_text[answer_offset] in [
                            " ", "\t", "\r", "\n", "。", ",", ":", ":", ".", ","
                    ]:
                        answer_offset += 1
                    start_position = char_to_word_offset[answer_offset]
                    answer_length = len(orig_answer_text)
                    end_position = char_to_word_offset[answer_offset +
                                                       answer_length - 1]
                    # Only add answers where the text can be exactly recovered from the
                    # document. If this CAN'T happen it's likely due to weird Unicode
                    # stuff so we will just skip the example.
                    #
                    # Note that this means for training mode, every example is NOT
                    # guaranteed to be preserved.
                    actual_text = "".join(
                        doc_tokens[start_position:(end_position + 1)])
                    cleaned_answer_text = "".join(
                        tokenization.whitespace_tokenize(orig_answer_text))
                    if actual_text.find(cleaned_answer_text) == -1:
                        logger.warning((actual_text, " vs ",
                                        cleaned_answer_text, " in ", qa))
                        continue
                    example = DRCDExample(
                        qas_id=qas_id,
                        question_text=question_text,
                        doc_tokens=doc_tokens,
                        orig_answer_text=orig_answer_text,
                        start_position=start_position,
                        end_position=end_position)
                    examples.append(example)
        return examples


if __name__ == "__main__":
    ds = DRCD()
    print("train")
    examples = ds.get_train_examples()
    for index, e in enumerate(examples):
        if index < 10:
            print(e)
    print("dev")
    examples = ds.get_dev_examples()
    for index, e in enumerate(examples):
        if index < 10:
            print(e)
    print("test")
    examples = ds.get_test_examples()
    for index, e in enumerate(examples):
        if index < 10:
            print(e)