提交 f18bd94b 编写于 作者: J Jacob Devlin

Adding support for multilingual models

上级 a4dc5daf
# BERT
**\*\*\*\*\* New November 3rd, 2018: Multilingual and Chinese models avalable
\*\*\*\*\***
We have made two new BERT models available:
* **[`BERT-Base, Multilingual`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**:
102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**:
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M
parameters
We use character-based tokenization for Chinese, and WordPiece tokenization for
all other languages. Both models should work out-of-the-box without any code
changes. We did update the implementation of `BasicTokenizer` in
`tokenization.py` to support Chinese character tokenization, so please update if
you forked it. However, we did not change the tokenization API.
For more, see the
[Multilingual README](https://github.com/google-research/bert/blob/master/multilingual.md).
**\*\*\*\*\* End new information \*\*\*\*\***
## Introduction
**BERT**, or **B**idirectional **E**ncoder **R**epresentations from
......@@ -154,6 +176,9 @@ Part-of-Speech tagging).
These models are all released under the same license as the source code (Apache
2.0).
For information about the Multilingual and Chinese model, see the
[Multilingual README](https://github.com/google-research/bert/blob/master/multilingual.md).
The links to the models are here (right-click, 'Save link as...' on the name):
* **[`BERT-Base, Uncased`](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip)**:
......@@ -164,6 +189,11 @@ The links to the models are here (right-click, 'Save link as...' on the name):
12-layer, 768-hidden, 12-heads , 110M parameters
* **`BERT-Large, Cased`**: 24-layer, 1024-hidden, 16-heads, 340M parameters
(Not available yet. Needs to be re-generated).
* **[`BERT-Base, Multilingual`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**:
102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**:
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M
parameters
Each .zip file contains three items:
......
## Models
There are two multilingual models currently available. We do not plan to release
more single-language models, but we may release `BERT-Large` versions of these
two in the future:
* **[`BERT-Base, Multilingual`](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip)**:
102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
* **[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**:
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M
parameters
See the [list of languages](#list-of-languages) that the Multilingual model
supports. The Multilingual model does include Chinese (and English), but if your
fine-tuning data is Chinese-only, then the Chinese model will likely produce
better results.
## Results
To evaluate these systems, we use the
[XNLI dataset](https://github.com/facebookresearch/XNLI) dataset, which is a
version of [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) where the
dev and test sets have been translated (by humans) into 15 languages. Note that
the training set was *machine* translated (we used the translations provided by
XNLI, not Google NMT). For clarity, we only report on 6 languages below:
| System | English | Chinese | Spanish | German | Arabic | Urdu |
| ---------- | -------- | -------- | -------- | -------- | -------- | -------- |
| XNLI | 73.7 | 67.0 | 68.8 | 66.5 | 65.8 | 56.6 |
: Baseline - : : : : : : :
: Translate : : : : : : :
: Train : : : : : : :
| XNLI | 73.7 | 68.3 | 70.7 | 68.7 | 66.8 | 59.3 |
: Baseline - : : : : : : :
: Translate : : : : : : :
: Test : : : : : : :
| BERT - | **81.4** | **74.2** | **77.3** | **75.2** | **70.5** | 61.7 |
: Translate : : : : : : :
: Train : : : : : : :
| BERT - | 81.4 | 70.1 | 74.9 | 74.4 | 70.4 | **62.1** |
: Translate : : : : : : :
: Test : : : : : : :
| BERT - | 81.4 | 63.8 | 74.3 | 70.5 | 62.1 | 58.3 |
: Zero Shot : : : : : : :
The first two rows are baselines from the XNLI paper and the last three rows are
our results with BERT.
**Translate Train** means that the MultiNLI training set was machine translated
from English into the foreign language. So training and evaluation were both
done in the foreign language. Unfortunately, training was done on
machine-translated data, so it is impossible to quantify how much of the lower
accuracy (compared to English) is due to the quality of the machine translation
vs. the quality of the pre-trained model.
**Translate Test** means that the XNLI test set was machine translated from the
foreign language into English. So training and evaluation were both done on
English. However, test evaluation was done on machine-translated English, so the
accuracy depends on the quality of the machine translation system.
**Zero Shot** means that the system was trained on English, and then evaluated
on the foreign language. In this case, machine translation was not involved at
all in either the pre-training or fine-tuning.
Note that the English result is worse than the 84.2 MultiNLI baseline because
this training used Multilingual BERT rather than English-only BERT. This implies
that for high-resource languages, the Multilingual model is somewhat worse than
a single-language model. However, it is not feasible for us to train and
maintain dozens of single-language model. Therefore, if your goal is to maximize
performance with a language other than English or Chinese, you might find it
beneficial to run pre-training for additional steps starting from our
Multilingual model on data from your language of interest.
Here is a comparison of training Chinese models with the Multilingual
`BERT-Base` and Chinese-only `BERT-Base`:
System | Chinese
----------------------- | -------
XNLI Baseline | 67.0
BERT Multilingual Model | 74.2
BERT Chinese-only Model | 77.2
Similar to English, the single-language model does 3% better than the
Multilingual model.
## Fine-tuning Example
The multilingual model does **not** require any special consideration or API
changes. We did update the implementation of `BasicTokenizer` in
`tokenization.py` to support Chinese character tokenization, so please update if
you forked it. However, we did not change the tokenization API.
To test the new models, we did modify `run_classifier.py` to add support for the
[XNLI dataset](https://github.com/facebookresearch/XNLI). This is a 15-language
version of MultiNLI where the dev/test sets have been human-translated, and the
training set has been machine-translated.
To run the fine-tuning code, please download the
[XNLI dev/test set](https://s3.amazonaws.com/xnli/XNLI-1.0.zip) and the
[XNLI machine-translated training set](https://s3.amazonaws.com/xnli/XNLI-MT-1.0.zip)
and then unpack both .zip files into some directory `$XNLI_DIR`.
To run fine-tuning on XNLI. The language is hard-coded into `run_classifier.py`
(Chinese by default), so please modify `XnliProcessor` if you want to run on
another language.
This is a large dataset, so this will training will take a few hours on a GPU
(or about 30 minutes on a Cloud TPU). To run an experiment quickly for
debugging, just set `num_train_epochs` to a small value like `0.1`.
```shell
export BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12 # or multilingual_L-12_H-768_A-12
export XNLI_DIR=/path/to/xnli
python run_classifier.py \
--task_name=XNLI \
--do_train=true \
--do_eval=true \
--data_dir=$XNLI_DIR \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=32 \
--learning_rate=5e-5 \
--num_train_epochs=2.0 \
--output_dir=/tmp/xnli_output/
```
With the Chinese-only model, the results should look something like this:
```
***** Eval results *****
eval_accuracy = 0.774116
eval_loss = 0.83554
global_step = 24543
loss = 0.74603
```
## Details
### Data Source and Sampling
The languages chosen were the
[top 100 languages with the largest Wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias).
The entire Wikipedia dump for each language (excluding user and talk pages) was
taken as the training data for each language
However, the size of the Wikipedia for a given language varies greatly, and
therefore low-resource languages may be "under-represented" in terms of the
neural network model (under the assumption that languages are "competing" for
limited model capacity to some extent).
However, the size of a Wikipedia also correlates with the number of speakers of
a language, and we also don't want to overfit the model by performing thousands
of epochs over a tiny Wikipedia for a particular language.
To balance these two factors, we performed exponentially smoothed weighting of
the data during pre-training data creation (and WordPiece vocab creation). In
other words, let's say that the probability of a language is *P(L)*, e.g.,
*P(English) = 0.21* means that after concatenating all of the Wikipedias
together, 21% of our data is English. We exponentiate each probability by some
factor *S* and then re-normalize, and sample from that distribution. In our case
we use *S=0.7*. So, high-resource languages like English will be under-sampled,
and low-resource languages like Icelandic will be over-sampled. E.g., in the
original distribution English would be sampled 1000x more than Icelandic, but
after smoothing it's only sampled 100x more.
### Tokenization
For tokenization, we use a 110k shared WordPiece vocabulary. The word counts are
weighted the same way as the data, so low-resource languages are upweighted by
some factor. We intentionally do *not* use any marker to denote the input
language (so that zero-shot training can work).
Because Chinese does not have whitespace characters, we add spaces around every
character in the
[CJK Unicode range](https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_\(Unicode_block\))
before applying WordPiece. This means that Chinese is effectively
character-tokenized. Note that the CJK Unicode block only includes
Chinese-origin characters and does *not* include Hangul Korean or
Katakana/Hiragana Japanese, which are tokenized with whitespace+WordPiece like
all other languages.
For all other languages, we apply the
[same recipe as English](https://github.com/google-research/bert#tokenization):
(a) lower casing+accent removal, (b) punctuation splitting, (c) whitespace
tokenization. We understand that accent markers have substantial meaning in some
languages, but felt that the benefits of reducing the effective vocabulary make
up for this. Generally the strong contextual models of BERT should make up for
any ambiguity introduced by stripping accent markers.
### List of Languages
The multilingual model supports the following languages. These languages were
chosen because they are the top 100 languages with the largest Wikipedias:
* Afrikaans
* Albanian
* Arabic
* Aragonese
* Armenian
* Asturian
* Azerbaijani
* Bashkir
* Basque
* Bavarian
* Belarusian
* Bengali
* Bishnupriya Manipuri
* Bosnian
* Breton
* Bulgarian
* Burmese
* Catalan
* Cebuano
* Chechen
* Chinese (Simplified)
* Chinese (Traditional)
* Chuvash
* Croatian
* Czech
* Danish
* Dutch
* English
* Estonian
* Finnish
* French
* Galician
* Georgian
* German
* Greek
* Gujarati
* Haitian
* Hebrew
* Hindi
* Hungarian
* Icelandic
* Ido
* Indonesian
* Irish
* Italian
* Japanese
* Javanese
* Kannada
* Kazakh
* Kirghiz
* Korean
* Latin
* Latvian
* Lithuanian
* Lombard
* Low Saxon
* Luxembourgish
* Macedonian
* Malagasy
* Malay
* Malayalam
* Marathi
* Minangkabau
* Nepali
* Newar
* Norwegian (Bokmal)
* Norwegian (Nynorsk)
* Occitan
* Persian (Farsi)
* Piedmontese
* Polish
* Portuguese
* Punjabi
* Romanian
* Russian
* Scots
* Serbian
* Serbo-Croatian
* Sicilian
* Slovak
* Slovenian
* South Azerbaijani
* Spanish
* Sundanese
* Swahili
* Swedish
* Tagalog
* Tajik
* Tamil
* Tatar
* Telugu
* Turkish
* Ukrainian
* Urdu
* Uzbek
* Vietnamese
* Volapük
* Waray-Waray
* Welsh
* West
* Western Punjabi
* Yoruba
The only language which we had to unfortunately exclude was Thai, since it is
the only language (other than Chinese) that does not use whitespace to delimit
words, and it has too many characters-per-word to use character-based
tokenization. Our WordPiece algorithm is quadratic with respect to the size of
the input token so very long character strings do not work with it.
......@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import csv
import os
import modeling
......@@ -174,6 +175,54 @@ class DataProcessor(object):
return lines
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
def __init__(self):
self.language = "zh"
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % self.language))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % (i)
text_a = tokenization.convert_to_unicode(line[0])
text_b = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[2])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % (i)
language = tokenization.convert_to_unicode(line[0])
if language != tokenization.convert_to_unicode(self.language):
continue
text_a = tokenization.convert_to_unicode(line[6])
text_b = tokenization.convert_to_unicode(line[7])
label = tokenization.convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
......@@ -269,16 +318,19 @@ class ColaProcessor(DataProcessor):
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer):
tokenizer, output_file):
"""Loads a data file into a list of `InputBatch`s."""
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
features = []
writer = tf.python_io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples):
tokens_a = tokenizer.tokenize(example.text_a)
if ex_index % 10000 == 0:
tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
tokens_b = None
if example.text_b:
......@@ -357,13 +409,19 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
tf.logging.info("label: %s (id = %d)" % (example.label, label_id))
features.append(
InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id))
return features
def create_int_feature(values):
feature = tf.train.Feature(
int64_list=tf.train.Int64List(value=list(values)))
return feature
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(input_mask)
features["segment_ids"] = create_int_feature(segment_ids)
features["label_ids"] = create_int_feature([label_id])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
......@@ -511,53 +569,47 @@ def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
return model_fn
def input_fn_builder(features, seq_length, is_training, drop_remainder):
def input_fn_builder(input_file, seq_length, is_training, drop_remainder):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
all_input_ids = []
all_input_mask = []
all_segment_ids = []
all_label_ids = []
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.FixedLenFeature([], tf.int64),
}
for feature in features:
all_input_ids.append(feature.input_ids)
all_input_mask.append(feature.input_mask)
all_segment_ids.append(feature.segment_ids)
all_label_ids.append(feature.label_id)
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
num_examples = len(features)
# This is for demo purposes and does NOT scale to large data sets. We do
# not use Dataset.from_generator() because that uses tf.py_func which is
# not TPU compatible. The right way to load data is with TFRecordReader.
d = tf.data.Dataset.from_tensor_slices({
"input_ids":
tf.constant(
all_input_ids, shape=[num_examples, seq_length],
dtype=tf.int32),
"input_mask":
tf.constant(
all_input_mask,
shape=[num_examples, seq_length],
dtype=tf.int32),
"segment_ids":
tf.constant(
all_segment_ids,
shape=[num_examples, seq_length],
dtype=tf.int32),
"label_ids":
tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32),
})
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
d = d.apply(
tf.contrib.data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder))
return d
return input_fn
......@@ -570,6 +622,7 @@ def main(_):
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
}
if not FLAGS.do_train and not FLAGS.do_eval:
......@@ -642,14 +695,15 @@ def main(_):
eval_batch_size=FLAGS.eval_batch_size)
if FLAGS.do_train:
train_features = convert_examples_to_features(
train_examples, label_list, FLAGS.max_seq_length, tokenizer)
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
convert_examples_to_features(train_examples, label_list,
FLAGS.max_seq_length, tokenizer, train_file)
tf.logging.info("***** Running training *****")
tf.logging.info(" Num examples = %d", len(train_examples))
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
tf.logging.info(" Num steps = %d", num_train_steps)
train_input_fn = input_fn_builder(
features=train_features,
input_file=train_file,
seq_length=FLAGS.max_seq_length,
is_training=True,
drop_remainder=True)
......@@ -657,8 +711,9 @@ def main(_):
if FLAGS.do_eval:
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, FLAGS.max_seq_length, tokenizer)
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
convert_examples_to_features(eval_examples, label_list,
FLAGS.max_seq_length, tokenizer, eval_file)
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Num examples = %d", len(eval_examples))
......@@ -675,7 +730,7 @@ def main(_):
eval_drop_remainder = True if FLAGS.use_tpu else False
eval_input_fn = input_fn_builder(
features=eval_features,
input_file=eval_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=eval_drop_remainder)
......
......@@ -134,6 +134,15 @@ class BasicTokenizer(object):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
......@@ -176,6 +185,41 @@ class BasicTokenizer(object):
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
......
......@@ -44,6 +44,13 @@ class TokenizationTest(tf.test.TestCase):
self.assertAllEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
def test_chinese(self):
tokenizer = tokenization.BasicTokenizer()
self.assertAllEqual(
tokenizer.tokenize(u"ah\u535A\u63A8zz"),
[u"ah", u"\u535A", u"\u63A8", u"zz"])
def test_basic_tokenizer_lower(self):
tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册