提交 8d32dc74 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2932 delete package to_mindrecord

Merge pull request !2932 from shenwei41/sw_r0.5
# Guideline to Convert Training Data CLUERNER2020 to MindRecord For Bert Fine Tuning
<!-- TOC -->
- [What does the example do](#what-does-the-example-do)
- [How to use the example to process CLUERNER2020](#how-to-use-the-example-to-process-cluerner2020)
- [Download CLUERNER2020 and unzip](#download-cluerner2020-and-unzip)
- [Generate MindRecord](#generate-mindrecord)
- [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
<!-- /TOC -->
## What does the example do
This example is based on [CLUERNER2020](https://www.cluebenchmarks.com/introduce.html) training data, generating MindRecord file, and finally used for Bert Fine Tuning progress.
1. run.sh: generate MindRecord entry script
2. run_read.py: create MindDataset by MindRecord entry script.
- create_dataset.py: use MindDataset to read MindRecord to generate dataset.
## How to use the example to process CLUERNER2020
Download CLUERNER2020, convert it to MindRecord, use MindDataset to read MindRecord.
### Download CLUERNER2020 and unzip
1. Download the training data zip.
> [CLUERNER2020 dataset download address](https://www.cluebenchmarks.com/introduce.html) **-> 任务介绍 -> CLUENER 细粒度命名实体识别 -> cluener下载链接**
2. Unzip the training data to dir example/nlp_to_mindrecord/CLUERNER2020/cluener_public.
```
unzip -d {your-mindspore}/example/nlp_to_mindrecord/CLUERNER2020/data/cluener_public cluener_public.zip
```
### Generate MindRecord
1. Run the run.sh script.
```bash
bash run.sh
```
2. Output like this:
```
...
[INFO] ME(17603:139620983514944,MainProcess):2020-04-28-16:56:12.498.235 [mindspore/mindrecord/filewriter.py:313] The list of mindrecord files created are: ['data/train.mindrecord'], and the list of index files are: ['data/train.mindrecord.db']
...
[INFO] ME(17603,python):2020-04-28-16:56:13.400.175 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
[INFO] ME(17603,python):2020-04-28-16:56:13.400.863 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
[INFO] ME(17603,python):2020-04-28-16:56:13.401.534 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
[INFO] ME(17603,python):2020-04-28-16:56:13.402.179 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
[INFO] ME(17603,python):2020-04-28-16:56:13.402.702 [mindspore/ccsrc/mindrecord/io/shard_writer.cc:667] WriteRawData] Write 1 records successfully.
...
[INFO] ME(17603:139620983514944,MainProcess):2020-04-28-16:56:13.431.208 [mindspore/mindrecord/filewriter.py:313] The list of mindrecord files created are: ['data/dev.mindrecord'], and the list of index files are: ['data/dev.mindrecord.db']
```
3. Generate files like this:
```bash
$ ls output/
dev.mindrecord dev.mindrecord.db README.md train.mindrecord train.mindrecord.db
```
### Create MindDataset By MindRecord
1. Run the run_read.sh script.
```bash
bash run_read.sh
```
2. Output like this:
```
...
example 1340: input_ids: [ 101 3173 1290 4852 7676 3949 122 3299 123 126 3189 4510 8020 6381 5442 7357 2590 3636 8021 7676 3949 4294 1166 6121 3124 1277 6121 3124 7270 2135 3295 5789 3326 123 126 3189 1355 6134 1093 1325 3173 2399 6590 6791 8024 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1340: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1340: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1340: label_ids: [ 0 18 19 20 2 4 0 0 0 0 0 0 0 34 36 26 27 28 0 34 35 35 35 35 35 35 35 35 35 36 26 27 28 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1341: input_ids: [ 101 1728 711 4293 3868 1168 2190 2150 3791 934 3633 3428 4638 6237 7025 8024 3297 1400 5310 3362 6206 5023 5401 1744 3297 7770 3791 7368 976 1139 1104 2137 511 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1341: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1341: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 1341: label_ids: [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 18 19 19 19 19 20 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
...
```
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create MindDataset by MindRecord"""
import mindspore.dataset as ds
def create_dataset(data_file):
"""create MindDataset"""
num_readers = 4
data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
index = 0
for item in data_set.create_dict_iterator():
# print("example {}: {}".format(index, item))
print("example {}: input_ids: {}".format(index, item['input_ids']))
print("example {}: input_mask: {}".format(index, item['input_mask']))
print("example {}: segment_ids: {}".format(index, item['segment_ids']))
print("example {}: label_ids: {}".format(index, item['label_ids']))
index += 1
if index % 1000 == 0:
print("read rows: {}".format(index))
print("total rows: {}".format(index))
if __name__ == '__main__':
create_dataset('output/train.mindrecord')
create_dataset('output/dev.mindrecord')
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
rm -f output/train.mindrecord*
rm -f output/dev.mindrecord*
if [ ! -d "../../../third_party/to_mindrecord/CLUERNER2020" ]; then
echo "The patch base dir ../../../third_party/to_mindrecord/CLUERNER2020 is not exist."
exit 1
fi
if [ ! -f "../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch" ]; then
echo "The patch file ../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch is not exist."
exit 1
fi
# patch for data_processor_seq.py
patch -p0 -d ../../../third_party/to_mindrecord/CLUERNER2020/ -o data_processor_seq_patched.py < ../../../third_party/patch/to_mindrecord/CLUERNER2020/data_processor_seq.patch
if [ $? -ne 0 ]; then
echo "Patch ../../../third_party/to_mindrecord/CLUERNER2020/data_processor_seq.py failed"
exit 1
fi
# use patched script
python ../../../third_party/to_mindrecord/CLUERNER2020/data_processor_seq_patched.py \
--vocab_file=../../../third_party/to_mindrecord/CLUERNER2020/vocab.txt \
--label2id_file=../../../third_party/to_mindrecord/CLUERNER2020/label2id.json
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
python create_dataset.py
# Guideline to Convert Training Data enwiki to MindRecord For Bert Pre Training
<!-- TOC -->
- [What does the example do](#what-does-the-example-do)
- [How to use the example to process enwiki](#how-to-use-the-example-to-process-enwiki)
- [Download enwiki training data](#download-enwiki-training-data)
- [Process the enwiki](#process-the-enwiki)
- [Generate MindRecord](#generate-mindrecord)
- [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
<!-- /TOC -->
## What does the example do
This example is based on [enwiki](https://dumps.wikimedia.org/enwiki) training data, generating MindRecord file, and finally used for Bert network training.
1. run.sh: generate MindRecord entry script.
2. run_read.py: create MindDataset by MindRecord entry script.
- create_dataset.py: use MindDataset to read MindRecord to generate dataset.
## How to use the example to process enwiki
Download enwiki data, process it, convert it to MindRecord, use MindDataset to read MindRecord.
### Download enwiki training data
> [enwiki dataset download address](https://dumps.wikimedia.org/enwiki) **-> 20200501 -> enwiki-20200501-pages-articles-multistream.xml.bz2**
### Process the enwiki
1. Please follow the steps in [process enwiki](https://github.com/mlperf/training/tree/master/language_model/tensorflow/bert)
- All permissions of this step belong to the link address website.
### Generate MindRecord
1. Run the run.sh script.
```
bash run.sh input_dir output_dir vocab_file
```
- input_dir: the directory which contains files like 'part-00251-of-00500'.
- output_dir: which will store the output mindrecord files.
- vocab_file: the vocab file which you can download from other opensource project.
2. The output like this:
```
...
Begin preprocess Wed Jun 10 09:21:23 CST 2020
Begin preprocess input file: /mnt/data/results/part-00000-of-00500
Begin output file: part-00000-of-00500.mindrecord
Total task: 510, processing: 1
Begin preprocess input file: /mnt/data/results/part-00001-of-00500
Begin output file: part-00001-of-00500.mindrecord
Total task: 510, processing: 2
Begin preprocess input file: /mnt/data/results/part-00002-of-00500
Begin output file: part-00002-of-00500.mindrecord
Total task: 510, processing: 3
Begin preprocess input file: /mnt/data/results/part-00003-of-00500
Begin output file: part-00003-of-00500.mindrecord
Total task: 510, processing: 4
Begin preprocess input file: /mnt/data/results/part-00004-of-00500
Begin output file: part-00004-of-00500.mindrecord
Total task: 510, processing: 4
...
```
3. Generate files like this:
```bash
$ ls {your_output_dir}/
part-00000-of-00500.mindrecord part-00000-of-00500.mindrecord.db part-00001-of-00500.mindrecord part-00001-of-00500.mindrecord.db part-00002-of-00500.mindrecord part-00002-of-00500.mindrecord.db ...
```
### Create MindDataset By MindRecord
1. Run the run_read.sh script.
```bash
bash run_read.sh input_dir
```
- input_dir: the directory which contains mindrecord files.
2. The output like this:
```
...
example 633: input_ids: [ 101 2043 19781 4305 2140 4520 2041 1010 103 2034 2455 2002
7879 2003 1996 2455 1997 103 26378 4160 1012 102 7291 2001
1996 103 1011 2343 1997 6327 1010 3423 1998 103 4262 2005
1996 2118 1997 2329 3996 103 102 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0]
example 633: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 633: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
example 633: masked_lm_positions: [ 8 17 20 25 33 41 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0]
example 633: masked_lm_ids: [ 1996 16137 1012 3580 2451 1012 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0]
example 633: masked_lm_weights: [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0.]
example 633: next_sentence_labels: [1]
...
```
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create MindDataset by MindRecord"""
import argparse
import mindspore.dataset as ds
def create_dataset(data_file):
"""create MindDataset"""
num_readers = 4
data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
index = 0
for item in data_set.create_dict_iterator():
# print("example {}: {}".format(index, item))
print("example {}: input_ids: {}".format(index, item['input_ids']))
print("example {}: input_mask: {}".format(index, item['input_mask']))
print("example {}: segment_ids: {}".format(index, item['segment_ids']))
print("example {}: masked_lm_positions: {}".format(index, item['masked_lm_positions']))
print("example {}: masked_lm_ids: {}".format(index, item['masked_lm_ids']))
print("example {}: masked_lm_weights: {}".format(index, item['masked_lm_weights']))
print("example {}: next_sentence_labels: {}".format(index, item['next_sentence_labels']))
index += 1
if index % 1000 == 0:
print("read rows: {}".format(index))
print("total rows: {}".format(index))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_file", nargs='+', type=str, help='Input mindreord file')
args = parser.parse_args()
create_dataset(args.input_file)
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 3 ]; then
echo "Usage: $0 input_dir output_dir vocab_file"
exit 1
fi
if [ ! -d $1 ]; then
echo "The input dir: $1 is not exist."
exit 1
fi
if [ ! -d $2 ]; then
echo "The output dir: $2 is not exist."
exit 1
fi
rm -fr $2/*.mindrecord*
if [ ! -f $3 ]; then
echo "The vocab file: $3 is not exist."
exit 1
fi
data_dir=$1
output_dir=$2
vocab_file=$3
file_list=()
output_filename=()
file_index=0
function getdir() {
elements=`ls $1`
for element in ${elements[*]};
do
dir_or_file=$1"/"$element
if [ -d $dir_or_file ];
then
getdir $dir_or_file
else
file_list[$file_index]=$dir_or_file
echo "${dir_or_file}" | tr '/' '\n' > dir_file_list.txt # dir dir file to mapfile
mapfile parent_dir < dir_file_list.txt
rm dir_file_list.txt >/dev/null 2>&1
tmp_output_filename=${parent_dir[${#parent_dir[@]}-1]}".mindrecord"
output_filename[$file_index]=`echo ${tmp_output_filename} | sed 's/ //g'`
file_index=`expr $file_index + 1`
fi
done
}
getdir "${data_dir}"
# echo "The input files: "${file_list[@]}
# echo "The output files: "${output_filename[@]}
if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
exit 1
fi
if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
exit 1
fi
# patch for create_pretraining_data.py
patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
if [ $? -ne 0 ]; then
echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
exit 1
fi
# get the cpu core count
num_cpu_core=`cat /proc/cpuinfo | grep "processor" | wc -l`
avaiable_core_size=`expr $num_cpu_core / 3 \* 2`
echo "Begin preprocess `date`"
# using patched script to generate mindrecord
file_list_len=`expr ${#file_list[*]} - 1`
for index in $(seq 0 $file_list_len); do
echo "Begin preprocess input file: ${file_list[$index]}"
echo "Begin output file: ${output_filename[$index]}"
python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
--input_file=${file_list[$index]} \
--output_file=${output_dir}/${output_filename[$index]} \
--partition_number=1 \
--vocab_file=${vocab_file} \
--do_lower_case=True \
--max_seq_length=512 \
--max_predictions_per_seq=76 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 &
process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
echo "Total task: ${#file_list[*]}, processing: ${process_count}"
if [ $process_count -ge $avaiable_core_size ]; then
while [ 1 ]; do
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
if [ $process_count -gt $process_num ]; then
process_count=$process_num
break;
fi
sleep 2
done
fi
done
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
while [ 1 ]; do
if [ $process_num -eq 0 ]; then
break;
fi
echo "There are still ${process_num} preprocess running ..."
sleep 2
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
done
echo "Preprocess all the data success."
echo "End preprocess `date`"
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 1 ]; then
echo "Usage: $0 input_dir"
exit 1
fi
if [ ! -d $1 ]; then
echo "The input dir: $1 is not exist."
exit 1
fi
file_list=()
file_index=0
# get all the mindrecord file from output dir
function getdir() {
elements=`ls $1/part-*.mindrecord`
for element in ${elements[*]};
do
file_list[$file_index]=$element
file_index=`expr $file_index + 1`
done
}
getdir $1
echo "Get all the mindrecord files: "${file_list[*]}
# create dataset for train
python create_dataset.py --input_file ${file_list[*]}
# Guideline to Convert Training Data zhwiki to MindRecord For Bert Pre Training
<!-- TOC -->
- [What does the example do](#what-does-the-example-do)
- [Run simple test](#run-simple-test)
- [How to use the example to process zhwiki](#how-to-use-the-example-to-process-zhwiki)
- [Download zhwiki training data](#download-zhwiki-training-data)
- [Extract the zhwiki](#extract-the-zhwiki)
- [Generate MindRecord](#generate-mindrecord)
- [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord)
<!-- /TOC -->
## What does the example do
This example is based on [zhwiki](https://dumps.wikimedia.org/zhwiki) training data, generating MindRecord file, and finally used for Bert network training.
1. run.sh: generate MindRecord entry script.
2. run_read.py: create MindDataset by MindRecord entry script.
- create_dataset.py: use MindDataset to read MindRecord to generate dataset.
## Run simple test
Follow the step:
```bash
bash run_simple.sh # generate output/simple.mindrecord* by ../../../third_party/to_mindrecord/zhwiki/sample_text.txt
bash run_read_simple.sh # use MindDataset to read output/simple.mindrecord*
```
## How to use the example to process zhwiki
Download zhwiki data, extract it, convert it to MindRecord, use MindDataset to read MindRecord.
### Download zhwiki training data
> [zhwiki dataset download address](https://dumps.wikimedia.org/zhwiki) **-> 20200401 -> zhwiki-20200401-pages-articles-multistream.xml.bz2**
- put the zhwiki-20200401-pages-articles-multistream.xml.bz2 in {your-mindspore}/example/nlp_to_mindrecord/zhwiki/data directory.
### Extract the zhwiki
1. Download [wikiextractor](https://github.com/attardi/wikiextractor) script to {your-mindspore}/example/nlp_to_mindrecord/zhwiki/data directory.
```
$ ls data/
README.md wikiextractor zhwiki-20200401-pages-articles-multistream.xml.bz2
```
2. Extract the zhwiki.
```python
python data/wikiextractor/WikiExtractor.py data/zhwiki-20200401-pages-articles-multistream.xml.bz2 --processes 4 --templates data/template --bytes 8M --min_text_length 0 --filter_disambig_pages --output data/extract
```
3. Generate like this:
```
$ ls data/extract
AA AB
```
### Generate MindRecord
1. Run the run.sh script.
```
bash run.sh
```
> Caution: This process maybe slow, please wait patiently. If you do not have a machine with enough memory and cpu, it is recommended that you modify the script to generate mindrecord in step by step.
2. The output like this:
```
patching file create_pretraining_data_patched.py (read from create_pretraining_data.py)
Begin preprocess input file: ./data/extract/AA/wiki_00
Begin output file: AAwiki_00.mindrecord
Total task: 5, processing: 1
Begin preprocess input file: ./data/extract/AA/wiki_01
Begin output file: AAwiki_01.mindrecord
Total task: 5, processing: 2
Begin preprocess input file: ./data/extract/AA/wiki_02
Begin output file: AAwiki_02.mindrecord
Total task: 5, processing: 3
Begin preprocess input file: ./data/extract/AB/wiki_02
Begin output file: ABwiki_02.mindrecord
Total task: 5, processing: 4
...
```
3. Generate files like this:
```bash
$ ls output/
AAwiki_00.mindrecord AAwiki_00.mindrecord.db AAwiki_01.mindrecord AAwiki_01.mindrecord.db AAwiki_02.mindrecord AAwiki_02.mindrecord.db ... ABwiki_00.mindrecord ABwiki_00.mindrecord.db ...
```
### Create MindDataset By MindRecord
1. Run the run_read.sh script.
```bash
bash run_read.sh
```
2. The output like this:
```
...
example 74: input_ids: [ 101 8168 118 12847 8783 9977 15908 117 8256 9245 11643 8168 8847 8588 11575 8154 8228 143 8384 8376 9197 10241 103 10564 11421 8199 12268 112 161 8228 11541 9586 8436 8174 8363 9864 9702 103 103 119 103 9947 10564 103 8436 8806 11479 103 8912 119 103 103 103 12209 8303 103 8757 8824 117 8256 103 8619 8168 11541 102 11684 8196 103 8228 8847 11523 117 9059 9064 12410 8358 8181 10764 117 11167 11706 9920 148 8332 11390 8936 8205 10951 11997 103 8154 117 103 8670 10467 112 161 10951 13139 12413 117 10288 143 10425 8205 152 10795 8472 8196 103 161 12126 9172 13129 12106 8217 8174 12244 8205 143 103 8461 8277 10628 160 8221 119 102]
example 74: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
example 74: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
example 74: masked_lm_positions: [ 6 22 37 38 40 43 47 50 51 52 55 60 67 76 89 92 98 109 120 0]
example 74: masked_lm_ids: [ 8118 8165 8329 8890 8554 8458 119 8850 8565 10392 8174 11467 10291 8181 8549 12718 13139 112 158 0]
example 74: masked_lm_weights: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0.]
example 74: next_sentence_labels: [0]
...
```
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""create MindDataset by MindRecord"""
import argparse
import mindspore.dataset as ds
def create_dataset(data_file):
"""create MindDataset"""
num_readers = 4
data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True)
index = 0
for item in data_set.create_dict_iterator():
# print("example {}: {}".format(index, item))
print("example {}: input_ids: {}".format(index, item['input_ids']))
print("example {}: input_mask: {}".format(index, item['input_mask']))
print("example {}: segment_ids: {}".format(index, item['segment_ids']))
print("example {}: masked_lm_positions: {}".format(index, item['masked_lm_positions']))
print("example {}: masked_lm_ids: {}".format(index, item['masked_lm_ids']))
print("example {}: masked_lm_weights: {}".format(index, item['masked_lm_weights']))
print("example {}: next_sentence_labels: {}".format(index, item['next_sentence_labels']))
index += 1
if index % 1000 == 0:
print("read rows: {}".format(index))
print("total rows: {}".format(index))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_file", nargs='+', type=str, help='Input mindreord file')
args = parser.parse_args()
create_dataset(args.input_file)
wikiextractor/
zhwiki-20200401-pages-articles-multistream.xml.bz2
extract/
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
rm -f output/*.mindrecord*
data_dir="./data/extract"
file_list=()
output_filename=()
file_index=0
function getdir() {
elements=`ls $1`
for element in ${elements[*]};
do
dir_or_file=$1"/"$element
if [ -d $dir_or_file ];
then
getdir $dir_or_file
else
file_list[$file_index]=$dir_or_file
echo "${dir_or_file}" | tr '/' '\n' > dir_file_list.txt # dir dir file to mapfile
mapfile parent_dir < dir_file_list.txt
rm dir_file_list.txt >/dev/null 2>&1
tmp_output_filename=${parent_dir[${#parent_dir[@]}-2]}${parent_dir[${#parent_dir[@]}-1]}".mindrecord"
output_filename[$file_index]=`echo ${tmp_output_filename} | sed 's/ //g'`
file_index=`expr $file_index + 1`
fi
done
}
getdir "${data_dir}"
# echo "The input files: "${file_list[@]}
# echo "The output files: "${output_filename[@]}
if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
exit 1
fi
if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
exit 1
fi
# patch for create_pretraining_data.py
patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
if [ $? -ne 0 ]; then
echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
exit 1
fi
# get the cpu core count
num_cpu_core=`cat /proc/cpuinfo | grep "processor" | wc -l`
avaiable_core_size=`expr $num_cpu_core / 3 \* 2`
echo "Begin preprocess `date`"
# using patched script to generate mindrecord
file_list_len=`expr ${#file_list[*]} - 1`
for index in $(seq 0 $file_list_len); do
echo "Begin preprocess input file: ${file_list[$index]}"
echo "Begin output file: ${output_filename[$index]}"
python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
--input_file=${file_list[$index]} \
--output_file=output/${output_filename[$index]} \
--partition_number=1 \
--vocab_file=../../../third_party/to_mindrecord/zhwiki/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 & # user defined
process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
echo "Total task: ${#file_list[*]}, processing: ${process_count}"
if [ $process_count -ge $avaiable_core_size ]; then
while [ 1 ]; do
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
if [ $process_count -gt $process_num ]; then
process_count=$process_num
break;
fi
sleep 2
done
fi
done
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
while [ 1 ]; do
if [ $process_num -eq 0 ]; then
break;
fi
echo "There are still ${process_num} preprocess running ..."
sleep 2
process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
done
echo "Preprocess all the data success."
echo "End preprocess `date`"
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
file_list=()
file_index=0
# get all the mindrecord file from output dir
function getdir() {
elements=`ls $1/[A-Z]*.mindrecord`
for element in ${elements[*]};
do
file_list[$file_index]=$element
file_index=`expr $file_index + 1`
done
}
getdir "./output"
echo "Get all the mindrecord files: "${file_list[*]}
# create dataset for train
python create_dataset.py --input_file ${file_list[*]}
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# create dataset for train
python create_dataset.py --input_file=output/simple.mindrecord0
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
rm -f output/simple.mindrecord*
if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then
echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist."
exit 1
fi
if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then
echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist."
exit 1
fi
# patch for create_pretraining_data.py
patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch
if [ $? -ne 0 ]; then
echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed"
exit 1
fi
# using patched script to generate mindrecord
python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \
--input_file=../../../third_party/to_mindrecord/zhwiki/sample_text.txt \
--output_file=output/simple.mindrecord \
--partition_number=4 \
--vocab_file=../../../third_party/to_mindrecord/zhwiki/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=10 # user defined
## the file is a patch which is about just change data_processor_seq.py the part of generated tfrecord to MindRecord in [CLUEbenchmark/CLUENER2020](https://github.com/CLUEbenchmark/CLUENER2020/tree/master/tf_version)
--- data_processor_seq.py 2020-05-28 10:07:13.365947168 +0800
+++ data_processor_seq.py 2020-05-28 10:14:33.298177130 +0800
@@ -4,11 +4,18 @@
@author: Cong Yu
@time: 2019-12-07 17:03
"""
+import sys
+sys.path.append("../../../third_party/to_mindrecord/CLUERNER2020")
+
+import argparse
import json
import tokenization
import collections
-import tensorflow as tf
+import numpy as np
+from mindspore.mindrecord import FileWriter
+
+# pylint: skip-file
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
@@ -80,11 +87,18 @@ def process_one_example(tokenizer, label
return feature
-def prepare_tf_record_data(tokenizer, max_seq_len, label2id, path, out_path):
+def prepare_mindrecord_data(tokenizer, max_seq_len, label2id, path, out_path):
"""
- 生成训练数据, tf.record, 单标签分类模型, 随机打乱数据
+ 生成训练数据, *.mindrecord, 单标签分类模型, 随机打乱数据
"""
- writer = tf.python_io.TFRecordWriter(out_path)
+ writer = FileWriter(out_path)
+
+ data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
+ "input_mask": {"type": "int64", "shape": [-1]},
+ "segment_ids": {"type": "int64", "shape": [-1]},
+ "label_ids": {"type": "int64", "shape": [-1]}}
+ writer.add_schema(data_schema, "CLUENER2020 schema")
+
example_count = 0
for line in open(path):
@@ -113,16 +127,12 @@ def prepare_tf_record_data(tokenizer, ma
feature = process_one_example(tokenizer, label2id, list(_["text"]), labels,
max_seq_len=max_seq_len)
- def create_int_feature(values):
- f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
- return f
-
features = collections.OrderedDict()
# 序列标注任务
- features["input_ids"] = create_int_feature(feature[0])
- features["input_mask"] = create_int_feature(feature[1])
- features["segment_ids"] = create_int_feature(feature[2])
- features["label_ids"] = create_int_feature(feature[3])
+ features["input_ids"] = np.asarray(feature[0])
+ features["input_mask"] = np.asarray(feature[1])
+ features["segment_ids"] = np.asarray(feature[2])
+ features["label_ids"] = np.asarray(feature[3])
if example_count < 5:
print("*** Example ***")
print(_["text"])
@@ -132,8 +142,7 @@ def prepare_tf_record_data(tokenizer, ma
print("segment_ids: %s" % " ".join([str(x) for x in feature[2]]))
print("label: %s " % " ".join([str(x) for x in feature[3]]))
- tf_example = tf.train.Example(features=tf.train.Features(feature=features))
- writer.write(tf_example.SerializeToString())
+ writer.write_raw_data([features])
example_count += 1
# if example_count == 20:
@@ -141,17 +150,22 @@ def prepare_tf_record_data(tokenizer, ma
if example_count % 3000 == 0:
print(example_count)
print("total example:", example_count)
- writer.close()
+ writer.commit()
if __name__ == "__main__":
- vocab_file = "./vocab.txt"
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vocab_file", type=str, required=True, help='The vocabulary file.')
+ parser.add_argument("--label2id_file", type=str, required=True, help='The label2id.json file.')
+ args = parser.parse_args()
+
+ vocab_file = args.vocab_file
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file)
- label2id = json.loads(open("label2id.json").read())
+ label2id = json.loads(open(args.label2id_file).read())
max_seq_len = 64
- prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_train.json",
- out_path="data/train.tf_record")
- prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_valid.json",
- out_path="data/dev.tf_record")
+ prepare_mindrecord_data(tokenizer, max_seq_len, label2id, path="data/cluener_public/train.json",
+ out_path="output/train.mindrecord")
+ prepare_mindrecord_data(tokenizer, max_seq_len, label2id, path="data/cluener_public/dev.json",
+ out_path="output/dev.mindrecord")
## the file is a patch which is about just change create_pretraining_data.py the part of generated tfrecord to MindRecord in [google-research/bert](https://github.com/google-research/bert)
--- create_pretraining_data.py 2020-05-27 17:02:14.285363720 +0800
+++ create_pretraining_data.py 2020-05-27 17:30:52.427767841 +0800
@@ -12,57 +12,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Create masked LM/next sentence masked_lm TF examples for BERT."""
+"""Create masked LM/next sentence masked_lm MindRecord files for BERT."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+sys.path.append("../../../third_party/to_mindrecord/zhwiki")
+
+import argparse
import collections
+import logging
import random
import tokenization
-import tensorflow as tf
-
-flags = tf.flags
-
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string("input_file", None,
- "Input raw text file (or comma-separated list of files).")
-
-flags.DEFINE_string(
- "output_file", None,
- "Output TF example file (or comma-separated list of files).")
-
-flags.DEFINE_string("vocab_file", None,
- "The vocabulary file that the BERT model was trained on.")
-
-flags.DEFINE_bool(
- "do_lower_case", True,
- "Whether to lower case the input text. Should be True for uncased "
- "models and False for cased models.")
-
-flags.DEFINE_bool(
- "do_whole_word_mask", False,
- "Whether to use whole word masking rather than per-WordPiece masking.")
-
-flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
-flags.DEFINE_integer("max_predictions_per_seq", 20,
- "Maximum number of masked LM predictions per sequence.")
+import numpy as np
+from mindspore.mindrecord import FileWriter
-flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
+# pylint: skip-file
-flags.DEFINE_integer(
- "dupe_factor", 10,
- "Number of times to duplicate the input data (with different masks).")
-
-flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
-
-flags.DEFINE_float(
- "short_seq_prob", 0.1,
- "Probability of creating sequences which are shorter than the "
- "maximum length.")
+logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
+ datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
class TrainingInstance(object):
@@ -94,13 +65,19 @@ class TrainingInstance(object):
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
- max_predictions_per_seq, output_files):
- """Create TF example files from `TrainingInstance`s."""
- writers = []
- for output_file in output_files:
- writers.append(tf.python_io.TFRecordWriter(output_file))
-
- writer_index = 0
+ max_predictions_per_seq, output_file, partition_number):
+ """Create MindRecord files from `TrainingInstance`s."""
+ writer = FileWriter(output_file, int(partition_number))
+
+ data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
+ "input_mask": {"type": "int64", "shape": [-1]},
+ "segment_ids": {"type": "int64", "shape": [-1]},
+ "masked_lm_positions": {"type": "int64", "shape": [-1]},
+ "masked_lm_ids": {"type": "int64", "shape": [-1]},
+ "masked_lm_weights": {"type": "float32", "shape": [-1]},
+ "next_sentence_labels": {"type": "int64", "shape": [-1]},
+ }
+ writer.add_schema(data_schema, "zhwiki schema")
total_written = 0
for (inst_index, instance) in enumerate(instances):
@@ -130,55 +107,35 @@ def write_instance_to_example_files(inst
next_sentence_label = 1 if instance.is_random_next else 0
features = collections.OrderedDict()
- features["input_ids"] = create_int_feature(input_ids)
- features["input_mask"] = create_int_feature(input_mask)
- features["segment_ids"] = create_int_feature(segment_ids)
- features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
- features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
- features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
- features["next_sentence_labels"] = create_int_feature([next_sentence_label])
-
- tf_example = tf.train.Example(features=tf.train.Features(feature=features))
-
- writers[writer_index].write(tf_example.SerializeToString())
- writer_index = (writer_index + 1) % len(writers)
+ features["input_ids"] = np.asarray(input_ids, np.int64)
+ features["input_mask"] = np.asarray(input_mask, np.int64)
+ features["segment_ids"] = np.asarray(segment_ids, np.int64)
+ features["masked_lm_positions"] = np.asarray(masked_lm_positions, np.int64)
+ features["masked_lm_ids"] = np.asarray(masked_lm_ids, np.int64)
+ features["masked_lm_weights"] = np.asarray(masked_lm_weights, np.float32)
+ features["next_sentence_labels"] = np.asarray([next_sentence_label], np.int64)
total_written += 1
if inst_index < 20:
- tf.logging.info("*** Example ***")
- tf.logging.info("tokens: %s" % " ".join(
+ logging.info("*** Example ***")
+ logging.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in instance.tokens]))
for feature_name in features.keys():
feature = features[feature_name]
- values = []
- if feature.int64_list.value:
- values = feature.int64_list.value
- elif feature.float_list.value:
- values = feature.float_list.value
- tf.logging.info(
- "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
-
- for writer in writers:
- writer.close()
-
- tf.logging.info("Wrote %d total instances", total_written)
-
-
-def create_int_feature(values):
- feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
- return feature
+ logging.info(
+ "%s: %s" % (feature_name, " ".join([str(x) for x in feature])))
+ writer.write_raw_data([features])
+ writer.commit()
-def create_float_feature(values):
- feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
- return feature
+ logging.info("Wrote %d total instances", total_written)
def create_training_instances(input_files, tokenizer, max_seq_length,
dupe_factor, short_seq_prob, masked_lm_prob,
- max_predictions_per_seq, rng):
+ max_predictions_per_seq, rng, do_whole_word_mask):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
@@ -189,7 +146,7 @@ def create_training_instances(input_file
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
for input_file in input_files:
- with tf.gfile.GFile(input_file, "r") as reader:
+ with open(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
@@ -214,7 +171,7 @@ def create_training_instances(input_file
instances.extend(
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
- masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
+ masked_lm_prob, max_predictions_per_seq, vocab_words, rng, do_whole_word_mask))
rng.shuffle(instances)
return instances
@@ -222,7 +179,7 @@ def create_training_instances(input_file
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
- masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
+ masked_lm_prob, max_predictions_per_seq, vocab_words, rng, do_whole_word_mask):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
@@ -320,7 +277,7 @@ def create_instances_from_document(
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
- tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
+ tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng, do_whole_word_mask)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
@@ -340,7 +297,7 @@ MaskedLmInstance = collections.namedtupl
def create_masked_lm_predictions(tokens, masked_lm_prob,
- max_predictions_per_seq, vocab_words, rng):
+ max_predictions_per_seq, vocab_words, rng, do_whole_word_mask):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
@@ -356,7 +313,7 @@ def create_masked_lm_predictions(tokens,
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
- if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
+ if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
@@ -433,37 +390,42 @@ def truncate_seq_pair(tokens_a, tokens_b
trunc_tokens.pop()
-def main(_):
- tf.logging.set_verbosity(tf.logging.INFO)
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input_file", type=str, required=True, help='Input raw text file (or comma-separated list of files).')
+ parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
+ parser.add_argument("--partition_number", type=int, default=1, help='The MindRecord file will be split into the number of partition.')
+ parser.add_argument("--vocab_file", type=str, required=True, help='The vocabulary file than the BERT model was trained on.')
+ parser.add_argument("--do_lower_case", type=bool, default=False, help='Whether to lower case the input text. Should be True for uncased models and False for cased models.')
+ parser.add_argument("--do_whole_word_mask", type=bool, default=False, help='Whether to use whole word masking rather than per-WordPiece masking.')
+ parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
+ parser.add_argument("--max_predictions_per_seq", type=int, default=20, help='Maximum number of masked LM predictions per sequence.')
+ parser.add_argument("--random_seed", type=int, default=12345, help='Random seed for data generation.')
+ parser.add_argument("--dupe_factor", type=int, default=10, help='Number of times to duplicate the input data (with diffrent masks).')
+ parser.add_argument("--masked_lm_prob", type=float, default=0.15, help='Masked LM probability.')
+ parser.add_argument("--short_seq_prob", type=float, default=0.1, help='Probability of creating sequences which are shorter than the maximum length.')
+ args = parser.parse_args()
tokenizer = tokenization.FullTokenizer(
- vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
input_files = []
- for input_pattern in FLAGS.input_file.split(","):
- input_files.extend(tf.gfile.Glob(input_pattern))
+ for input_pattern in args.input_file.split(","):
+ input_files.append(input_pattern)
- tf.logging.info("*** Reading from input files ***")
+ logging.info("*** Reading from input files ***")
for input_file in input_files:
- tf.logging.info(" %s", input_file)
+ logging.info(" %s", input_file)
- rng = random.Random(FLAGS.random_seed)
+ rng = random.Random(args.random_seed)
instances = create_training_instances(
- input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
- FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
- rng)
-
- output_files = FLAGS.output_file.split(",")
- tf.logging.info("*** Writing to output files ***")
- for output_file in output_files:
- tf.logging.info(" %s", output_file)
+ input_files, tokenizer, args.max_seq_length, args.dupe_factor,
+ args.short_seq_prob, args.masked_lm_prob, args.max_predictions_per_seq,
+ rng, args.do_whole_word_mask)
- write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
- FLAGS.max_predictions_per_seq, output_files)
+ write_instance_to_example_files(instances, tokenizer, args.max_seq_length,
+ args.max_predictions_per_seq, args.output_file, args.partition_number)
if __name__ == "__main__":
- flags.mark_flag_as_required("input_file")
- flags.mark_flag_as_required("output_file")
- flags.mark_flag_as_required("vocab_file")
- tf.app.run()
+ main()
## All the scripts here come from [CLUEbenchmark/CLUENER2020](https://github.com/CLUEbenchmark/CLUENER2020/tree/master/tf_version)
#!/usr/bin/python
# coding:utf8
"""
@author: Cong Yu
@time: 2019-12-07 17:03
"""
import json
import tokenization
import collections
import tensorflow as tf
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def process_one_example(tokenizer, label2id, text, label, max_seq_len=128):
# textlist = text.split(' ')
# labellist = label.split(' ')
textlist = list(text)
labellist = list(label)
tokens = []
labels = []
for i, word in enumerate(textlist):
token = tokenizer.tokenize(word)
tokens.extend(token)
label_1 = labellist[i]
for m in range(len(token)):
if m == 0:
labels.append(label_1)
else:
print("some unknown token...")
labels.append(labels[0])
# tokens = tokenizer.tokenize(example.text) -2 的原因是因为序列需要加一个句首和句尾标志
if len(tokens) >= max_seq_len - 1:
tokens = tokens[0:(max_seq_len - 2)]
labels = labels[0:(max_seq_len - 2)]
ntokens = []
segment_ids = []
label_ids = []
ntokens.append("[CLS]") # 句子开始设置CLS 标志
segment_ids.append(0)
# [CLS] [SEP] 可以为 他们构建标签,或者 统一到某个标签,反正他们是不变的,基本不参加训练 即:x-l 永远不变
label_ids.append(0) # label2id["[CLS]"]
for i, token in enumerate(tokens):
ntokens.append(token)
segment_ids.append(0)
label_ids.append(label2id[labels[i]])
ntokens.append("[SEP]")
segment_ids.append(0)
# append("O") or append("[SEP]") not sure!
label_ids.append(0) # label2id["[SEP]"]
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_len:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(0)
ntokens.append("**NULL**")
assert len(input_ids) == max_seq_len
assert len(input_mask) == max_seq_len
assert len(segment_ids) == max_seq_len
assert len(label_ids) == max_seq_len
feature = (input_ids, input_mask, segment_ids, label_ids)
return feature
def prepare_tf_record_data(tokenizer, max_seq_len, label2id, path, out_path):
"""
生成训练数据, tf.record, 单标签分类模型, 随机打乱数据
"""
writer = tf.python_io.TFRecordWriter(out_path)
example_count = 0
for line in open(path):
if not line.strip():
continue
_ = json.loads(line.strip())
len_ = len(_["text"])
labels = ["O"] * len_
for k, v in _["label"].items():
for kk, vv in v.items():
for vvv in vv:
span = vvv
s = span[0]
e = span[1] + 1
# print(s, e)
if e - s == 1:
labels[s] = "S_" + k
else:
labels[s] = "B_" + k
for i in range(s + 1, e - 1):
labels[i] = "M_" + k
labels[e - 1] = "E_" + k
# print()
# feature = process_one_example(tokenizer, label2id, row[column_name_x1], row[column_name_y],
# max_seq_len=max_seq_len)
feature = process_one_example(tokenizer, label2id, list(_["text"]), labels,
max_seq_len=max_seq_len)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
features = collections.OrderedDict()
# 序列标注任务
features["input_ids"] = create_int_feature(feature[0])
features["input_mask"] = create_int_feature(feature[1])
features["segment_ids"] = create_int_feature(feature[2])
features["label_ids"] = create_int_feature(feature[3])
if example_count < 5:
print("*** Example ***")
print(_["text"])
print(_["label"])
print("input_ids: %s" % " ".join([str(x) for x in feature[0]]))
print("input_mask: %s" % " ".join([str(x) for x in feature[1]]))
print("segment_ids: %s" % " ".join([str(x) for x in feature[2]]))
print("label: %s " % " ".join([str(x) for x in feature[3]]))
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
example_count += 1
# if example_count == 20:
# break
if example_count % 3000 == 0:
print(example_count)
print("total example:", example_count)
writer.close()
if __name__ == "__main__":
vocab_file = "./vocab.txt"
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file)
label2id = json.loads(open("label2id.json").read())
max_seq_len = 64
prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_train.json",
out_path="data/train.tf_record")
prepare_tf_record_data(tokenizer, max_seq_len, label2id, path="data/thuctc_valid.json",
out_path="data/dev.tf_record")
{
"O": 0,
"S_address": 1,
"B_address": 2,
"M_address": 3,
"E_address": 4,
"S_book": 5,
"B_book": 6,
"M_book": 7,
"E_book": 8,
"S_company": 9,
"B_company": 10,
"M_company": 11,
"E_company": 12,
"S_game": 13,
"B_game": 14,
"M_game": 15,
"E_game": 16,
"S_government": 17,
"B_government": 18,
"M_government": 19,
"E_government": 20,
"S_movie": 21,
"B_movie": 22,
"M_movie": 23,
"E_movie": 24,
"S_name": 25,
"B_name": 26,
"M_name": 27,
"E_name": 28,
"S_organization": 29,
"B_organization": 30,
"M_organization": 31,
"E_organization": 32,
"S_position": 33,
"B_position": 34,
"M_position": 35,
"E_position": 36,
"S_scene": 37,
"B_scene": 38,
"M_scene": 39,
"E_scene": 40
}
\ No newline at end of file
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
# pylint: skip-file
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
if item in vocab:
output.append(vocab[item])
else:
output.append(vocab['[UNK]'])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
## All the scripts here come from [google-research/bert](https://github.com/google-research/bert)
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import random
import tokenization
import tensorflow as tf
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("input_file", None,
"Input raw text file (or comma-separated list of files).")
flags.DEFINE_string(
"output_file", None,
"Output TF example file (or comma-separated list of files).")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_bool(
"do_whole_word_mask", False,
"Whether to use whole word masking rather than per-WordPiece masking.")
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20,
"Maximum number of masked LM predictions per sequence.")
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
flags.DEFINE_integer(
"dupe_factor", 10,
"Number of times to duplicate the input data (with different masks).")
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
flags.DEFINE_float(
"short_seq_prob", 0.1,
"Probability of creating sequences which are shorter than the "
"maximum length.")
class TrainingInstance(object):
"""A single training instance (sentence pair)."""
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
is_random_next):
self.tokens = tokens
self.segment_ids = segment_ids
self.is_random_next = is_random_next
self.masked_lm_positions = masked_lm_positions
self.masked_lm_labels = masked_lm_labels
def __str__(self):
s = ""
s += "tokens: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.tokens]))
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
s += "is_random_next: %s\n" % self.is_random_next
s += "masked_lm_positions: %s\n" % (" ".join(
[str(x) for x in self.masked_lm_positions]))
s += "masked_lm_labels: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
s += "\n"
return s
def __repr__(self):
return self.__str__()
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_files):
"""Create TF example files from `TrainingInstance`s."""
writers = []
for output_file in output_files:
writers.append(tf.python_io.TFRecordWriter(output_file))
writer_index = 0
total_written = 0
for (inst_index, instance) in enumerate(instances):
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
input_mask = [1] * len(input_ids)
segment_ids = list(instance.segment_ids)
assert len(input_ids) <= max_seq_length
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
masked_lm_positions = list(instance.masked_lm_positions)
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
masked_lm_weights = [1.0] * len(masked_lm_ids)
while len(masked_lm_positions) < max_predictions_per_seq:
masked_lm_positions.append(0)
masked_lm_ids.append(0)
masked_lm_weights.append(0.0)
next_sentence_label = 1 if instance.is_random_next else 0
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(input_mask)
features["segment_ids"] = create_int_feature(segment_ids)
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writers[writer_index].write(tf_example.SerializeToString())
writer_index = (writer_index + 1) % len(writers)
total_written += 1
if inst_index < 20:
tf.logging.info("*** Example ***")
tf.logging.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in instance.tokens]))
for feature_name in features.keys():
feature = features[feature_name]
values = []
if feature.int64_list.value:
values = feature.int64_list.value
elif feature.float_list.value:
values = feature.float_list.value
tf.logging.info(
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
for writer in writers:
writer.close()
tf.logging.info("Wrote %d total instances", total_written)
def create_int_feature(values):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return feature
def create_float_feature(values):
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return feature
def create_training_instances(input_files, tokenizer, max_seq_length,
dupe_factor, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, rng):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
# Input file format:
# (1) One sentence per line. These should ideally be actual sentences, not
# entire paragraphs or arbitrary spans of text. (Because we use the
# sentence boundaries for the "next sentence prediction" task).
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
for input_file in input_files:
with tf.gfile.GFile(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
line = line.strip()
# Empty lines are used as document delimiters
if not line:
all_documents.append([])
tokens = tokenizer.tokenize(line)
if tokens:
all_documents[-1].append(tokens)
# Remove empty documents
all_documents = [x for x in all_documents if x]
rng.shuffle(all_documents)
vocab_words = list(tokenizer.vocab.keys())
instances = []
for _ in range(dupe_factor):
for document_index in range(len(all_documents)):
instances.extend(
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
rng.shuffle(instances)
return instances
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
# Account for [CLS], [SEP], [SEP]
max_num_tokens = max_seq_length - 3
# We *usually* want to fill up the entire sequence since we are padding
# to `max_seq_length` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `max_seq_length` is a hard limit.
target_seq_length = max_num_tokens
if rng.random() < short_seq_prob:
target_seq_length = rng.randint(2, max_num_tokens)
# We DON'T just concatenate all of the tokens from a document into a long
# sequence and choose an arbitrary split point because this would make the
# next sentence prediction task too easy. Instead, we split the input into
# segments "A" and "B" based on the actual "sentences" provided by the user
# input.
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = rng.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
# Random next
is_random_next = False
if len(current_chunk) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
# This should rarely go for more than one iteration for large
# corpora. However, just to be careful, we try to make sure that
# the random document is not the same as the document
# we're processing.
for _ in range(10):
random_document_index = rng.randint(0, len(all_documents) - 1)
if random_document_index != document_index:
break
random_document = all_documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# Actual next
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
is_random_next=is_random_next,
masked_lm_positions=masked_lm_positions,
masked_lm_labels=masked_lm_labels)
instances.append(instance)
current_chunk = []
current_length = 0
i += 1
return instances
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
rng.shuffle(cand_indexes)
output_tokens = list(tokens)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels)
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
"""Truncates a pair of sequences to a maximum sequence length."""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
if rng.random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
input_files = []
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))
tf.logging.info("*** Reading from input files ***")
for input_file in input_files:
tf.logging.info(" %s", input_file)
rng = random.Random(FLAGS.random_seed)
instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng)
output_files = FLAGS.output_file.split(",")
tf.logging.info("*** Writing to output files ***")
for output_file in output_files:
tf.logging.info(" %s", output_file)
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, output_files)
if __name__ == "__main__":
flags.mark_flag_as_required("input_file")
flags.mark_flag_as_required("output_file")
flags.mark_flag_as_required("vocab_file")
tf.app.run()
This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত
Text should be one-sentence-per-line, with empty lines between documents.
This sample text is public domain and was randomly selected from Project Guttenberg.
The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors.
Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity.
Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them.
"Cass" Beard had risen early that morning, but not with a view to discovery.
A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets.
The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency.
This was nearly opposite.
Mr. Cassius crossed the highway, and stopped suddenly.
Something glittered in the nearest red pool before him.
Gold, surely!
But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring.
Looking at it more attentively, he saw that it bore the inscription, "May to Cass."
Like most of his fellow gold-seekers, Cass was superstitious.
The fountain of classic wisdom, Hypatia herself.
As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge.
From my youth I felt in me a soul above the matter-entangled herd.
She revealed to me the glorious fact, that I am a spark of Divinity itself.
A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's.
There is a philosophic pleasure in opening one's treasures to the modest young.
Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street.
Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide;
but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind.
Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now.
His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert;
while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts.
At last they reached the quay at the opposite end of the street;
and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers.
He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him.
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
# pylint: skip-file
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册