Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
3e1aa4bd
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3e1aa4bd
编写于
11月 20, 2019
作者:
K
kinghuin
提交者:
wuzewu
11月 22, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support ernie-tiny cls
上级
271883bf
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
100 addition
and
44 deletion
+100
-44
demo/text-classification/text_classifier.py
demo/text-classification/text_classifier.py
+29
-42
paddlehub/module/module.py
paddlehub/module/module.py
+13
-0
paddlehub/reader/nlp_reader.py
paddlehub/reader/nlp_reader.py
+8
-2
paddlehub/reader/tokenization.py
paddlehub/reader/tokenization.py
+50
-0
未找到文件。
demo/text-classification/text_classifier.py
浏览文件 @
3e1aa4bd
...
@@ -33,7 +33,6 @@ parser.add_argument("--max_seq_len", type=int, default=512, help="Number of word
...
@@ -33,7 +33,6 @@ parser.add_argument("--max_seq_len", type=int, default=512, help="Number of word
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
32
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
32
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--use_pyreader"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use pyreader to feed data."
)
parser
.
add_argument
(
"--use_pyreader"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use pyreader to feed data."
)
parser
.
add_argument
(
"--use_data_parallel"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use data parallel."
)
parser
.
add_argument
(
"--use_data_parallel"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether use data parallel."
)
parser
.
add_argument
(
"--use_taskid"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"Whether to use taskid ,if yes to use ernie v2."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# yapf: enable.
# yapf: enable.
...
@@ -43,7 +42,7 @@ if __name__ == '__main__':
...
@@ -43,7 +42,7 @@ if __name__ == '__main__':
# Download dataset and use ClassifyReader to read dataset
# Download dataset and use ClassifyReader to read dataset
if
args
.
dataset
.
lower
()
==
"chnsenticorp"
:
if
args
.
dataset
.
lower
()
==
"chnsenticorp"
:
dataset
=
hub
.
dataset
.
ChnSentiCorp
()
dataset
=
hub
.
dataset
.
ChnSentiCorp
()
module
=
hub
.
Module
(
name
=
"
roberta_wwm_ext_chinese_L-24_H-1024_A-16
"
)
module
=
hub
.
Module
(
name
=
"
ernie_v2_chinese_tiny
"
)
metrics_choices
=
[
"acc"
]
metrics_choices
=
[
"acc"
]
elif
args
.
dataset
.
lower
()
==
"tnews"
:
elif
args
.
dataset
.
lower
()
==
"tnews"
:
dataset
=
hub
.
dataset
.
TNews
()
dataset
=
hub
.
dataset
.
TNews
()
...
@@ -75,60 +74,36 @@ if __name__ == '__main__':
...
@@ -75,60 +74,36 @@ if __name__ == '__main__':
metrics_choices
=
[
"acc"
,
"f1"
]
metrics_choices
=
[
"acc"
,
"f1"
]
elif
args
.
dataset
.
lower
()
==
"mrpc"
:
elif
args
.
dataset
.
lower
()
==
"mrpc"
:
dataset
=
hub
.
dataset
.
GLUE
(
"MRPC"
)
dataset
=
hub
.
dataset
.
GLUE
(
"MRPC"
)
if
args
.
use_taskid
:
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
else
:
module
=
hub
.
Module
(
name
=
"bert_uncased_L-12_H-768_A-12"
)
metrics_choices
=
[
"f1"
,
"acc"
]
metrics_choices
=
[
"f1"
,
"acc"
]
# The first metric will be choose to eval. Ref: task.py:799
# The first metric will be choose to eval. Ref: task.py:799
elif
args
.
dataset
.
lower
()
==
"qqp"
:
elif
args
.
dataset
.
lower
()
==
"qqp"
:
dataset
=
hub
.
dataset
.
GLUE
(
"QQP"
)
dataset
=
hub
.
dataset
.
GLUE
(
"QQP"
)
if
args
.
use_taskid
:
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
else
:
module
=
hub
.
Module
(
name
=
"bert_uncased_L-12_H-768_A-12"
)
metrics_choices
=
[
"f1"
,
"acc"
]
metrics_choices
=
[
"f1"
,
"acc"
]
elif
args
.
dataset
.
lower
()
==
"sst-2"
:
elif
args
.
dataset
.
lower
()
==
"sst-2"
:
dataset
=
hub
.
dataset
.
GLUE
(
"SST-2"
)
dataset
=
hub
.
dataset
.
GLUE
(
"SST-2"
)
if
args
.
use_taskid
:
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
else
:
module
=
hub
.
Module
(
name
=
"bert_uncased_L-12_H-768_A-12"
)
metrics_choices
=
[
"acc"
]
metrics_choices
=
[
"acc"
]
elif
args
.
dataset
.
lower
()
==
"cola"
:
elif
args
.
dataset
.
lower
()
==
"cola"
:
dataset
=
hub
.
dataset
.
GLUE
(
"CoLA"
)
dataset
=
hub
.
dataset
.
GLUE
(
"CoLA"
)
if
args
.
use_taskid
:
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
else
:
module
=
hub
.
Module
(
name
=
"bert_uncased_L-12_H-768_A-12"
)
metrics_choices
=
[
"matthews"
,
"acc"
]
metrics_choices
=
[
"matthews"
,
"acc"
]
elif
args
.
dataset
.
lower
()
==
"qnli"
:
elif
args
.
dataset
.
lower
()
==
"qnli"
:
dataset
=
hub
.
dataset
.
GLUE
(
"QNLI"
)
dataset
=
hub
.
dataset
.
GLUE
(
"QNLI"
)
if
args
.
use_taskid
:
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
else
:
module
=
hub
.
Module
(
name
=
"bert_uncased_L-12_H-768_A-12"
)
metrics_choices
=
[
"acc"
]
metrics_choices
=
[
"acc"
]
elif
args
.
dataset
.
lower
()
==
"rte"
:
elif
args
.
dataset
.
lower
()
==
"rte"
:
dataset
=
hub
.
dataset
.
GLUE
(
"RTE"
)
dataset
=
hub
.
dataset
.
GLUE
(
"RTE"
)
if
args
.
use_taskid
:
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
else
:
module
=
hub
.
Module
(
name
=
"bert_uncased_L-12_H-768_A-12"
)
metrics_choices
=
[
"acc"
]
metrics_choices
=
[
"acc"
]
elif
args
.
dataset
.
lower
()
==
"mnli"
or
args
.
dataset
.
lower
()
==
"mnli"
:
elif
args
.
dataset
.
lower
()
==
"mnli"
or
args
.
dataset
.
lower
()
==
"mnli"
:
dataset
=
hub
.
dataset
.
GLUE
(
"MNLI_m"
)
dataset
=
hub
.
dataset
.
GLUE
(
"MNLI_m"
)
if
args
.
use_taskid
:
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
else
:
module
=
hub
.
Module
(
name
=
"bert_uncased_L-12_H-768_A-12"
)
metrics_choices
=
[
"acc"
]
metrics_choices
=
[
"acc"
]
elif
args
.
dataset
.
lower
()
==
"mnli_mm"
:
elif
args
.
dataset
.
lower
()
==
"mnli_mm"
:
dataset
=
hub
.
dataset
.
GLUE
(
"MNLI_mm"
)
dataset
=
hub
.
dataset
.
GLUE
(
"MNLI_mm"
)
if
args
.
use_taskid
:
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
module
=
hub
.
Module
(
name
=
"ernie_v2_eng_base"
)
else
:
module
=
hub
.
Module
(
name
=
"bert_uncased_L-12_H-768_A-12"
)
metrics_choices
=
[
"acc"
]
metrics_choices
=
[
"acc"
]
elif
args
.
dataset
.
lower
().
startswith
(
"xnli"
):
elif
args
.
dataset
.
lower
().
startswith
(
"xnli"
):
dataset
=
hub
.
dataset
.
XNLI
(
language
=
args
.
dataset
.
lower
()[
-
2
:])
dataset
=
hub
.
dataset
.
XNLI
(
language
=
args
.
dataset
.
lower
()[
-
2
:])
...
@@ -137,19 +112,22 @@ if __name__ == '__main__':
...
@@ -137,19 +112,22 @@ if __name__ == '__main__':
else
:
else
:
raise
ValueError
(
"%s dataset is not defined"
%
args
.
dataset
)
raise
ValueError
(
"%s dataset is not defined"
%
args
.
dataset
)
# Check metric
support_metrics
=
[
"acc"
,
"f1"
,
"matthews"
]
support_metrics
=
[
"acc"
,
"f1"
,
"matthews"
]
for
metric
in
metrics_choices
:
for
metric
in
metrics_choices
:
if
metric
not
in
support_metrics
:
if
metric
not
in
support_metrics
:
raise
ValueError
(
"
\"
%s
\"
metric is not defined"
%
metric
)
raise
ValueError
(
"
\"
%s
\"
metric is not defined"
%
metric
)
# Start preparing parameters for reader and task accoring to module
# For ernie_v2, it has an addition embedding named task_id
# For ernie_v2_chinese_tiny, it use an addition sentence_piece_vocab to tokenize
if
module
.
name
.
startswith
(
"ernie_v2"
):
use_taskid
=
True
else
:
use_taskid
=
False
inputs
,
outputs
,
program
=
module
.
context
(
inputs
,
outputs
,
program
=
module
.
context
(
trainable
=
True
,
max_seq_len
=
args
.
max_seq_len
)
trainable
=
True
,
max_seq_len
=
args
.
max_seq_len
)
reader
=
hub
.
reader
.
ClassifyReader
(
dataset
=
dataset
,
vocab_path
=
module
.
get_vocab_path
(),
max_seq_len
=
args
.
max_seq_len
,
use_task_id
=
args
.
use_taskid
)
# Construct transfer learning network
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
# Use "sequence_output" for token-level output.
...
@@ -163,9 +141,18 @@ if __name__ == '__main__':
...
@@ -163,9 +141,18 @@ if __name__ == '__main__':
inputs
[
"segment_ids"
].
name
,
inputs
[
"segment_ids"
].
name
,
inputs
[
"input_mask"
].
name
,
inputs
[
"input_mask"
].
name
,
]
]
if
use_taskid
:
if
args
.
use_taskid
:
feed_list
.
append
(
inputs
[
"task_ids"
].
name
)
feed_list
.
append
(
inputs
[
"task_ids"
].
name
)
# Finish preparing parameter for reader and task accoring to modul
# Define reader
reader
=
hub
.
reader
.
ClassifyReader
(
dataset
=
dataset
,
vocab_path
=
module
.
get_vocab_path
(),
max_seq_len
=
args
.
max_seq_len
,
use_task_id
=
use_taskid
,
sp_model_path
=
module
.
get_spm_path
(),
word_dict_path
=
module
.
get_word_dict_path
())
# Select finetune strategy, setup config and finetune
# Select finetune strategy, setup config and finetune
strategy
=
hub
.
AdamWeightDecayStrategy
(
strategy
=
hub
.
AdamWeightDecayStrategy
(
...
...
paddlehub/module/module.py
浏览文件 @
3e1aa4bd
...
@@ -320,6 +320,19 @@ class Module(object):
...
@@ -320,6 +320,19 @@ class Module(object):
for
assets_file
in
self
.
assets
:
for
assets_file
in
self
.
assets
:
if
"vocab.txt"
in
assets_file
:
if
"vocab.txt"
in
assets_file
:
return
assets_file
return
assets_file
return
None
def
get_word_dict_path
(
self
):
for
assets_file
in
self
.
assets
:
if
"dict.wordseg.pickle"
in
assets_file
:
return
assets_file
return
None
def
get_spm_path
(
self
):
for
assets_file
in
self
.
assets
:
if
"spm_cased_simp_sampled.model"
in
assets_file
:
return
assets_file
return
None
def
_recover_from_desc
(
self
):
def
_recover_from_desc
(
self
):
# recover signature
# recover signature
...
...
paddlehub/reader/nlp_reader.py
浏览文件 @
3e1aa4bd
...
@@ -44,10 +44,16 @@ class BaseReader(object):
...
@@ -44,10 +44,16 @@ class BaseReader(object):
do_lower_case
=
True
,
do_lower_case
=
True
,
random_seed
=
None
,
random_seed
=
None
,
use_task_id
=
False
,
use_task_id
=
False
,
sp_model_path
=
None
,
word_dict_path
=
None
,
in_tokens
=
False
):
in_tokens
=
False
):
self
.
max_seq_len
=
max_seq_len
self
.
max_seq_len
=
max_seq_len
self
.
tokenizer
=
tokenization
.
FullTokenizer
(
if
sp_model_path
and
word_dict_path
:
vocab_file
=
vocab_path
,
do_lower_case
=
do_lower_case
)
self
.
tokenizer
=
tokenization
.
WSSPTokenizer
(
vocab_path
,
sp_model_path
,
word_dict_path
,
ws
=
True
,
lower
=
True
)
else
:
self
.
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
vocab_path
,
do_lower_case
=
do_lower_case
)
self
.
vocab
=
self
.
tokenizer
.
vocab
self
.
vocab
=
self
.
tokenizer
.
vocab
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
pad_id
=
self
.
vocab
[
"[PAD]"
]
self
.
pad_id
=
self
.
vocab
[
"[PAD]"
]
...
...
paddlehub/reader/tokenization.py
浏览文件 @
3e1aa4bd
...
@@ -22,6 +22,8 @@ import collections
...
@@ -22,6 +22,8 @@ import collections
import
io
import
io
import
unicodedata
import
unicodedata
import
six
import
six
import
sentencepiece
as
spm
import
pickle
def
convert_to_unicode
(
text
):
def
convert_to_unicode
(
text
):
...
@@ -154,6 +156,54 @@ class CharTokenizer(object):
...
@@ -154,6 +156,54 @@ class CharTokenizer(object):
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
class
WSSPTokenizer
(
object
):
def
__init__
(
self
,
vocab_file
,
sp_model_dir
,
word_dict
,
ws
=
True
,
lower
=
True
):
self
.
vocab
=
load_vocab
(
vocab_file
)
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
self
.
vocab
.
items
()}
self
.
ws
=
ws
self
.
lower
=
lower
self
.
dict
=
pickle
.
load
(
open
(
word_dict
,
'rb'
),
encoding
=
'utf8'
)
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
window_size
=
5
self
.
sp_model
.
Load
(
sp_model_dir
)
def
cut
(
self
,
chars
):
words
=
[]
idx
=
0
while
idx
<
len
(
chars
):
matched
=
False
for
i
in
range
(
self
.
window_size
,
0
,
-
1
):
cand
=
chars
[
idx
:
idx
+
i
]
if
cand
in
self
.
dict
:
words
.
append
(
cand
)
matched
=
True
break
if
not
matched
:
i
=
1
words
.
append
(
chars
[
idx
])
idx
+=
i
return
words
def
tokenize
(
self
,
text
):
sen
=
text
.
decode
(
'utf8'
)
if
self
.
ws
:
sen
=
[
s
for
s
in
self
.
cut
(
sen
)
if
s
!=
' '
]
else
:
sen
=
sen
.
split
(
' '
)
if
self
.
lower
:
sen
=
[
s
.
lower
()
for
s
in
sen
]
sen
=
' '
.
join
(
sen
)
ret
=
self
.
sp_model
.
EncodeAsPieces
(
sen
)
return
ret
def
convert_tokens_to_ids
(
self
,
tokens
):
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
def
convert_ids_to_tokens
(
self
,
ids
):
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
class
BasicTokenizer
(
object
):
class
BasicTokenizer
(
object
):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录