未验证 提交 34aab4f3 编写于 作者: K kinghuin 提交者: GitHub

add peoplesner, fix msraner (#5183)

add peoplesner, fix msraner (#5183)
上级 96109f97
......@@ -5,11 +5,11 @@
MSRA-NER 数据集由微软亚研院发布,其目标是识别文本中具有特定意义的实体,主要包括人名、地名、机构名等。示例如下:
```
海钓比赛地点在厦门与金门之间的海域。 OOOOOOOB-LOCI-LOCOB-LOCI-LOCOOOOOO
座依山傍水的博物馆由国内一流的设计师主持设计,整个建筑群精美而恢宏。 OOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOO
不\002久\002前\002,\002中\002国\002共\002产\002党\002召\002开\002了\002举\002世\002瞩\002目\002的\002第\002十\002五\002次\002全\002国\002代\002表\002大\002会\002。 O\002O\002O\002O\002B-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002O\002O\002O\002O\002O\002O\002O\002O\002B-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002I-ORG\002O
\002次\002代\002表\002大\002会\002是\002在\002中\002国\002改\002革\002开\002放\002和\002社\002会\002主\002义\002现\002代\002化\002建\002设\002发\002展\002的\002关\002键\002时\002刻\002召\002开\002的\002历\002史\002性\002会\002议\002。 O\002O\002O\002O\002O\002O\002O\002O\002B-LOC\002I-LOC\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O\002O
```
数据集中以特殊字符"\t"分隔文本、标签,以特殊字符"\002"分隔每个字
PaddleNLP集成的数据集MSRA-NER数据集对文件格式做了调整:每一行文本、标签以特殊字符"\t"进行分隔,每个字之间以特殊字符"\002"分隔
## 2. 快速开始
......@@ -52,22 +52,21 @@ python -u ./run_msra_ner.py \
训练过程将按照 `logging_steps``save_steps` 的设置打印如下日志:
```
global step 1496, epoch: 2, batch: 192, loss: 0.010747, speed: 4.77 step/s
global step 1497, epoch: 2, batch: 193, loss: 0.004837, speed: 4.46 step/s
global step 1498, epoch: 2, batch: 194, loss: 0.011281, speed: 4.24 step/s
global step 1499, epoch: 2, batch: 195, loss: 0.005711, speed: 4.73 step/s
global step 1500, epoch: 2, batch: 196, loss: 0.003150, speed: 4.52 step/s
eval loss: 0.010307, precision: 0.884222, recall: 0.903190, f1: 0.893605
global step 3996, epoch: 2, batch: 1184, loss: 0.008593, speed: 4.15 step/s
global step 3997, epoch: 2, batch: 1185, loss: 0.008453, speed: 4.17 step/s
global step 3998, epoch: 2, batch: 1186, loss: 0.002294, speed: 4.19 step/s
global step 3999, epoch: 2, batch: 1187, loss: 0.005351, speed: 4.16 step/s
global step 4000, epoch: 2, batch: 1188, loss: 0.004734, speed: 4.18 step/s
eval loss: 0.006829, precision: 0.908957, recall: 0.926683, f1: 0.917734
```
使用以上命令进行单卡 Fine-tuning ,在验证集上有如下结果:
Metric | Result |
------------------------------|-------------|
precision | 0.884222 |
recall | 0.903190 |
f1 | 0.893605 |
precision | 0.908957 |
recall | 0.926683 |
f1 | 0.917734 |
## 参考
[Microsoft Research Asia Chinese Word-Segmentation Data Set](https://www.microsoft.com/en-us/download/details.aspx?id=52531)
[The third international Chinese language processing bakeoff: Word segmentation and named entity recognition](https://faculty.washington.edu/levow/papers/sighan06.pdf)
......@@ -245,8 +245,8 @@ def do_train(args):
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
train_dataset, dev_dataset = ppnlp.datasets.MSRA_NER.get_datasets(
["train", "dev"])
train_dataset, test_dataset = ppnlp.datasets.MSRA_NER.get_datasets(
["train", "test"])
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
label_list = train_dataset.get_labels()
......@@ -276,11 +276,11 @@ def do_train(args):
num_workers=0,
return_list=True)
dev_dataset = dev_dataset.apply(trans_func, lazy=True)
test_dataset = test_dataset.apply(trans_func, lazy=True)
dev_batch_sampler = paddle.io.BatchSampler(
dev_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True)
dev_data_loader = DataLoader(
dataset=dev_dataset,
test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True)
test_data_loader = DataLoader(
dataset=test_dataset,
batch_sampler=dev_batch_sampler,
collate_fn=batchify_fn,
num_workers=0,
......@@ -336,7 +336,7 @@ def do_train(args):
lr_scheduler.step()
optimizer.clear_gradients()
if global_step % args.save_steps == 0:
evaluate(model, loss_fct, metric, dev_data_loader, label_num)
evaluate(model, loss_fct, metric, test_data_loader, label_num)
if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0:
paddle.save(model.state_dict(),
os.path.join(args.output_dir,
......
......@@ -17,6 +17,7 @@ from .dataset import *
from .glue import *
from .lcqmc import *
from .msra_ner import *
from .peoples_daily_ner import *
from .ptb import *
from .squad import *
from .translation import *
......
......@@ -27,7 +27,7 @@ __all__ = ['MSRA_NER']
class MSRA_NER(TSVDataset):
URL = "https://bj.bcebos.com/paddlehub-dataset/msra_ner.tar.gz"
URL = "https://paddlenlp.bj.bcebos.com/datasets/msra_ner.tar.gz"
MD5 = None
META_INFO = collections.namedtuple(
'META_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
......@@ -37,11 +37,6 @@ class MSRA_NER(TSVDataset):
'67d3c93a37daba60ef43c03271f119d7',
(0, 1),
1, ),
'dev': META_INFO(
os.path.join('msra_ner', 'dev.tsv'),
'ec772f3ba914bca5269f6e785bb3375d',
(0, 1),
1, ),
'test': META_INFO(
os.path.join('msra_ner', 'test.tsv'),
'2f27ae68b5f61d6553ffa28bb577c8a7',
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import warnings
from paddle.io import Dataset
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.env import DATA_HOME
from .dataset import TSVDataset
__all__ = ['PeoplesDailyNER']
class PeoplesDailyNER(TSVDataset):
URL = "https://paddlenlp.bj.bcebos.com/datasets/peoples_daily_ner.tar.gz"
MD5 = None
META_INFO = collections.namedtuple(
'META_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples'))
SPLITS = {
'train': META_INFO(
os.path.join('peoples_daily_ner', 'train.tsv'),
'67d3c93a37daba60ef43c03271f119d7',
(0, 1),
1, ),
'dev': META_INFO(
os.path.join('peoples_daily_ner', 'dev.tsv'),
'ec772f3ba914bca5269f6e785bb3375d',
(0, 1),
1, ),
'test': META_INFO(
os.path.join('peoples_daily_ner', 'test.tsv'),
'2f27ae68b5f61d6553ffa28bb577c8a7',
(0, 1),
1, ),
}
def __init__(self, mode='train', root=None, **kwargs):
default_root = os.path.join(DATA_HOME, 'peoples_daily_ner')
filename, data_hash, field_indices, num_discard_samples = self.SPLITS[
mode]
fullname = os.path.join(default_root,
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
if not os.path.exists(fullname) or (data_hash and
not md5file(fullname) == data_hash):
if root is not None: # not specified, and no need to warn
warnings.warn(
'md5 check failed for {}, download {} data to {}'.format(
filename, self.__class__.__name__, default_root))
path = get_path_from_url(self.URL, default_root, self.MD5)
fullname = os.path.join(default_root, filename)
super(PeoplesDailyNER, self).__init__(
fullname,
field_indices=field_indices,
num_discard_samples=num_discard_samples,
**kwargs)
def get_labels(self):
"""
Return labels of the GlueCoLA object.
"""
return ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册