提交 f3cf41f0 编写于 作者: SYSU_BOND's avatar SYSU_BOND 提交者: pkpk

update downloads.py (#3672)

上级 89077088
...@@ -37,18 +37,18 @@ PaddlePaddle的版本要求是:Python 2 版本是 2.7.15+、Python 3 版本是 ...@@ -37,18 +37,18 @@ PaddlePaddle的版本要求是:Python 2 版本是 2.7.15+、Python 3 版本是
本项目涉及的**数据集****预训练模型**的数据可通过执行以下脚本进行快速下载,若仅需使用部分数据,可根据需要参照下列介绍进行部分下载 本项目涉及的**数据集****预训练模型**的数据可通过执行以下脚本进行快速下载,若仅需使用部分数据,可根据需要参照下列介绍进行部分下载
```bash ```bash
python download.py all python downloads.py all
``` ```
或在支持运行shell脚本的环境下执行: 或在支持运行shell脚本的环境下执行:
```bash ```bash
sh download.sh sh downloads.sh
``` ```
#### 2. 训练数据集 #### 2. 训练数据集
下载数据集文件,解压后会生成 `./data/` 文件夹 下载数据集文件,解压后会生成 `./data/` 文件夹
```bash ```bash
python download.py dataset python downloads.py dataset
``` ```
#### 3. 预训练模型 #### 3. 预训练模型
...@@ -56,10 +56,10 @@ python download.py dataset ...@@ -56,10 +56,10 @@ python download.py dataset
我们开源了在自建数据集上训练的词法分析模型,可供用户直接使用,可通过下述链接进行下载: 我们开源了在自建数据集上训练的词法分析模型,可供用户直接使用,可通过下述链接进行下载:
```bash ```bash
# download baseline model # download baseline model
python download.py lac python downloads.py lac
# download ERNIE finetuned model # download ERNIE finetuned model
python download.py finetuned python downloads.py finetuned
``` ```
注:若需进行ERNIE Finetune训练,需自行下载 [ERNIE](https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz) 开放的模型,下载链接为: [https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz](https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz),下载后解压至 `./pretrained/` 目录下。 注:若需进行ERNIE Finetune训练,需自行下载 [ERNIE](https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz) 开放的模型,下载链接为: [https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz](https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz),下载后解压至 `./pretrained/` 目录下。
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Download script, download dataset and pretrain models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import sys
import time
import hashlib
import tarfile
import requests
FILE_INFO = {
'BASE_URL': 'https://baidu-nlp.bj.bcebos.com/',
'DATA': {
'name': 'lexical_analysis-dataset-2.0.0.tar.gz',
'md5': '71e4a9a36d0f0177929a1bccedca7dba'
},
'LAC_MODEL': {
'name': 'lexical_analysis-2.0.0.tar.gz',
'md5': "fc1daef00de9564083c7dc7b600504ca"
},
'ERNIE_MODEL': {
'name': 'ERNIE_stable-1.0.1.tar.gz',
'md5': "bab876a874b5374a78d7af93384d3bfa"
},
'FINETURN_MODEL': {
'name': 'lexical_analysis_finetuned-1.0.0.tar.gz',
'md5': "ee2c7614b06dcfd89561fbbdaac34342"
}
}
def usage():
desc = ("\nDownload datasets and pretrained models for LAC.\n"
"Usage:\n"
" 1. python download.py all\n"
" 2. python download.py dataset\n"
" 3. python download.py lac\n"
" 4. python download.py finetuned\n"
" 5. python download.py ernie\n")
print(desc)
def md5file(fname):
hash_md5 = hashlib.md5()
with io.open(fname, "rb") as fin:
for chunk in iter(lambda: fin.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def extract(fname, dir_path):
"""
Extract tar.gz file
"""
try:
tar = tarfile.open(fname, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name, dir_path)
print(file_name)
tar.close()
except Exception as e:
raise e
def _download(url, filename, md5sum):
"""
Download file and check md5
"""
retry = 0
retry_limit = 3
chunk_size = 4096
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if retry < retry_limit:
retry += 1
else:
raise RuntimeError(
"Cannot download dataset ({0}) with retry {1} times.".format(
url, retry_limit))
try:
start = time.time()
size = 0
res = requests.get(url, stream=True)
filesize = int(res.headers['content-length'])
if res.status_code == 200:
print("[Filesize]: %0.2f MB" % (filesize / 1024 / 1024))
# save by chunk
with io.open(filename, "wb") as fout:
for chunk in res.iter_content(chunk_size=chunk_size):
if chunk:
fout.write(chunk)
size += len(chunk)
pr = '>' * int(size * 50 / filesize)
print(
'\r[Process ]: %s%.2f%%' %
(pr, float(size / filesize * 100)),
end='')
end = time.time()
print("\n[CostTime]: %.2f s" % (end - start))
except Exception as e:
print(e)
def download(name, dir_path):
url = FILE_INFO['BASE_URL'] + FILE_INFO[name]['name']
file_path = os.path.join(dir_path, FILE_INFO[name]['name'])
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# download data
print("Downloading : %s" % name)
_download(url, file_path, FILE_INFO[name]['md5'])
# extract data
print("Extracting : %s" % file_path)
extract(file_path, dir_path)
os.remove(file_path)
if __name__ == '__main__':
if len(sys.argv) != 2:
usage()
sys.exit(1)
pwd = os.path.join(os.path.dirname(__file__), './')
ernie_dir = os.path.join(os.path.dirname(__file__), './pretrained')
if sys.argv[1] == 'all':
download('DATA', pwd)
download('LAC_MODEL', pwd)
download('FINETURN_MODEL', pwd)
download('ERNIE_MODEL', ernie_dir)
if sys.argv[1] == "dataset":
download('DATA', pwd)
elif sys.argv[1] == "lac":
download('LAC_MODEL', pwd)
elif sys.argv[1] == "finetuned":
download('FINETURN_MODEL', pwd)
elif sys.argv[1] == "ernie":
download('ERNIE_MODEL', ernie_dir)
else:
usage()
...@@ -73,7 +73,7 @@ class Dataset(object): ...@@ -73,7 +73,7 @@ class Dataset(object):
def get_num_examples(self, filename): def get_num_examples(self, filename):
"""num of line of file""" """num of line of file"""
return sum(1 for line in open(filename, "r")) return sum(1 for line in io.open(filename, "r", encoding='utf8'))
def word_to_ids(self, words): def word_to_ids(self, words):
"""convert word to word index""" """convert word to word index"""
...@@ -107,16 +107,17 @@ class Dataset(object): ...@@ -107,16 +107,17 @@ class Dataset(object):
fread = io.open(filename, "r", encoding="utf-8") fread = io.open(filename, "r", encoding="utf-8")
if mode == "infer": if mode == "infer":
for line in fread: for line in fread:
words= line.strip() words = line.strip()
word_ids = self.word_to_ids(words) word_ids = self.word_to_ids(words)
yield (word_ids[0:max_seq_len],) yield (word_ids[0:max_seq_len], )
else: else:
headline = next(fread) headline = next(fread)
headline = headline.strip().split('\t') headline = headline.strip().split('\t')
assert len(headline) == 2 and headline[0] == "text_a" and headline[1] == "label" assert len(headline) == 2 and headline[
0] == "text_a" and headline[1] == "label"
for line in fread: for line in fread:
words, labels = line.strip("\n").split("\t") words, labels = line.strip("\n").split("\t")
if len(words)<1: if len(words) < 1:
continue continue
word_ids = self.word_to_ids(words.split("\002")) word_ids = self.word_to_ids(words.split("\002"))
label_ids = self.label_to_ids(labels.split("\002")) label_ids = self.label_to_ids(labels.split("\002"))
......
...@@ -48,19 +48,21 @@ class ArgumentGroup(object): ...@@ -48,19 +48,21 @@ class ArgumentGroup(object):
help=help + ' Default: %(default)s.', help=help + ' Default: %(default)s.',
**kwargs) **kwargs)
def load_yaml(parser, file_name, **kwargs): def load_yaml(parser, file_name, **kwargs):
with open(file_name) as f: with open(file_name) as f:
args = yaml.load(f, Loader=yaml.FullLoader) args = yaml.load(f)
for title in args: for title in args:
group = parser.add_argument_group(title=title, description='') group = parser.add_argument_group(title=title, description='')
for name in args[title]: for name in args[title]:
_type = type(args[title][name]['val']) _type = type(args[title][name]['val'])
_type = str2bool if _type==bool else _type _type = str2bool if _type == bool else _type
group.add_argument( group.add_argument(
"--"+name, "--" + name,
default=args[title][name]['val'], default=args[title][name]['val'],
type=_type, type=_type,
help=args[title][name]['meaning'] + ' Default: %(default)s.', help=args[title][name]['meaning'] +
' Default: %(default)s.',
**kwargs) **kwargs)
...@@ -115,7 +117,9 @@ def parse_result(words, crf_decode, dataset): ...@@ -115,7 +117,9 @@ def parse_result(words, crf_decode, dataset):
for sent_index in range(batch_size): for sent_index in range(batch_size):
begin, end = offset_list[sent_index], offset_list[sent_index + 1] begin, end = offset_list[sent_index], offset_list[sent_index + 1]
sent = [dataset.id2word_dict[str(id[0])] for id in words[begin:end]] sent = [dataset.id2word_dict[str(id[0])] for id in words[begin:end]]
tags = [dataset.id2label_dict[str(id[0])] for id in crf_decode[begin:end]] tags = [
dataset.id2label_dict[str(id[0])] for id in crf_decode[begin:end]
]
sent_out = [] sent_out = []
tags_out = [] tags_out = []
...@@ -128,7 +132,7 @@ def parse_result(words, crf_decode, dataset): ...@@ -128,7 +132,7 @@ def parse_result(words, crf_decode, dataset):
continue continue
# for the beginning of word # for the beginning of word
if tag.endswith("-B") or (tag == "O" and tags[ind-1]!="O"): if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
sent_out.append(parital_word) sent_out.append(parital_word)
tags_out.append(tag.split('-')[0]) tags_out.append(tag.split('-')[0])
parital_word = sent[ind] parital_word = sent[ind]
...@@ -137,12 +141,13 @@ def parse_result(words, crf_decode, dataset): ...@@ -137,12 +141,13 @@ def parse_result(words, crf_decode, dataset):
parital_word += sent[ind] parital_word += sent[ind]
# append the last word, except for len(tags)=0 # append the last word, except for len(tags)=0
if len(sent_out)<len(tags_out): if len(sent_out) < len(tags_out):
sent_out.append(parital_word) sent_out.append(parital_word)
batch_out.append([sent_out,tags_out]) batch_out.append([sent_out, tags_out])
return batch_out return batch_out
def init_checkpoint(exe, init_checkpoint_path, main_program): def init_checkpoint(exe, init_checkpoint_path, main_program):
""" """
Init CheckPoint Init CheckPoint
...@@ -165,6 +170,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program): ...@@ -165,6 +170,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
predicate=existed_persitables) predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path)) print("Load model from {}".format(init_checkpoint_path))
def init_pretraining_params(exe, def init_pretraining_params(exe,
pretraining_params_path, pretraining_params_path,
main_program, main_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册