From 48f01f52cf3d2ee6abd3880d83530fe52e8bf8fe Mon Sep 17 00:00:00 2001 From: bbking Date: Thu, 17 Oct 2019 19:38:21 +0800 Subject: [PATCH] [PaddleNLP] emotion_detection add download.py (#3649) * emotion-detection => 1.6 * ERNIE => 1.6 * [PaddleNLP] update emotion_detection readme * [PaddleNLP] emotion_detection add download.py for windows user --- PaddleNLP/emotion_detection/README.md | 17 +-- PaddleNLP/emotion_detection/download.py | 153 ++++++++++++++++++++++++ PaddleNLP/models/model_check.py | 16 --- 3 files changed, 163 insertions(+), 23 deletions(-) create mode 100644 PaddleNLP/emotion_detection/download.py diff --git a/PaddleNLP/emotion_detection/README.md b/PaddleNLP/emotion_detection/README.md index bc0ee7c4..95280492 100644 --- a/PaddleNLP/emotion_detection/README.md +++ b/PaddleNLP/emotion_detection/README.md @@ -56,7 +56,8 @@ . ├── config.json # 配置文件 ├── config.py # 配置文件读取接口 -├── inference_model.py # 保存 inference_model 的脚本 +├── download.py # 下载数据及预训练模型脚本 +├── inference_model.py # 保存 inference_model 的脚本 ├── reader.py # 数据读取接口 ├── run_classifier.py # 项目的主程序入口,包括训练、预测、评估 ├── run.sh # 训练、预测、评估运行脚本 @@ -86,15 +87,15 @@ python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.ut #### 公开数据集 -这里我们提供一份已标注的、经过分词预处理的机器人聊天数据集,只需运行数据下载脚本 ```sh download_data.sh```,运行成功后,会生成文件夹 ```data```,其目录结构如下: +这里我们提供一份已标注的、经过分词预处理的机器人聊天数据集,只需运行数据下载脚本 ```sh download_data.sh```,或者 ```python download.py dataset``` 运行成功后,会生成文件夹 ```data```,其目录结构如下: ```text . -├── train.tsv # 训练集 -├── dev.tsv # 验证集 -├── test.tsv # 测试集 -├── infer.tsv # 待预测数据 -├── vocab.txt # 词典 +├── train.tsv # 训练集 +├── dev.tsv # 验证集 +├── test.tsv # 测试集 +├── infer.tsv # 待预测数据 +├── vocab.txt # 词典 ``` ### 单机训练 @@ -181,6 +182,8 @@ tar xvf emotion_detection_ernie_finetune-1.0.0.tar.gz ```shell sh download_model.sh +# 或者 +python download.py model ``` 以上两种方式会将预训练的 TextCNN 模型和 ERNIE模型,保存在```pretrain_models```目录下,可直接修改```run.sh```脚本中的```init_checkpoint```参数进行评估、预测。 diff --git a/PaddleNLP/emotion_detection/download.py b/PaddleNLP/emotion_detection/download.py new file mode 100644 index 00000000..419a16db --- /dev/null +++ b/PaddleNLP/emotion_detection/download.py @@ -0,0 +1,153 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Download script, download dataset and pretrain models. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import io +import os +import sys +import time +import hashlib +import tarfile +import requests + + +def usage(): + desc = ("\nDownload datasets and pretrained models for EmotionDetection task.\n" + "Usage:\n" + " 1. python download.py dataset\n" + " 2. python download.py model\n") + print(desc) + + +def md5file(fname): + hash_md5 = hashlib.md5() + with io.open(fname, "rb") as fin: + for chunk in iter(lambda: fin.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +def extract(fname, dir_path): + """ + Extract tar.gz file + """ + try: + tar = tarfile.open(fname, "r:gz") + file_names = tar.getnames() + for file_name in file_names: + tar.extract(file_name, dir_path) + print(file_name) + tar.close() + except Exception as e: + raise e + + +def download(url, filename, md5sum): + """ + Download file and check md5 + """ + retry = 0 + retry_limit = 3 + chunk_size = 4096 + while not (os.path.exists(filename) and md5file(filename) == md5sum): + if retry < retry_limit: + retry += 1 + else: + raise RuntimeError("Cannot download dataset ({0}) with retry {1} times.". + format(url, retry_limit)) + try: + start = time.time() + size = 0 + res = requests.get(url, stream=True) + filesize = int(res.headers['content-length']) + if res.status_code == 200: + print("[Filesize]: %0.2f MB" % (filesize / 1024 / 1024)) + # save by chunk + with io.open(filename, "wb") as fout: + for chunk in res.iter_content(chunk_size=chunk_size): + if chunk: + fout.write(chunk) + size += len(chunk) + pr = '>' * int(size * 50 / filesize) + print('\r[Process ]: %s%.2f%%' % (pr, float(size / filesize*100)), end='') + end = time.time() + print("\n[CostTime]: %.2f s" % (end - start)) + except Exception as e: + print(e) + + +def download_dataset(dir_path): + BASE_URL = "https://baidu-nlp.bj.bcebos.com/" + DATASET_NAME = "emotion_detection-dataset-1.0.0.tar.gz" + DATASET_MD5 = "512d256add5f9ebae2c101b74ab053e9" + file_path = os.path.join(dir_path, DATASET_NAME) + url = BASE_URL + DATASET_NAME + + if not os.path.exists(dir_path): + os.makedirs(dir_path) + # download dataset + print("Downloading dataset: %s" % url) + download(url, file_path, DATASET_MD5) + # extract dataset + print("Extracting dataset: %s" % file_path) + extract(file_path, dir_path) + os.remove(file_path) + + +def download_model(dir_path): + MODELS = {} + BASE_URL = "https://baidu-nlp.bj.bcebos.com/" + CNN_NAME = "emotion_detection_textcnn-1.0.0.tar.gz" + CNN_MD5 = "b7ee648fcd108835c880a5f5fce0d8ab" + ERNIE_NAME = "emotion_detection_ernie_finetune-1.0.0.tar.gz" + ERNIE_MD5 = "dfeb68ddbbc87f466d3bb93e7d11c03a" + MODELS[CNN_NAME] = CNN_MD5 + MODELS[ERNIE_NAME] = ERNIE_MD5 + + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + for model in MODELS: + url = BASE_URL + model + model_path = os.path.join(dir_path, model) + print("Downloading model: %s" % url) + # download model + download(url, model_path, MODELS[model]) + # extract model.tar.gz + print("Extracting model: %s" % model_path) + extract(model_path, dir_path) + os.remove(model_path) + + +if __name__ == '__main__': + if len(sys.argv) != 2: + usage() + sys.exit(1) + + if sys.argv[1] == "dataset": + pwd = os.path.join(os.path.dirname(__file__), './') + download_dataset(pwd) + elif sys.argv[1] == "model": + pwd = os.path.join(os.path.dirname(__file__), './pretrain_models') + download_model(pwd) + else: + usage() + diff --git a/PaddleNLP/models/model_check.py b/PaddleNLP/models/model_check.py index 8d261153..4469be4c 100644 --- a/PaddleNLP/models/model_check.py +++ b/PaddleNLP/models/model_check.py @@ -50,22 +50,6 @@ def check_version(): sys.exit(1) -def check_version(): - """ - Log error and exit when the installed version of paddlepaddle is - not satisfied. - """ - err = "PaddlePaddle version 1.6 or higher is required, " \ - "or a suitable develop version is satisfied as well. \n" \ - "Please make sure the version is good with your code." \ - - try: - fluid.require_version('1.6.0') - except Exception as e: - print(err) - sys.exit(1) - - if __name__ == "__main__": check_cuda(True) -- GitLab