提交 48f01f52 编写于 作者: B bbking 提交者: pkpk

[PaddleNLP] emotion_detection add download.py (#3649)

* emotion-detection => 1.6

* ERNIE => 1.6

* [PaddleNLP] update emotion_detection readme

* [PaddleNLP] emotion_detection add download.py for windows user
上级 e2432ea4
...@@ -56,7 +56,8 @@ ...@@ -56,7 +56,8 @@
. .
├── config.json # 配置文件 ├── config.json # 配置文件
├── config.py # 配置文件读取接口 ├── config.py # 配置文件读取接口
├── inference_model.py # 保存 inference_model 的脚本 ├── download.py # 下载数据及预训练模型脚本
├── inference_model.py # 保存 inference_model 的脚本
├── reader.py # 数据读取接口 ├── reader.py # 数据读取接口
├── run_classifier.py # 项目的主程序入口,包括训练、预测、评估 ├── run_classifier.py # 项目的主程序入口,包括训练、预测、评估
├── run.sh # 训练、预测、评估运行脚本 ├── run.sh # 训练、预测、评估运行脚本
...@@ -86,15 +87,15 @@ python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.ut ...@@ -86,15 +87,15 @@ python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.ut
#### 公开数据集 #### 公开数据集
这里我们提供一份已标注的、经过分词预处理的机器人聊天数据集,只需运行数据下载脚本 ```sh download_data.sh```,运行成功后,会生成文件夹 ```data```,其目录结构如下: 这里我们提供一份已标注的、经过分词预处理的机器人聊天数据集,只需运行数据下载脚本 ```sh download_data.sh```或者 ```python download.py dataset``` 运行成功后,会生成文件夹 ```data```,其目录结构如下:
```text ```text
. .
├── train.tsv # 训练集 ├── train.tsv # 训练集
├── dev.tsv # 验证集 ├── dev.tsv # 验证集
├── test.tsv # 测试集 ├── test.tsv # 测试集
├── infer.tsv # 待预测数据 ├── infer.tsv # 待预测数据
├── vocab.txt # 词典 ├── vocab.txt # 词典
``` ```
### 单机训练 ### 单机训练
...@@ -181,6 +182,8 @@ tar xvf emotion_detection_ernie_finetune-1.0.0.tar.gz ...@@ -181,6 +182,8 @@ tar xvf emotion_detection_ernie_finetune-1.0.0.tar.gz
```shell ```shell
sh download_model.sh sh download_model.sh
# 或者
python download.py model
``` ```
以上两种方式会将预训练的 TextCNN 模型和 ERNIE模型,保存在```pretrain_models```目录下,可直接修改```run.sh```脚本中的```init_checkpoint```参数进行评估、预测。 以上两种方式会将预训练的 TextCNN 模型和 ERNIE模型,保存在```pretrain_models```目录下,可直接修改```run.sh```脚本中的```init_checkpoint```参数进行评估、预测。
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Download script, download dataset and pretrain models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import sys
import time
import hashlib
import tarfile
import requests
def usage():
desc = ("\nDownload datasets and pretrained models for EmotionDetection task.\n"
"Usage:\n"
" 1. python download.py dataset\n"
" 2. python download.py model\n")
print(desc)
def md5file(fname):
hash_md5 = hashlib.md5()
with io.open(fname, "rb") as fin:
for chunk in iter(lambda: fin.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def extract(fname, dir_path):
"""
Extract tar.gz file
"""
try:
tar = tarfile.open(fname, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name, dir_path)
print(file_name)
tar.close()
except Exception as e:
raise e
def download(url, filename, md5sum):
"""
Download file and check md5
"""
retry = 0
retry_limit = 3
chunk_size = 4096
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if retry < retry_limit:
retry += 1
else:
raise RuntimeError("Cannot download dataset ({0}) with retry {1} times.".
format(url, retry_limit))
try:
start = time.time()
size = 0
res = requests.get(url, stream=True)
filesize = int(res.headers['content-length'])
if res.status_code == 200:
print("[Filesize]: %0.2f MB" % (filesize / 1024 / 1024))
# save by chunk
with io.open(filename, "wb") as fout:
for chunk in res.iter_content(chunk_size=chunk_size):
if chunk:
fout.write(chunk)
size += len(chunk)
pr = '>' * int(size * 50 / filesize)
print('\r[Process ]: %s%.2f%%' % (pr, float(size / filesize*100)), end='')
end = time.time()
print("\n[CostTime]: %.2f s" % (end - start))
except Exception as e:
print(e)
def download_dataset(dir_path):
BASE_URL = "https://baidu-nlp.bj.bcebos.com/"
DATASET_NAME = "emotion_detection-dataset-1.0.0.tar.gz"
DATASET_MD5 = "512d256add5f9ebae2c101b74ab053e9"
file_path = os.path.join(dir_path, DATASET_NAME)
url = BASE_URL + DATASET_NAME
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# download dataset
print("Downloading dataset: %s" % url)
download(url, file_path, DATASET_MD5)
# extract dataset
print("Extracting dataset: %s" % file_path)
extract(file_path, dir_path)
os.remove(file_path)
def download_model(dir_path):
MODELS = {}
BASE_URL = "https://baidu-nlp.bj.bcebos.com/"
CNN_NAME = "emotion_detection_textcnn-1.0.0.tar.gz"
CNN_MD5 = "b7ee648fcd108835c880a5f5fce0d8ab"
ERNIE_NAME = "emotion_detection_ernie_finetune-1.0.0.tar.gz"
ERNIE_MD5 = "dfeb68ddbbc87f466d3bb93e7d11c03a"
MODELS[CNN_NAME] = CNN_MD5
MODELS[ERNIE_NAME] = ERNIE_MD5
if not os.path.exists(dir_path):
os.makedirs(dir_path)
for model in MODELS:
url = BASE_URL + model
model_path = os.path.join(dir_path, model)
print("Downloading model: %s" % url)
# download model
download(url, model_path, MODELS[model])
# extract model.tar.gz
print("Extracting model: %s" % model_path)
extract(model_path, dir_path)
os.remove(model_path)
if __name__ == '__main__':
if len(sys.argv) != 2:
usage()
sys.exit(1)
if sys.argv[1] == "dataset":
pwd = os.path.join(os.path.dirname(__file__), './')
download_dataset(pwd)
elif sys.argv[1] == "model":
pwd = os.path.join(os.path.dirname(__file__), './pretrain_models')
download_model(pwd)
else:
usage()
...@@ -50,22 +50,6 @@ def check_version(): ...@@ -50,22 +50,6 @@ def check_version():
sys.exit(1) sys.exit(1)
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
print(err)
sys.exit(1)
if __name__ == "__main__": if __name__ == "__main__":
check_cuda(True) check_cuda(True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册