Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
b296c718
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b296c718
编写于
4月 14, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add choices for scripts
上级
da2db0f3
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
22 addition
and
30 deletion
+22
-30
demo/sequence-labeling/sequence_label.py
demo/sequence-labeling/sequence_label.py
+4
-3
demo/text-classification/README.md
demo/text-classification/README.md
+7
-5
demo/text-classification/run_classifier.sh
demo/text-classification/run_classifier.sh
+2
-2
demo/text-classification/text_classifier.py
demo/text-classification/text_classifier.py
+1
-2
paddlehub/dataset/chnsenticorp.py
paddlehub/dataset/chnsenticorp.py
+2
-2
paddlehub/dataset/lcqmc.py
paddlehub/dataset/lcqmc.py
+2
-2
paddlehub/dataset/msra_ner.py
paddlehub/dataset/msra_ner.py
+2
-8
paddlehub/dataset/nlpcc_dbqa.py
paddlehub/dataset/nlpcc_dbqa.py
+2
-2
paddlehub/reader/nlp_reader.py
paddlehub/reader/nlp_reader.py
+0
-4
未找到文件。
demo/sequence-labeling/sequence_label.py
浏览文件 @
b296c718
...
...
@@ -21,12 +21,13 @@ import paddlehub as hub
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--num_epoch"
,
type
=
int
,
default
=
3
,
help
=
"Number of epoches for fine-tuning."
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
"Whether use GPU for finetuning, input should be True or False"
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
5e-5
,
help
=
"Learning rate used to train with warmup."
)
parser
.
add_argument
(
"--weight_decay"
,
type
=
float
,
default
=
0.01
,
help
=
"Weight decay rate for L2 regularizer."
)
parser
.
add_argument
(
"--
checkpoint_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory to model checkpoint
"
)
parser
.
add_argument
(
"--
warmup_proportion"
,
type
=
float
,
default
=
0.0
,
help
=
"Warmup proportion params for warmup strategy
"
)
parser
.
add_argument
(
"--max_seq_len"
,
type
=
int
,
default
=
512
,
help
=
"Number of words of the longest seqence."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
32
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--checkpoint_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory to model checkpoint"
)
args
=
parser
.
parse_args
()
# yapf: enable.
...
...
@@ -76,7 +77,7 @@ if __name__ == '__main__':
# Setup runing config for PaddleHub Finetune API
config
=
hub
.
RunConfig
(
use_cuda
=
True
,
use_cuda
=
args
.
use_gpu
,
num_epoch
=
args
.
num_epoch
,
batch_size
=
args
.
batch_size
,
checkpoint_dir
=
args
.
checkpoint_dir
,
...
...
demo/text-classification/README.md
浏览文件 @
b296c718
...
...
@@ -149,8 +149,10 @@ python cls_predict.py --checkpoint_dir $CKPT_DIR --max_seq_len 128
```
其中CKPT_DIR为Finetune API保存最佳模型的路径, max_seq_len是ERNIE模型的最大序列长度,
*请与训练时配置的参数保持一致*
参数配置正确后,请执行脚本
`sh run_predict.sh`
,即可看到以下文本分类预测结果。如需了解更多预测步骤,请参考
`cls_predict.py`
参数配置正确后,请执行脚本
`sh run_predict.sh`
,即可看到以下文本分类预测结果, 以及最终准确率。
如需了解更多预测步骤,请参考
`cls_predict.py`
```
text=键盘缝隙大进灰,装系统自己不会装,屏幕有点窄玩游戏人物有点变形 label=0 predict=0
accuracy = 0.954267
```
demo/text-classification/run_classifier.sh
浏览文件 @
b296c718
...
...
@@ -15,5 +15,5 @@ python -u text_classifier.py \
--checkpoint_dir
=
${
CKPT_DIR
}
\
--learning_rate
=
5e-5
\
--weight_decay
=
0.01
\
--max_seq_len
=
128
--num_epoch
=
3
\
--max_seq_len
=
128
\
--num_epoch
=
3
demo/text-classification/text_classifier.py
浏览文件 @
b296c718
...
...
@@ -22,8 +22,7 @@ import paddlehub as hub
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--num_epoch"
,
type
=
int
,
default
=
3
,
help
=
"Number of epoches for fine-tuning."
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use GPU for finetuning, input should be True or False"
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"senticorp"
,
help
=
"Directory to model checkpoint"
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"chnsenticorp"
,
help
=
"Directory to model checkpoint"
,
choices
=
[
"chnsenticorp"
,
"nlpcc_dbqa"
,
"lcqmc"
])
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
5e-5
,
help
=
"Learning rate used to train with warmup."
)
parser
.
add_argument
(
"--weight_decay"
,
type
=
float
,
default
=
0.01
,
help
=
"Weight decay rate for L2 regularizer."
)
parser
.
add_argument
(
"--warmup_proportion"
,
type
=
float
,
default
=
0.0
,
help
=
"Warmup proportion params for warmup strategy"
)
...
...
paddlehub/dataset/chnsenticorp.py
浏览文件 @
b296c718
...
...
@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader
from
paddlehub.common.dir
import
DATA_HOME
from
paddlehub.common.logger
import
logger
DATA_URL
=
"https://paddlehub-dataset.bj.bcebos.com/chnsenticorp.tar.gz"
_
DATA_URL
=
"https://paddlehub-dataset.bj.bcebos.com/chnsenticorp.tar.gz"
class
ChnSentiCorp
(
HubDataset
):
...
...
@@ -38,7 +38,7 @@ class ChnSentiCorp(HubDataset):
self
.
dataset_dir
=
os
.
path
.
join
(
DATA_HOME
,
"chnsenticorp"
)
if
not
os
.
path
.
exists
(
self
.
dataset_dir
):
ret
,
tips
,
self
.
dataset_dir
=
default_downloader
.
download_file_and_uncompress
(
url
=
DATA_URL
,
save_path
=
DATA_HOME
,
print_progress
=
True
)
url
=
_
DATA_URL
,
save_path
=
DATA_HOME
,
print_progress
=
True
)
else
:
logger
.
info
(
"Dataset {} already cached."
.
format
(
self
.
dataset_dir
))
...
...
paddlehub/dataset/lcqmc.py
浏览文件 @
b296c718
...
...
@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader
from
paddlehub.common.dir
import
DATA_HOME
from
paddlehub.common.logger
import
logger
DATA_URL
=
"https://paddlehub-dataset.bj.bcebos.com/lcqmc.tar.gz"
_
DATA_URL
=
"https://paddlehub-dataset.bj.bcebos.com/lcqmc.tar.gz"
class
LCQMC
(
HubDataset
):
...
...
@@ -33,7 +33,7 @@ class LCQMC(HubDataset):
self
.
dataset_dir
=
os
.
path
.
join
(
DATA_HOME
,
"lcqmc"
)
if
not
os
.
path
.
exists
(
self
.
dataset_dir
):
ret
,
tips
,
self
.
dataset_dir
=
default_downloader
.
download_file_and_uncompress
(
url
=
DATA_URL
,
save_path
=
DATA_HOME
,
print_progress
=
True
)
url
=
_
DATA_URL
,
save_path
=
DATA_HOME
,
print_progress
=
True
)
else
:
logger
.
info
(
"Dataset {} already cached."
.
format
(
self
.
dataset_dir
))
...
...
paddlehub/dataset/msra_ner.py
浏览文件 @
b296c718
...
...
@@ -26,7 +26,7 @@ from paddlehub.common.downloader import default_downloader
from
paddlehub.common.dir
import
DATA_HOME
from
paddlehub.common.logger
import
logger
DATA_URL
=
"https://paddlehub-dataset.bj.bcebos.com/msra_ner.tar.gz"
_
DATA_URL
=
"https://paddlehub-dataset.bj.bcebos.com/msra_ner.tar.gz"
class
MSRA_NER
(
HubDataset
):
...
...
@@ -41,20 +41,14 @@ class MSRA_NER(HubDataset):
self
.
dataset_dir
=
os
.
path
.
join
(
DATA_HOME
,
"msra_ner"
)
if
not
os
.
path
.
exists
(
self
.
dataset_dir
):
ret
,
tips
,
self
.
dataset_dir
=
default_downloader
.
download_file_and_uncompress
(
url
=
DATA_URL
,
save_path
=
DATA_HOME
,
print_progress
=
True
)
url
=
_
DATA_URL
,
save_path
=
DATA_HOME
,
print_progress
=
True
)
else
:
logger
.
info
(
"Dataset {} already cached."
.
format
(
self
.
dataset_dir
))
self
.
_load_label_map
()
self
.
_load_train_examples
()
self
.
_load_test_examples
()
self
.
_load_dev_examples
()
def
_load_label_map
(
self
):
self
.
label_map_file
=
os
.
path
.
join
(
self
.
dataset_dir
,
"label_map.json"
)
with
open
(
self
.
label_map_file
)
as
fi
:
self
.
label_map
=
json
.
load
(
fi
)
def
_load_train_examples
(
self
):
train_file
=
os
.
path
.
join
(
self
.
dataset_dir
,
"train.tsv"
)
self
.
train_examples
=
self
.
_read_tsv
(
train_file
)
...
...
paddlehub/dataset/nlpcc_dbqa.py
浏览文件 @
b296c718
...
...
@@ -25,7 +25,7 @@ from paddlehub.common.downloader import default_downloader
from
paddlehub.common.dir
import
DATA_HOME
from
paddlehub.common.logger
import
logger
DATA_URL
=
"https://paddlehub-dataset.bj.bcebos.com/nlpcc-dbqa.tar.gz"
_
DATA_URL
=
"https://paddlehub-dataset.bj.bcebos.com/nlpcc-dbqa.tar.gz"
class
NLPCC_DBQA
(
HubDataset
):
...
...
@@ -39,7 +39,7 @@ class NLPCC_DBQA(HubDataset):
self
.
dataset_dir
=
os
.
path
.
join
(
DATA_HOME
,
"nlpcc-dbqa"
)
if
not
os
.
path
.
exists
(
self
.
dataset_dir
):
ret
,
tips
,
self
.
dataset_dir
=
default_downloader
.
download_file_and_uncompress
(
url
=
DATA_URL
,
save_path
=
DATA_HOME
,
print_progress
=
True
)
url
=
_
DATA_URL
,
save_path
=
DATA_HOME
,
print_progress
=
True
)
else
:
logger
.
info
(
"Dataset {} already cached."
.
format
(
self
.
dataset_dir
))
...
...
paddlehub/reader/nlp_reader.py
浏览文件 @
b296c718
...
...
@@ -76,10 +76,6 @@ class BaseReader(object):
"""Gets a collection of `InputExample`s for prediction."""
return
self
.
dataset
.
get_test_examples
()
def
get_labels
(
self
):
"""Gets the list of labels for this data set."""
return
self
.
dataset
.
get_labels
()
def
get_train_progress
(
self
):
"""Gets progress for training phase."""
return
self
.
current_example
,
self
.
current_epoch
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录