未验证 提交 4bb3163d 编写于 作者: L LiuHao 提交者: GitHub

update (#3822)

上级 dc8ecdf1
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
1. PaddlePaddle 安装 1. PaddlePaddle 安装
本项目依赖于 PaddlePaddle Fluid 1.3.2 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装 本项目依赖于 PaddlePaddle Fluid 1.6 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
2. 代码安装 2. 代码安装
...@@ -42,7 +42,7 @@ ...@@ -42,7 +42,7 @@
3. 环境依赖 3. 环境依赖
请参考 PaddlePaddle [安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html) 部分的内容 Python 2 的版本要求 2.7.15+,Python 3 的版本要求 3.5.1+/3.6/3.7,其它环境请参考 PaddlePaddle [安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html) 部分的内容
### 代码结构说明 ### 代码结构说明
...@@ -118,10 +118,10 @@ Running type options: ...@@ -118,10 +118,10 @@ Running type options:
Model config options: Model config options:
--model_type {bow_net,cnn_net,lstm_net,bilstm_net,gru_net,textcnn_net} --model_type {bow_net,cnn_net,lstm_net,bilstm_net,gru_net,textcnn_net}
Model type to run the task. Default: textcnn_net. Model type to run the task. Default: bilstm_net.
--init_checkpoint INIT_CHECKPOINT --init_checkpoint INIT_CHECKPOINT
Init checkpoint to resume training from. Default: . Init checkpoint to resume training from. Default: .
--save_checkpoint_dir SAVE_CHECKPOINT_DIR --checkpoints SAVE_CHECKPOINT_DIR
Directory path to save checkpoints Default: . Directory path to save checkpoints Default: .
... ...
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Download script, download dataset and pretrain models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import sys
import time
import hashlib
import tarfile
import requests
def usage():
desc = ("\nDownload datasets and pretrained models for Sentiment Classification task.\n"
"Usage:\n"
" 1. python download.py dataset\n"
" 2. python download.py model\n")
print(desc)
def extract(fname, dir_path):
"""
Extract tar.gz file
"""
try:
tar = tarfile.open(fname, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name, dir_path)
print(file_name)
tar.close()
except Exception as e:
raise e
def download(url, filename):
"""
Download file
"""
retry = 0
retry_limit = 3
chunk_size = 4096
while not (os.path.exists(filename):
if retry < retry_limit:
retry += 1
else:
raise RuntimeError("Cannot download dataset ({0}) with retry {1} times.".
format(url, retry_limit))
try:
start = time.time()
size = 0
res = requests.get(url, stream=True)
filesize = int(res.headers['content-length'])
if res.status_code == 200:
print("[Filesize]: %0.2f MB" % (filesize / 1024 / 1024))
# save by chunk
with io.open(filename, "wb") as fout:
for chunk in res.iter_content(chunk_size=chunk_size):
if chunk:
fout.write(chunk)
size += len(chunk)
pr = '>' * int(size * 50 / filesize)
print('\r[Process ]: %s%.2f%%' % (pr, float(size / filesize*100)), end='')
end = time.time()
print("\n[CostTime]: %.2f s" % (end - start))
except Exception as e:
print(e)
def download_dataset(dir_path):
BASE_URL = "https://baidu-nlp.bj.bcebos.com/"
DATASET_NAME = "sentiment_classification-dataset-1.0.0.tar.gz"
file_path = os.path.join(dir_path, DATASET_NAME)
url = BASE_URL + DATASET_NAME
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# download dataset
print("Downloading dataset: %s" % url)
download(url, file_path)
# extract dataset
print("Extracting dataset: %s" % file_path)
extract(file_path, dir_path)
os.remove(file_path)
def download_model(dir_path):
BASE_URL = "https://baidu-nlp.bj.bcebos.com/"
MODEL_NAME = "sentiment_classification-1.0.0.tar.gz"
if not os.path.exists(dir_path):
os.makedirs(dir_path)
url = BASE_URL + MODEL_NAME
model_path = os.path.join(dir_path, model)
print("Downloading model: %s" % url)
# download model
download(url, model_path, MODEL_NAME)
# extract model.tar.gz
print("Extracting model: %s" % model_path)
extract(model_path, dir_path)
os.remove(model_path)
if __name__ == "__main__":
if len(sys) != 2:
usage()
sys.exit(1)
if sys.argv[1] == "dataset":
pwd = os.path.join(os.path.dirname(__file__), "./")
download_dataset(pwd)
elif sys.argv[1] == "model":
pwd = os.path.join(os.path.dirname(__file__), "./models")
download_model(pwd)
else:
usage()
...@@ -14,7 +14,7 @@ import utils ...@@ -14,7 +14,7 @@ import utils
import reader import reader
from run_ernie_classifier import ernie_pyreader from run_ernie_classifier import ernie_pyreader
from models.representation.ernie import ErnieConfig from models.representation.ernie import ErnieConfig
from models.representation.ernie import ernie_encoder from models.representation.ernie import ernie_encoder, ernie_encoder_with_paddle_hub
from preprocess.ernie import task_reader from preprocess.ernie import task_reader
def do_save_inference_model(args): def do_save_inference_model(args):
...@@ -39,8 +39,11 @@ def do_save_inference_model(args): ...@@ -39,8 +39,11 @@ def do_save_inference_model(args):
infer_pyreader, ernie_inputs, labels = ernie_pyreader( infer_pyreader, ernie_inputs, labels = ernie_pyreader(
args, args,
pyreader_name="infer_reader") pyreader_name="infer_reader")
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config) if args.use_paddle_hub:
embeddings = ernie_encoder_with_paddle_hub(ernie_inputs, args.max_seq_len)
else:
embeddings = ernie_encoder(ernie_inputs, ernie_config=ernie_config)
probs = create_model(args, probs = create_model(args,
embeddings, embeddings,
......
...@@ -72,12 +72,12 @@ class SentaProcessor(object): ...@@ -72,12 +72,12 @@ class SentaProcessor(object):
Generate data for train, dev or infer Generate data for train, dev or infer
""" """
if phase == "train": if phase == "train":
return paddle.batch(self.get_train_examples(self.data_dir, epoch, self.max_seq_len), batch_size) return fluid.io.batch(self.get_train_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
#return self.get_train_examples(self.data_dir, epoch, self.max_seq_len) #return self.get_train_examples(self.data_dir, epoch, self.max_seq_len)
elif phase == "dev": elif phase == "dev":
return paddle.batch(self.get_dev_examples(self.data_dir, epoch, self.max_seq_len), batch_size) return fluid.io.batch(self.get_dev_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
elif phase == "infer": elif phase == "infer":
return paddle.batch(self.get_test_examples(self.data_dir, epoch, self.max_seq_len), batch_size) return fluid.io.batch(self.get_test_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
else: else:
raise ValueError( raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'infer'].") "Unknown phase, which should be in ['train', 'dev', 'infer'].")
...@@ -21,6 +21,7 @@ from nets import cnn_net ...@@ -21,6 +21,7 @@ from nets import cnn_net
from nets import bilstm_net from nets import bilstm_net
from nets import gru_net from nets import gru_net
from models.model_check import check_cuda from models.model_check import check_cuda
from models.model_check import check_version
from config import PDConfig from config import PDConfig
import paddle import paddle
...@@ -39,11 +40,11 @@ def create_model(args, ...@@ -39,11 +40,11 @@ def create_model(args,
""" """
data = fluid.layers.data( data = fluid.layers.data(
name="src_ids", shape=[-1, args.max_seq_len, 1], dtype='int64') name="src_ids", shape=[-1, args.max_seq_len], dtype='int64')
label = fluid.layers.data( label = fluid.layers.data(
name="label", shape=[-1, 1], dtype="int64") name="label", shape=[-1, 1], dtype="int64")
seq_len = fluid.layers.data( seq_len = fluid.layers.data(
name="seq_len", shape=[-1, 1], dtype="int64") name="seq_len", shape=[-1], dtype="int64")
data_reader = fluid.io.PyReader(feed_list=[data, label, seq_len], data_reader = fluid.io.PyReader(feed_list=[data, label, seq_len],
capacity=4, iterable=False) capacity=4, iterable=False)
......
...@@ -48,16 +48,21 @@ def ernie_pyreader(args, pyreader_name): ...@@ -48,16 +48,21 @@ def ernie_pyreader(args, pyreader_name):
labels = fluid.layers.data( labels = fluid.layers.data(
name="labels", shape=[-1, 1], dtype="int64") name="labels", shape=[-1, 1], dtype="int64")
seq_lens = fluid.layers.data( seq_lens = fluid.layers.data(
name="seq_lens", shape=[-1, 1], dtype="int64") name="seq_lens", shape=[-1], dtype="int64")
pyreader = fluid.io.PyReader(feed_list=[src_ids, sent_ids, pos_ids, input_mask, labels, seq_lens], pyreader = fluid.io.DataLoader.from_generator(
capacity=4, iterable=False) feed_list=[src_ids, sent_ids, pos_ids, input_mask, labels, seq_lens],
capacity=50,
iterable=False,
use_double_buffer=True)
ernie_inputs = { ernie_inputs = {
"src_ids": src_ids, "src_ids": src_ids,
"sent_ids": sent_ids, "sent_ids": sent_ids,
"pos_ids": pos_ids, "pos_ids": pos_ids,
"input_mask": input_mask, "input_mask": input_mask,
"seq_lens": seq_lens} "seq_lens": seq_lens}
return pyreader, ernie_inputs, labels return pyreader, ernie_inputs, labels
def create_model(args, def create_model(args,
...@@ -299,15 +304,15 @@ def main(args): ...@@ -299,15 +304,15 @@ def main(args):
if args.do_train: if args.do_train:
train_exe = exe train_exe = exe
train_pyreader.decorate_batch_generator(train_data_generator) train_pyreader.set_batch_generator(train_data_generator)
else: else:
train_exe = None train_exe = None
if args.do_val: if args.do_val:
test_exe = exe test_exe = exe
test_pyreader.decorate_batch_generator(test_data_generator) test_pyreader.set_batch_generator(test_data_generator)
if args.do_infer: if args.do_infer:
test_exe = exe test_exe = exe
infer_pyreader.decorate_batch_generator(infer_data_generator) infer_pyreader.set_batch_generator(infer_data_generator)
if args.do_train: if args.do_train:
train_pyreader.start() train_pyreader.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册