Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
22c4494f
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
22c4494f
编写于
5月 11, 2020
作者:
S
Steffy-zxf
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/PaddleHub
into add-preset-net
上级
b3b8cb0f
b7e8230f
变更
23
显示空白变更内容
内联
并排
Showing
23 changed file
with
826 addition
and
111 deletion
+826
-111
README.md
README.md
+18
-16
demo/image_classification/img_classifier_dygraph.py
demo/image_classification/img_classifier_dygraph.py
+89
-0
demo/sequence_labeling/sequence_label_dygraph.py
demo/sequence_labeling/sequence_label_dygraph.py
+107
-0
demo/text_classification/finetuned_model_to_module/module.py
demo/text_classification/finetuned_model_to_module/module.py
+12
-10
demo/text_classification/text_classifier_dygraph.py
demo/text_classification/text_classifier_dygraph.py
+98
-0
docs/pretrained_models.md
docs/pretrained_models.md
+178
-0
docs/reference/config.md
docs/reference/config.md
+4
-4
docs/reference/task/base_task.md
docs/reference/task/base_task.md
+0
-9
docs/tutorial/define_task_example.md
docs/tutorial/define_task_example.md
+84
-0
docs/tutorial/finetuned_model_to_module.md
docs/tutorial/finetuned_model_to_module.md
+63
-30
docs/tutorial/how_to_load_data.md
docs/tutorial/how_to_load_data.md
+3
-0
hub_module/scripts/configs/faster_rcnn_resnet50_fpn_venus.yml
...module/scripts/configs/faster_rcnn_resnet50_fpn_venus.yml
+5
-5
paddlehub/__init__.py
paddlehub/__init__.py
+1
-0
paddlehub/common/downloader.py
paddlehub/common/downloader.py
+74
-17
paddlehub/common/hub_server.py
paddlehub/common/hub_server.py
+2
-1
paddlehub/common/logger.py
paddlehub/common/logger.py
+2
-1
paddlehub/common/paddle_helper.py
paddlehub/common/paddle_helper.py
+4
-2
paddlehub/dataset/food101.py
paddlehub/dataset/food101.py
+3
-3
paddlehub/module/manager.py
paddlehub/module/manager.py
+9
-5
paddlehub/module/module.py
paddlehub/module/module.py
+7
-2
paddlehub/module/nlp_module.py
paddlehub/module/nlp_module.py
+61
-4
paddlehub/reader/cv_reader.py
paddlehub/reader/cv_reader.py
+1
-1
paddlehub/version.py
paddlehub/version.py
+1
-1
未找到文件。
README.md
浏览文件 @
22c4494f
...
...
@@ -8,18 +8,18 @@


PaddleHub是飞桨生态的预训练模型应用工具,开发者可以便捷地使用高质量的预训练模型结合Fine-tune API快速完成模型迁移到部署的全流程工作。PaddleHub提供的预训练模型涵盖了图像分类、目标检测、词法分析、语义模型、情感分析、视频分类、图像生成、图像分割、文本审核、关键点检测等主流模型。更多详情可查看官网:https://www.paddlepaddle.org.cn/hu
PaddleHub是飞桨生态的预训练模型应用工具,开发者可以便捷地使用高质量的预训练模型结合Fine-tune API快速完成模型迁移到部署的全流程工作。PaddleHub提供的预训练模型涵盖了图像分类、目标检测、词法分析、语义模型、情感分析、视频分类、图像生成、图像分割、文本审核、关键点检测等主流模型。更多详情可查看官网:https://www.paddlepaddle.org.cn/hu
b
PaddleHub以预训练模型应用为核心具备以下特点:
*
**[模型即软件](#模型即软件)**
,通过Python API或命令行实现模型调用,可快速体验或集成飞桨特色预训练模型。
*
**[易用的迁移学习](#迁移学习)**
,通过Fine-tune API,内置多种优化策略,只需少量代码即可完成预训练模型的Fine-tuning。
*
**[易用的迁移学习](#
易用的
迁移学习)**
,通过Fine-tune API,内置多种优化策略,只需少量代码即可完成预训练模型的Fine-tuning。
*
**[一键模型转服务](#
服务化部署paddlehub-serving
)**
,简单一行命令即可搭建属于自己的深度学习模型API服务完成部署。
*
**[一键模型转服务](#
一键模型转服务
)**
,简单一行命令即可搭建属于自己的深度学习模型API服务完成部署。
*
**[自动超参优化](#
超参优化autodl-finetuner
)**
,内置AutoDL Finetuner能力,一键启动自动化超参搜索。
*
**[自动超参优化](#
自动超参优化
)**
,内置AutoDL Finetuner能力,一键启动自动化超参搜索。
<p
align=
"center"
>
...
...
@@ -66,7 +66,7 @@ PaddleHub采用模型即软件的设计理念,所有的预训练模型与Pytho
安装PaddleHub后,执行命令
[
hub run
](
./docs/tutorial/cmdintro.md
)
,即可快速体验无需代码、一键预测的功能:
*
使用
[
目标检测
](
http
://www.paddlepaddle.org.cn/hub?filter=
category&value=ObjectDetection
)
模型pyramidbox_lite_mobile_mask对图片进行口罩检测
*
使用
[
目标检测
](
http
s://www.paddlepaddle.org.cn/hublist?filter=en_
category&value=ObjectDetection
)
模型pyramidbox_lite_mobile_mask对图片进行口罩检测
```
shell
$
wget https://paddlehub.bj.bcebos.com/resources/test_mask_detection.jpg
$
hub run pyramidbox_lite_mobile_mask
--input_path
test_mask_detection.jpg
...
...
@@ -75,19 +75,22 @@ $ hub run pyramidbox_lite_mobile_mask --input_path test_mask_detection.jpg
<img src="./docs/imgs/test_mask_detection_result.jpg" align="middle"
</p>
*
使用
[
词法分析
](
http
://www.paddlepaddle.org.cn/hub?filter=
category&value=LexicalAnalysis
)
模型LAC进行分词
*
使用
[
词法分析
](
http
s://www.paddlepaddle.org.cn/hublist?filter=en_
category&value=LexicalAnalysis
)
模型LAC进行分词
```
shell
$
hub run lac
--input_text
"今天是个好日子"
[{
'word'
:
[
'今天'
,
'是'
,
'个'
,
'好日子'
]
,
'tag'
:
[
'TIME'
,
'v'
,
'q'
,
'n'
]}]
$
hub run lac
--input_text
"现在,慕尼黑再保险公司不仅是此类行动的倡议者,更是将其大量气候数据整合进保险产品中,并与公众共享大量天气信息,参与到新能源领域的保障中。"
[{
'word'
:
[
'现在'
,
','
,
'慕尼黑再保险公司'
,
'不仅'
,
'是'
,
'此类'
,
'行动'
,
'的'
,
'倡议者'
,
','
,
'更是'
,
'将'
,
'其'
,
'大量'
,
'气候'
,
'数据'
,
'整合'
,
'进'
,
'保险'
,
'产品'
,
'中'
,
','
,
'并'
,
'与'
,
'公众'
,
'共享'
,
'大量'
,
'天气'
,
'信息'
,
','
,
'参与'
,
'到'
,
'新能源'
,
'领域'
,
'的'
,
'保障'
,
'中'
,
'。'
]
,
'tag'
:
[
'TIME'
,
'w'
,
'ORG'
,
'c'
,
'v'
,
'r'
,
'n'
,
'u'
,
'n'
,
'w'
,
'd'
,
'p'
,
'r'
,
'a'
,
'n'
,
'n'
,
'v'
,
'v'
,
'n'
,
'n'
,
'f'
,
'w'
,
'c'
,
'p'
,
'n'
,
'v'
,
'a'
,
'n'
,
'n'
,
'w'
,
'v'
,
'v'
,
'n'
,
'n'
,
'u'
,
'vn'
,
'f'
,
'w'
]
}]
```
*
使用
[
情感分析
](
http
://www.paddlepaddle.org.cn/hub?filter=
category&value=SentimentAnalysis
)
模型Senta对句子进行情感预测
*
使用
[
情感分析
](
http
s://www.paddlepaddle.org.cn/hublist?filter=en_
category&value=SentimentAnalysis
)
模型Senta对句子进行情感预测
```
shell
$
hub run senta_bilstm
--input_text
"今天天气真好"
{
'text'
:
'今天天气真好'
,
'sentiment_label'
: 1,
'sentiment_key'
:
'positive'
,
'positive_probs'
: 0.9798,
'negative_probs'
: 0.0202
}]
```
*
使用
[
目标检测
](
http
://www.paddlepaddle.org.cn/hub?filter=
category&value=ObjectDetection
)
模型Ultra-Light-Fast-Generic-Face-Detector-1MB对图片进行人脸识别
*
使用
[
目标检测
](
http
s://www.paddlepaddle.org.cn/hublist?filter=en_
category&value=ObjectDetection
)
模型Ultra-Light-Fast-Generic-Face-Detector-1MB对图片进行人脸识别
```
shell
$
wget https://paddlehub.bj.bcebos.com/resources/test_image.jpg
$
hub run ultra_light_fast_generic_face_detector_1mb_640
--input_path
test_image.jpg
...
...
@@ -110,11 +113,11 @@ $ hub run deeplabv3p_xception65_humanseg --input_path test_image.jpg
</p>
<p
align=
'center'
>
         
ace2p分割结果展示
                
humanseg分割结果展示
   
         
ACE2P人体部件分割
                
HumanSeg人像分割
   
</p>
PaddleHub还提供图像分类、语义模型、视频分类、图像生成、图像分割、文本审核、关键点检测等主流模型,更多模型介绍,请前往
[
https://www.paddlepaddle.org.cn/hub
](
https://www.paddlepaddle.org.cn/hub
)
查看
PaddleHub还提供图像分类、语义模型、视频分类、图像生成、图像分割、文本审核、关键点检测等主流模型,更多模型介绍,请前往
[
预训练模型介绍
](
./docs/pretrained_models.md
)
或者PaddleHub官网
[
https://www.paddlepaddle.org.cn/hub
](
https://www.paddlepaddle.org.cn/hub
)
查看
### 易用的迁移学习
...
...
@@ -189,6 +192,5 @@ $ hub uninstall ernie
## 更新历史
PaddleHub v1.6.0已发布!
详情参考
[
更新历史
](
./RELEASE.md
)
PaddleHub v1.6 已发布!
更多升级详情参考
[
更新历史
](
./RELEASE.md
)
demo/image_classification/img_classifier_dygraph.py
0 → 100644
浏览文件 @
22c4494f
#coding:utf-8
import
argparse
import
os
import
numpy
as
np
import
paddlehub
as
hub
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Linear
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.optimizer
import
AdamOptimizer
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--num_epoch"
,
type
=
int
,
default
=
1
,
help
=
"Number of epoches for fine-tuning."
)
parser
.
add_argument
(
"--checkpoint_dir"
,
type
=
str
,
default
=
"paddlehub_finetune_ckpt_dygraph"
,
help
=
"Path to save log data."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--log_interval"
,
type
=
int
,
default
=
10
,
help
=
"log interval."
)
parser
.
add_argument
(
"--save_interval"
,
type
=
int
,
default
=
10
,
help
=
"save interval."
)
# yapf: enable.
class
ResNet50
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
,
backbone
):
super
(
ResNet50
,
self
).
__init__
()
self
.
fc
=
Linear
(
input_dim
=
2048
,
output_dim
=
num_classes
)
self
.
backbone
=
backbone
def
forward
(
self
,
imgs
):
feature_map
=
self
.
backbone
(
imgs
)
feature_map
=
fluid
.
layers
.
reshape
(
feature_map
,
shape
=
[
-
1
,
2048
])
pred
=
self
.
fc
(
feature_map
)
return
fluid
.
layers
.
softmax
(
pred
)
def
finetune
(
args
):
with
fluid
.
dygraph
.
guard
():
resnet50_vd_10w
=
hub
.
Module
(
name
=
"resnet50_vd_10w"
)
dataset
=
hub
.
dataset
.
Flowers
()
resnet
=
ResNet50
(
num_classes
=
dataset
.
num_labels
,
backbone
=
resnet50_vd_10w
)
adam
=
AdamOptimizer
(
learning_rate
=
0.001
,
parameter_list
=
resnet
.
parameters
())
state_dict_path
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
'dygraph_state_dict'
)
if
os
.
path
.
exists
(
state_dict_path
+
'.pdparams'
):
state_dict
,
_
=
fluid
.
load_dygraph
(
state_dict_path
)
resnet
.
load_dict
(
state_dict
)
reader
=
hub
.
reader
.
ImageClassificationReader
(
image_width
=
resnet50_vd_10w
.
get_expected_image_width
(),
image_height
=
resnet50_vd_10w
.
get_expected_image_height
(),
images_mean
=
resnet50_vd_10w
.
get_pretrained_images_mean
(),
images_std
=
resnet50_vd_10w
.
get_pretrained_images_std
(),
dataset
=
dataset
)
train_reader
=
reader
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'train'
)
loss_sum
=
acc_sum
=
cnt
=
0
# 执行epoch_num次训练
for
epoch
in
range
(
args
.
num_epoch
):
# 读取训练数据进行训练
for
batch_id
,
data
in
enumerate
(
train_reader
()):
imgs
=
np
.
array
(
data
[
0
][
0
])
labels
=
np
.
array
(
data
[
0
][
1
])
pred
=
resnet
(
imgs
)
acc
=
fluid
.
layers
.
accuracy
(
pred
,
to_variable
(
labels
))
loss
=
fluid
.
layers
.
cross_entropy
(
pred
,
to_variable
(
labels
))
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
avg_loss
.
backward
()
# 参数更新
adam
.
minimize
(
avg_loss
)
loss_sum
+=
avg_loss
.
numpy
()
*
imgs
.
shape
[
0
]
acc_sum
+=
acc
.
numpy
()
*
imgs
.
shape
[
0
]
cnt
+=
imgs
.
shape
[
0
]
if
batch_id
%
args
.
log_interval
==
0
:
print
(
'epoch {}: loss {}, acc {}'
.
format
(
epoch
,
loss_sum
/
cnt
,
acc_sum
/
cnt
))
loss_sum
=
acc_sum
=
cnt
=
0
if
batch_id
%
args
.
save_interval
==
0
:
state_dict
=
resnet
.
state_dict
()
fluid
.
save_dygraph
(
state_dict
,
state_dict_path
)
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
finetune
(
args
)
demo/sequence_labeling/sequence_label_dygraph.py
0 → 100644
浏览文件 @
22c4494f
#coding:utf-8
import
argparse
import
os
import
numpy
as
np
import
paddlehub
as
hub
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Linear
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.optimizer
import
AdamOptimizer
from
paddlehub.finetune.evaluate
import
chunk_eval
,
calculate_f1
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--num_epoch"
,
type
=
int
,
default
=
1
,
help
=
"Number of epoches for fine-tuning."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--log_interval"
,
type
=
int
,
default
=
10
,
help
=
"log interval."
)
parser
.
add_argument
(
"--save_interval"
,
type
=
int
,
default
=
10
,
help
=
"save interval."
)
parser
.
add_argument
(
"--checkpoint_dir"
,
type
=
str
,
default
=
"paddlehub_finetune_ckpt_dygraph"
,
help
=
"Path to save log data."
)
parser
.
add_argument
(
"--max_seq_len"
,
type
=
int
,
default
=
512
,
help
=
"Number of words of the longest seqence."
)
# yapf: enable.
class
TransformerSequenceLabelLayer
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
,
transformer
):
super
(
TransformerSequenceLabelLayer
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
transformer
=
transformer
self
.
fc
=
Linear
(
input_dim
=
768
,
output_dim
=
num_classes
)
def
forward
(
self
,
input_ids
,
position_ids
,
segment_ids
,
input_mask
):
result
=
self
.
transformer
(
input_ids
,
position_ids
,
segment_ids
,
input_mask
)
pred
=
self
.
fc
(
result
[
'sequence_output'
])
ret_infers
=
fluid
.
layers
.
reshape
(
x
=
fluid
.
layers
.
argmax
(
pred
,
axis
=
2
),
shape
=
[
-
1
,
1
])
pred
=
fluid
.
layers
.
reshape
(
pred
,
shape
=
[
-
1
,
self
.
num_classes
])
return
fluid
.
layers
.
softmax
(
pred
),
ret_infers
def
finetune
(
args
):
ernie
=
hub
.
Module
(
name
=
"ernie"
,
max_seq_len
=
args
.
max_seq_len
)
with
fluid
.
dygraph
.
guard
():
dataset
=
hub
.
dataset
.
MSRA_NER
()
ts
=
TransformerSequenceLabelLayer
(
num_classes
=
dataset
.
num_labels
,
transformer
=
ernie
)
adam
=
AdamOptimizer
(
learning_rate
=
1e-5
,
parameter_list
=
ts
.
parameters
())
state_dict_path
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
'dygraph_state_dict'
)
if
os
.
path
.
exists
(
state_dict_path
+
'.pdparams'
):
state_dict
,
_
=
fluid
.
load_dygraph
(
state_dict_path
)
ts
.
load_dict
(
state_dict
)
reader
=
hub
.
reader
.
SequenceLabelReader
(
dataset
=
dataset
,
vocab_path
=
ernie
.
get_vocab_path
(),
max_seq_len
=
args
.
max_seq_len
,
sp_model_path
=
ernie
.
get_spm_path
(),
word_dict_path
=
ernie
.
get_word_dict_path
())
train_reader
=
reader
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'train'
)
loss_sum
=
total_infer
=
total_label
=
total_correct
=
cnt
=
0
# 执行epoch_num次训练
for
epoch
in
range
(
args
.
num_epoch
):
# 读取训练数据进行训练
for
batch_id
,
data
in
enumerate
(
train_reader
()):
input_ids
=
np
.
array
(
data
[
0
][
0
]).
astype
(
np
.
int64
)
position_ids
=
np
.
array
(
data
[
0
][
1
]).
astype
(
np
.
int64
)
segment_ids
=
np
.
array
(
data
[
0
][
2
]).
astype
(
np
.
int64
)
input_mask
=
np
.
array
(
data
[
0
][
3
]).
astype
(
np
.
float32
)
labels
=
np
.
array
(
data
[
0
][
4
]).
astype
(
np
.
int64
).
reshape
(
-
1
,
1
)
seq_len
=
np
.
squeeze
(
np
.
array
(
data
[
0
][
5
]).
astype
(
np
.
int64
),
axis
=
1
)
pred
,
ret_infers
=
ts
(
input_ids
,
position_ids
,
segment_ids
,
input_mask
)
loss
=
fluid
.
layers
.
cross_entropy
(
pred
,
to_variable
(
labels
))
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
avg_loss
.
backward
()
# 参数更新
adam
.
minimize
(
avg_loss
)
loss_sum
+=
avg_loss
.
numpy
()
*
labels
.
shape
[
0
]
label_num
,
infer_num
,
correct_num
=
chunk_eval
(
labels
,
ret_infers
.
numpy
(),
seq_len
,
dataset
.
num_labels
,
1
)
cnt
+=
labels
.
shape
[
0
]
total_infer
+=
infer_num
total_label
+=
label_num
total_correct
+=
correct_num
if
batch_id
%
args
.
log_interval
==
0
:
precision
,
recall
,
f1
=
calculate_f1
(
total_label
,
total_infer
,
total_correct
)
print
(
'epoch {}: loss {}, f1 {} recall {} precision {}'
.
format
(
epoch
,
loss_sum
/
cnt
,
f1
,
recall
,
precision
))
loss_sum
=
total_infer
=
total_label
=
total_correct
=
cnt
=
0
if
batch_id
%
args
.
save_interval
==
0
:
state_dict
=
ts
.
state_dict
()
fluid
.
save_dygraph
(
state_dict
,
state_dict_path
)
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
finetune
(
args
)
demo/text_classification/finetuned_model_to_module/module.py
浏览文件 @
22c4494f
...
...
@@ -94,6 +94,7 @@ class ERNIETinyFinetuned(hub.Module):
config
=
config
,
metrics_choices
=
metrics_choices
)
@
serving
def
predict
(
self
,
data
,
return_result
=
False
,
accelerate_mode
=
True
):
"""
Get prediction results
...
...
@@ -102,7 +103,14 @@ class ERNIETinyFinetuned(hub.Module):
data
=
data
,
return_result
=
return_result
,
accelerate_mode
=
accelerate_mode
)
return
run_states
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
prediction
=
[]
for
batch_result
in
results
:
# get predict index
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
batch_result
=
batch_result
.
tolist
()
prediction
+=
batch_result
return
prediction
if
__name__
==
"__main__"
:
...
...
@@ -113,12 +121,6 @@ if __name__ == "__main__":
data
=
[[
"这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"
],
[
"交通方便;环境很好;服务态度很好 房间较小"
],
[
"19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"
]]
index
=
0
run_states
=
ernie_tiny
.
predict
(
data
=
data
)
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
for
batch_result
in
results
:
# get predict index
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
for
result
in
batch_result
:
print
(
"%s
\t
predict=%s"
%
(
data
[
index
][
0
],
result
))
index
+=
1
predictions
=
ernie_tiny
.
predict
(
data
=
data
)
for
index
,
text
in
enumerate
(
data
):
print
(
"%s
\t
predict=%s"
%
(
data
[
index
][
0
],
predictions
[
index
]))
demo/text_classification/text_classifier_dygraph.py
0 → 100644
浏览文件 @
22c4494f
#coding:utf-8
import
argparse
import
os
import
numpy
as
np
import
paddlehub
as
hub
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Linear
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.optimizer
import
AdamOptimizer
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--num_epoch"
,
type
=
int
,
default
=
1
,
help
=
"Number of epoches for fine-tuning."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--log_interval"
,
type
=
int
,
default
=
10
,
help
=
"log interval."
)
parser
.
add_argument
(
"--save_interval"
,
type
=
int
,
default
=
10
,
help
=
"save interval."
)
parser
.
add_argument
(
"--checkpoint_dir"
,
type
=
str
,
default
=
"paddlehub_finetune_ckpt_dygraph"
,
help
=
"Path to save log data."
)
parser
.
add_argument
(
"--max_seq_len"
,
type
=
int
,
default
=
512
,
help
=
"Number of words of the longest seqence."
)
# yapf: enable.
class
TransformerClassifier
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
,
transformer
):
super
(
TransformerClassifier
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
transformer
=
transformer
self
.
fc
=
Linear
(
input_dim
=
768
,
output_dim
=
num_classes
)
def
forward
(
self
,
input_ids
,
position_ids
,
segment_ids
,
input_mask
):
result
=
self
.
transformer
(
input_ids
,
position_ids
,
segment_ids
,
input_mask
)
cls_feats
=
fluid
.
layers
.
dropout
(
result
[
'pooled_output'
],
dropout_prob
=
0.1
,
dropout_implementation
=
"upscale_in_train"
)
cls_feats
=
fluid
.
layers
.
reshape
(
cls_feats
,
shape
=
[
-
1
,
768
])
pred
=
self
.
fc
(
cls_feats
)
return
fluid
.
layers
.
softmax
(
pred
)
def
finetune
(
args
):
ernie
=
hub
.
Module
(
name
=
"ernie"
,
max_seq_len
=
args
.
max_seq_len
)
with
fluid
.
dygraph
.
guard
():
dataset
=
hub
.
dataset
.
ChnSentiCorp
()
tc
=
TransformerClassifier
(
num_classes
=
dataset
.
num_labels
,
transformer
=
ernie
)
adam
=
AdamOptimizer
(
learning_rate
=
1e-5
,
parameter_list
=
tc
.
parameters
())
state_dict_path
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
'dygraph_state_dict'
)
if
os
.
path
.
exists
(
state_dict_path
+
'.pdparams'
):
state_dict
,
_
=
fluid
.
load_dygraph
(
state_dict_path
)
tc
.
load_dict
(
state_dict
)
reader
=
hub
.
reader
.
ClassifyReader
(
dataset
=
dataset
,
vocab_path
=
ernie
.
get_vocab_path
(),
max_seq_len
=
args
.
max_seq_len
,
sp_model_path
=
ernie
.
get_spm_path
(),
word_dict_path
=
ernie
.
get_word_dict_path
())
train_reader
=
reader
.
data_generator
(
batch_size
=
args
.
batch_size
,
phase
=
'train'
)
loss_sum
=
acc_sum
=
cnt
=
0
# 执行epoch_num次训练
for
epoch
in
range
(
args
.
num_epoch
):
# 读取训练数据进行训练
for
batch_id
,
data
in
enumerate
(
train_reader
()):
input_ids
=
np
.
array
(
data
[
0
][
0
]).
astype
(
np
.
int64
)
position_ids
=
np
.
array
(
data
[
0
][
1
]).
astype
(
np
.
int64
)
segment_ids
=
np
.
array
(
data
[
0
][
2
]).
astype
(
np
.
int64
)
input_mask
=
np
.
array
(
data
[
0
][
3
]).
astype
(
np
.
float32
)
labels
=
np
.
array
(
data
[
0
][
4
]).
astype
(
np
.
int64
)
pred
=
tc
(
input_ids
,
position_ids
,
segment_ids
,
input_mask
)
acc
=
fluid
.
layers
.
accuracy
(
pred
,
to_variable
(
labels
))
loss
=
fluid
.
layers
.
cross_entropy
(
pred
,
to_variable
(
labels
))
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
avg_loss
.
backward
()
# 参数更新
adam
.
minimize
(
avg_loss
)
loss_sum
+=
avg_loss
.
numpy
()
*
labels
.
shape
[
0
]
acc_sum
+=
acc
.
numpy
()
*
labels
.
shape
[
0
]
cnt
+=
labels
.
shape
[
0
]
if
batch_id
%
args
.
log_interval
==
0
:
print
(
'epoch {}: loss {}, acc {}'
.
format
(
epoch
,
loss_sum
/
cnt
,
acc_sum
/
cnt
))
loss_sum
=
acc_sum
=
cnt
=
0
if
batch_id
%
args
.
save_interval
==
0
:
state_dict
=
tc
.
state_dict
()
fluid
.
save_dygraph
(
state_dict
,
state_dict_path
)
if
__name__
==
"__main__"
:
args
=
parser
.
parse_args
()
finetune
(
args
)
docs/pretrained_models.md
0 → 100644
浏览文件 @
22c4494f
# PaddleHub 预训练模型介绍
PaddlePaddle 提供了丰富的模型,使得用户可以采用模块化的方法解决各种学习问题。
*
如果是想了解具体预训练模型的使用和间接可以继续学习本课程,也可以参考
[
PaddleHub预训练模型库
](
https://www.paddlepaddle.org.cn/hublist
)
*
如果想了解更多模型组网网络结构源代码请参考
[
飞桨官方模型库
](
https://github.com/PaddlePaddle/models
)
### PaddleHub预训练模型
*
[
飞桨优势特色模型
](
#飞桨优势特色模型
)
*
[
图像
](
#图像
)
*
[
图像分类
](
#图像分类
)
*
[
目标检测
](
#目标检测
)
*
[
图像分割
](
#图像分割
)
*
[
关键点检测
](
#关键点检测
)
*
[
图像生成
](
#图像生成
)
*
[
文本
](
#文本
)
*
[
中文词法分析与词向量
](
#中文词法分析与词向量
)
*
[
情感分析
](
#情感分析
)
*
[
文本相似度计算
](
#文本相似度计算
)
*
[
语义表示
](
#语义表示
)
*
[
视频
](
#视频
)
## 百度飞桨独有优势特色模型
| 任务 |
**模型名称**
|
**Master模型推荐辞**
|
| ---------- | :----------------------------------------------------------- | ---------------------------------------------------------- |
| 目标检测 |
[
YOLOv3
](
https://www.paddlepaddle.org.cn/hubdetail?name=yolov3_darknet53_coco2017&en_category=ObjectDetection
)
| 实现精度相比原作者
**提高5.9 个绝对百分点**
,性能极致优化。 |
| 目标检测 |
[
人脸检测
](
https://www.paddlepaddle.org.cn/hubdetail?name=pyramidbox_lite_server&en_category=ObjectDetection
)
| 百度自研,18年3月WIDER Face 数据集
**冠军模型**
, |
| 目标检测 |
[
口罩人脸检测与识别
](
https://github.com/PaddlePaddle/PaddleDetection
)
| 业界
**首个开源口罩人脸检测与识别模型**
,引起广泛关注。 |
| 语义分割 |
[
HumanSeg
](
https://www.paddlepaddle.org.cn/hubdetail?name=deeplabv3p_xception65_humanseg&en_category=ImageSegmentation
)
| 百度
**自建数据集**
训练,人像分割效果卓越。 |
| 语义分割 |
[
ACE2P
](
https://www.paddlepaddle.org.cn/hubdetail?name=ace2p&en_category=ImageSegmentation
)
| CVPR2019 LIP挑战赛中
**满贯三冠王**
。人体解析任务必选。 |
| 语义分割 |
[
Pneumonia_CT_LKM_PP
](
https://www.paddlepaddle.org.cn/hubdetail?name=Pneumonia_CT_LKM_PP&en_category=ImageSegmentation
)
| 助力连心医疗开源
**业界首个**
肺炎CT影像分析模型 |
| GAN |
[
stylepro_artistic
](
https://www.paddlepaddle.org.cn/hubdetail?name=stylepro_artistic&en_category=GANs
)
| 百度自研风格迁移模型,趣味模型,
**推荐尝试**
|
| 词法分析 |
[
LAC
](
https://www.paddlepaddle.org.cn/hubdetail?name=lac&en_category=LexicalAnalysis
)
| 百度
**自研中文特色**
模型词法分析任务。 |
| 情感分析 |
[
Senta
](
https://www.paddlepaddle.org.cn/hubdetail?name=lac&en_category=LexicalAnalysis
)
| 百度自研情感分析模型,海量中文数据训练。 |
| 情绪识别 |
[
emotion_detection
](
https://www.paddlepaddle.org.cn/hubdetail?name=emotion_detection_textcnn&en_category=SentimentAnalysis
)
| 百度自研对话识别模型,海量中文数据训练。 |
| 文本相似度 |
[
simnet
](
https://www.paddlepaddle.org.cn/hubdetail?name=simnet_bow&en_category=SemanticModel
)
| 百度自研短文本相似度模型,海量中文数据训练。 |
| 文本审核 |
[
porn_detection
](
https://www.paddlepaddle.org.cn/hubdetail?name=porn_detection_gru&en_category=TextCensorship
)
| 百度自研色情文本审核模型,海量中文数据训练。 |
| 语义模型 |
[
ERNIE
](
https://www.paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel
)
|
**SOTA 语义模型,中文任务全面优于BERT**
。 |
| 图像分类 |
[
菜品识别
](
https://www.paddlepaddle.org.cn/hubdetail?name=resnet50_vd_dishes&en_category=ImageClassification
)
| 私有数据集训练,适合进一步菜品方向微调。 |
| 图像分类 |
[
动物识别
](
https://www.paddlepaddle.org.cn/hubdetail?name=resnet50_vd_animals&en_category=ImageClassification
)
| 私有数据集训练,适合进一步动物方向微调。 |
| | | |
| 目标检测 | 行人检测(即将开源) | 百度自研模型,海量私有数据集训练。 |
| 目标检测 | 行人检测(即将开源) | 百度自研模型,海量私有数据集训练。 |
| OCR | 中文OCR(即将开源) | 开源模型基础上性能优化,增加私有数据集训练。 |
| 语音合成 | WaveFlow(即将开源) | 百度自研模型,海量私有数据集训练。 |
## 图像
#### 图像分类
图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉中重要的基础问题,是物体检测、图像分割、物体跟踪、行为分析、人脸识别等其他高层视觉任务的基础,在许多领域都有着广泛的应用。如:安防领域的人脸识别和智能视频分析等,交通领域的交通场景识别,互联网领域基于内容的图像检索和相册自动归类,医学领域的图像识别等。
**注:**
**如果你是资深开发者,那可以随意按需使用**
,
**假如你是新手,服务器端优先选择Resnet50,移动端优先选择MobileNetV2**
。
|
**模型名称**
|
**模型简介**
|
| - | - |
|
[
AlexNet
](
https://www.paddlepaddle.org.cn/hubdetail?name=alexnet_imagenet&en_category=ImageClassification
)
| 首次在 CNN 中成功的应用了 ReLU, Dropout 和 LRN,并使用 GPU 进行运算加速 |
|
[
VGG19
](
https://www.paddlepaddle.org.cn/hubdetail?name=vgg19_imagenet&en_category=ImageClassification
)
| 在 AlexNet 的基础上使用 3
*
3 小卷积核,增加网络深度,具有很好的泛化能力 |
|
[
GoogLeNet
](
https://github.com/PaddlePaddle/models/tree/release/1.7/PaddleCV/image_classification
)
| 在不增加计算负载的前提下增加了网络的深度和宽度,性能更加优越 |
|
[
ResNet50
](
https://www.paddlepaddle.org.cn/hubdetail?name=resnet_v2_50_imagenet&en_category=ImageClassification
)
| Residual Network,引入了新的残差结构,解决了随着网络加深,准确率下降的问题 |
|
[
Inceptionv4
](
https://www.paddlepaddle.org.cn/hubdetail?name=inception_v4_imagenet&en_category=ImageClassification
)
| 将 Inception 模块与 Residual Connection 进行结合,通过ResNet的结构极大地加速训练并获得性能的提升 |
|
[
MobileNetV2
](
https://www.paddlepaddle.org.cn/hubdetail?name=mobilenet_v2_imagenet&en_category=ImageClassification
)
| MobileNet结构的微调,直接在 thinner 的 bottleneck层上进行 skip learning 连接以及对 bottleneck layer 不进行 ReLu 非线性处理可取得更好的结果 |
|
[
se_resnext50
](
https://www.paddlepaddle.org.cn/hubdetail?name=se_resnext50_32x4d_imagenet&en_category=ImageClassification
)
| 在ResNeXt 基础、上加入了 SE(Sequeeze-and-Excitation) 模块,提高了识别准确率,在 ILSVRC 2017 的分类项目中取得了第一名 |
|
[
ShuffleNetV2
](
https://www.paddlepaddle.org.cn/hubdetail?name=shufflenet_v2_imagenet&en_category=ImageClassification
)
| ECCV2018,轻量级 CNN 网络,在速度和准确度之间做了很好地平衡。在同等复杂度下,比 ShuffleNet 和 MobileNetv2 更准确,更适合移动端以及无人车领域 |
|
[
efficientNetb7
](
https://www.paddlepaddle.org.cn/hubdetail?name=efficientnetb7_imagenet&en_category=ImageClassification
)
| 同时对模型的分辨率,通道数和深度进行缩放,用极少的参数就可以达到SOTA的精度。 |
|
[
xception71
](
https://www.paddlepaddle.org.cn/hubdetail?name=xception71_imagenet&en_category=ImageClassification
)
| 对inception-v3的改进,用深度可分离卷积代替普通卷积,降低参数量同时提高了精度。 |
|
[
dpn107
](
https://www.paddlepaddle.org.cn/hubdetail?name=dpn107_imagenet&en_category=ImageClassification
)
| 融合了densenet和resnext的特点。 |
|
[
DarkNet53
](
https://www.paddlepaddle.org.cn/hubdetail?name=darknet53_imagenet&en_category=ImageClassification
)
| 检测框架yolov3使用的backbone,在分类和检测任务上都有不错表现。 |
|
[
DenseNet161
](
https://www.paddlepaddle.org.cn/hubdetail?name=densenet161_imagenet&en_category=ImageClassification
)
| 提出了密集连接的网络结构,更加有利于信息流的传递。 |
|
[
ResNeXt152_vd
](
https://www.paddlepaddle.org.cn/hubdetail?name=resnext152_64x4d_imagenet&en_category=ImageClassification
)
| 提出了cardinatity的概念,用于作为模型复杂度的另外一个度量,有效地提升模型精度。 |
#### 目标检测
目标检测任务的目标是给定一张图像或是一个视频帧,让计算机找出其中所有目标的位置,并给出每个目标的具体类别。对于计算机而言,能够“看到”的是图像被编码之后的数字,但很难解图像或是视频帧中出现了人或是物体这样的高层语义概念,也就更加难以定位目标出现在图像中哪个区域。目标检测模型请参考目标检测库
[
PaddleDetection
](
https://github.com/PaddlePaddle/PaddleDetection
)
| 模型名称 | 模型简介 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
[
SSD
](
https://www.paddlepaddle.org.cn/hubdetail?name=ssd_mobilenet_v1_pascal&en_category=ObjectDetection
)
| 很好的继承了 MobileNet 预测速度快,易于部署的特点,能够很好的在多种设备上完成图像目标检测任务 |
|
[
Faster-RCNN
](
https://www.paddlepaddle.org.cn/hubdetail?name=faster_rcnn_coco2017&en_category=ObjectDetection
)
| 创造性地采用卷积网络自行产生建议框,并且和目标检测网络共享卷积网络,建议框数目减少,质量提高 |
|
[
YOLOv3
](
https://www.paddlepaddle.org.cn/hubdetail?name=yolov3_darknet53_coco2017&en_category=ObjectDetection
)
| 速度和精度均衡的目标检测网络,相比于原作者 darknet 中的 YOLO v3 实现,PaddlePaddle 实现增加了 mixup,label_smooth 等处理,精度 (mAP(0.50: 0.95)) 相比于原作者提高了 4.7 个绝对百分点,在此基础上加入 synchronize batch normalization, 最终精度相比原作者提高 5.9 个绝对百分点。 |
|
[
PyramidBox人脸检测
](
https://www.paddlepaddle.org.cn/hubdetail?name=pyramidbox_lite_server&en_category=ObjectDetection
)
|
**PyramidBox**
**模型是百度自主研发的人脸检测模型**
,利用上下文信息解决困难人脸的检测问题,网络表达能力高,鲁棒性强。于18年3月份在 WIDER Face 数据集上取得第一名 |
|
[
超轻量人脸检测
](
https://www.paddlepaddle.org.cn/hubdetail?name=ultra_light_fast_generic_face_detector_1mb_640&en_category=ObjectDetection
)
| Ultra-Light-Fast-Generic-Face-Detector-1MB是针对边缘计算设备或低算力设备(如用ARM推理)设计的实时超轻量级通用人脸检测模型,可以在低算力设备中如用ARM进行实时的通用场景的人脸检测推理。该PaddleHub Module的预训练数据集为WIDER FACE数据集,可支持预测,在预测时会将图片输入缩放为640
*
480。 |
|
[
人脸口罩检测
](
https://www.paddlepaddle.org.cn/hubdetail?name=pyramidbox_lite_server_mask&en_category=ObjectDetection
)
| 基于PyramidBox而研发的轻量级模型,对于光照、口罩遮挡、表情变化、尺度变化等常见问题具有很强的鲁棒性。基于WIDER FACE数据集和百度自采人脸数据集进行训练,支持预测,可用于检测人脸是否佩戴口罩。 |
#### 图像分割
图像语义分割顾名思义是将图像像素按照表达的语义含义的不同进行分组/分割,图像语义是指对图像内容的理解,例如,能够描绘出什么物体在哪里做了什么事情等,分割是指对图片中的每个像素点进行标注,标注属于哪一类别。近年来用在无人车驾驶技术中分割街景来避让行人和车辆、医疗影像分析中辅助诊断等。
图像语义分割模型请参考语义分割库
[
PaddleSeg
](
https://github.com/PaddlePaddle/PaddleSeg
)
| 模型名称 | 模型简介 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
[
deeplabv3
](
https://www.paddlepaddle.org.cn/hubdetail?name=deeplabv3p_xception65_humanseg&en_category=ImageSegmentation
)
|DeepLabv3+ 作者通过encoder-decoder进行多尺度信息的融合,同时保留了原来的空洞卷积和ASSP层, 其骨干网络使用了Xception模型,提高了语义分割的健壮性和运行速率,在 PASCAL VOC 2012 dataset取得新的state-of-art performance。本Module使用百度自建数据集进行训练,可用于人像分割,支持任意大小的图片输入。|
|
[
ACE2P
](
https://www.paddlepaddle.org.cn/hubdetail?name=ace2p&en_category=ImageSegmentation
)
| 人体解析(Human Parsing)是细粒度的语义分割任务,其旨在识别像素级别的人类图像的组成部分(例如,身体部位和服装)。ACE2P通过融合底层特征,全局上下文信息和边缘细节,端到端地训练学习人体解析任务。该结构针对Intersection over Union指标进行针对性的优化学习,提升准确率。以ACE2P单人人体解析网络为基础的解决方案在CVPR2019第三届LIP挑战赛中赢得了全部三个人体解析任务的第一名。该PaddleHub Module采用ResNet101作为骨干网络,接受输入图片大小为473x473x3。 |
|
[
Pneumonia_CT_LKM_PP
](
https://www.paddlepaddle.org.cn/hubdetail?name=Pneumonia_CT_LKM_PP&en_category=ImageSegmentation
)
| 肺炎CT影像分析模型(Pneumonia-CT-LKM-PP)可以高效地完成对患者CT影像的病灶检测识别、病灶轮廓勾画,通过一定的后处理代码,可以分析输出肺部病灶的数量、体积、病灶占比等全套定量指标。值得强调的是,该系统采用的深度学习算法模型充分训练了所收集到的高分辨率和低分辨率的CT影像数据,能极好地适应不同等级CT影像设备采集的检查数据,有望为医疗资源受限和医疗水平偏低的基层医院提供有效的肺炎辅助诊断工具。 |
#### 关键点检测
人体骨骼关键点检测 (Pose Estimation) 主要检测人体的一些关键点,如关节,五官等,通过关键点描述人体骨骼信息。人体骨骼关键点检测对于描述人体姿态,预测人体行为至关重要。是诸多计算机视觉任务的基础,例如动作分类,异常行为检测,以及自动驾驶等等。
| 模型名称 | 模型简介 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
[
Pose Estimation
](
https://www.paddlepaddle.org.cn/hubdetail?name=pose_resnet50_mpii&en_category=KeyPointDetection
)
| 人体骨骼关键点检测(Pose Estimation) 是计算机视觉的基础性算法之一,在诸多计算机视觉任务起到了基础性的作用,如行为识别、人物跟踪、步态识别等相关领域。具体应用主要集中在智能视频监控,病人监护系统,人机交互,虚拟现实,人体动画,智能家居,智能安防,运动员辅助训练等等。 该模型的论文《Simple Baselines for Human Pose Estimation and Tracking》由 MSRA 发表于 ECCV18,使用 MPII 数据集训练完成。 |
#### 图像生成
图像生成是指根据输入向量,生成目标图像。这里的输入向量可以是随机的噪声或用户指定的条件向量。具体的应用场景有:手写体生成、人脸合成、风格迁移、图像修复等。
[
gan
](
https://github.com/PaddlePaddle/models/tree/release/1.7/PaddleCV/gan
)
包含和图像生成相关的多个模型。
| 模型名称 | 模型简介 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
[
CycleGAN
](
https://www.paddlepaddle.org.cn/hubdetail?name=cyclegan_cityscapes&en_category=GANs
)
| 图像翻译,可以通过非成对的图片将某一类图片转换成另外一类图片,可用于风格迁移,支持图片从实景图转换为语义分割结果,也支持从语义分割结果转换为实景图。 |
|
[
StarGAN
](
https://www.paddlepaddle.org.cn/hubdetail?name=stargan_celeba&en_category=GANs
)
| 多领域属性迁移,引入辅助分类帮助单个判别器判断多个属性,可用于人脸属性转换。该 PaddleHub Module 使用 Celeba 数据集训练完成,目前支持 "Black_Hair", "Blond_Hair", "Brown_Hair", "Female", "Male", "Aged" 这六种人脸属性转换。 |
|
[
AttGAN
](
https://www.paddlepaddle.org.cn/hubdetail?name=attgan_celeba&en_category=GANs
)
| 利用分类损失和重构损失来保证改变特定的属性,可用于13种人脸特定属性转换。 |
|
[
STGAN
](
https://www.paddlepaddle.org.cn/hubdetail?name=stgan_celeba&en_category=GANs
)
| 人脸特定属性转换,只输入有变化的标签,引入 GRU 结构,更好的选择变化的属性,支持13种属性转换。 |
## 文本
PaddleNLP 是基于 PaddlePaddle 深度学习框架开发的自然语言处理 (NLP) 工具,算法,模型和数据的开源项目。百度在 NLP 领域十几年的深厚积淀为 PaddleNLP 提供了强大的核心动力。
#### 中文词法分析与词向量
| 模型名称 | 简介 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
[
LAC 中文词法分析
](
https://www.paddlepaddle.org.cn/hubdetail?name=lac&en_category=LexicalAnalysis
)
| Lexical Analysis of Chinese,简称LAC,是百度自主研发中文特色模型词法分析任务,集成了中文分词、词性标注和命名实体识别任务。输入是一个字符串,而输出是句子中的词边界和词性、实体类别。 |
|
[
word2vec词向量
](
https://www.paddlepaddle.org.cn/hubdetail?name=word2vec_skipgram&en_category=SemanticModel
)
| Word2vec是常用的词嵌入(word embedding)模型。该PaddleHub Module基于Skip-gram模型,在海量百度搜索数据集下预训练得到中文单词预训练词嵌入。其支持Fine-tune。Word2vec的预训练数据集的词汇表大小为1700249,word embedding维度为128。 |
#### 情感分析
| 模型名称 | 简介 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
[
Senta
](
https://www.paddlepaddle.org.cn/hubdetail?name=lac&en_category=LexicalAnalysis
)
| 情感倾向分析(Sentiment Classification,简称Senta)针对带有主观描述的中文文本,可自动判断该文本的情感极性类别并给出相应的置信度,能够帮助企业理解用户消费习惯、分析热点话题和危机舆情监控,为企业提供有利的决策支持。该模型基于一个双向LSTM结构,情感类型分为积极、消极。该PaddleHub Module支持预测和Fine-tune。 |
|
[
emotion_detection
](
https://www.paddlepaddle.org.cn/hubdetail?name=emotion_detection_textcnn&en_category=SentimentAnalysis
)
| 对话情绪识别(Emotion Detection,简称EmoTect)专注于识别智能对话场景中用户的情绪,针对智能对话场景中的用户文本,自动判断该文本的情绪类别并给出相应的置信度,情绪类型分为积极、消极、中性。该模型基于TextCNN(多卷积核CNN模型),能够更好地捕捉句子局部相关性。该PaddleHub Module预训练数据集为百度自建数据集,支持预测和Fine-tune。 |
#### 文本相似度计算
[
SimNet
](
https://github.com/PaddlePaddle/models/tree/release/1.7/PaddleNLP/similarity_net
)
(
Similarity
Net) 是一个计算短文本相似度的框架,主要包括 BOW、CNN、RNN、MMDNN 等核心网络结构形式。SimNet 框架在百度各产品上广泛应用,提供语义相似度计算训练和预测框架,适用于信息检索、新闻推荐、智能客服等多个应用场景,帮助企业解决语义匹配问题。
| 模型 | 简介 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
[
simnet_bow
](
https://www.paddlepaddle.org.cn/hubdetail?name=simnet_bow&en_category=SemanticModel
)
| SimNet是一个计算短文本相似度的模型,可以根据用户输入的两个文本,计算出相似度得分。该PaddleHub Module基于百度海量搜索数据进行训练,支持命令行和Python接口进行预测 |
#### 文本审核
文本审核也是NLP方向的一个常用任务,可以广泛应用在各种信息分发平台、论坛、讨论区的文本审核中。
| 模型 | 简介 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
[
porn_detection_gru
](
https://www.paddlepaddle.org.cn/hubdetail?name=porn_detection_gru&en_category=TextCensorship
)
| 色情检测模型可自动判别文本是否涉黄并给出相应的置信度,对文本中的色情描述、低俗交友、污秽文爱进行识别。porn_detection_gru采用GRU网络结构并按字粒度进行切词。该模型最大句子长度为256字,仅支持预测。 |
#### 语义表示
[
PaddleLARK
](
https://github.com/PaddlePaddle/models/tree/release/1.7/PaddleNLP/pretrain_language_models
)
通过在大规模语料上训练得到的通用的语义表示模型,可以助益其他自然语言处理任务,是通用预训练 + 特定任务精调范式的体现。PaddleLARK 集成了 ELMO,BERT,ERNIE 1.0,ERNIE 2.0,XLNet 等热门中英文预训练模型。
| 模型 | 简介 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
[
ERNIE
](
https://www.paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel
)
| ERNIE通过建模海量数据中的词、实体及实体关系,学习真实世界的语义知识。相较于BERT学习原始语言信号,ERNIE直接对先验语义知识单元进行建模,增强了模型语义表示能力,以Transformer为网络基本组件,以Masked Bi-Language Model和Next Sentence Prediction为训练目标,通过预训练得到通用语义表示,再结合简单的输出层,应用到下游的NLP任务,在多个任务上取得了SOTA的结果。其可用于文本分类、序列标注、阅读理解等任务。预训练数据集为百科类、资讯类、论坛对话类数据等中文语料。该PaddleHub Module可支持Fine-tune。 |
|
[
BERT
](
https://www.paddlepaddle.org.cn/hubdetail?name=bert_multi_uncased_L-12_H-768_A-12&en_category=SemanticModel
)
| 一个迁移能力很强的通用语义表示模型, 以 Transformer 为网络基本组件,以双向 Masked Language Model和 Next Sentence Prediction 为训练目标,通过预训练得到通用语义表示,再结合简单的输出层,应用到下游的 NLP 任务,在多个任务上取得了 SOTA 的结果。 |
|
[
RoBERTa
](
https://www.paddlepaddle.org.cn/hubdetail?name=rbtl3&en_category=SemanticModel
)
| RoBERTa (a Robustly Optimized BERT Pretraining Approach) 是BERT通用语义表示模型的一个优化版,它在BERT模型的基础上提出了Dynamic Masking方法、去除了Next Sentence Prediction目标,同时在更多的数据上采用更大的batch size训练更长的时间,在多个任务中做到了SOTA。rbtl3以roberta_wwm_ext_chinese_L-24_H-1024_A-16模型参数初始化前三层Transformer以及词向量层并在此基础上继续训练了1M步,在仅损失少量效果的情况下大幅减少参数量,得到推断速度的进一步提升。当该PaddleHub Module用于Fine-tune时,其输入是单文本(如Fine-tune的任务为情感分类等)或文本对(如Fine-tune任务为文本语义相似度匹配等),可用于文本分类、序列标注、阅读理解等任务。 |
|
[
chinese-bert
](
https://www.paddlepaddle.org.cn/hubdetail?name=chinese-bert-wwm&en_category=SemanticModel
)
| chinese_bert_wwm是支持中文的BERT模型,它采用全词遮罩(Whole Word Masking)技术,考虑到了中文分词问题。预训练数据集为中文维基百科。该PaddleHub Module只支持Fine-tune。当该PaddleHub Module用于Fine-tune时,其输入是单文本(如Fine-tune的任务为情感分类等)或文本对(如Fine-tune任务为文本语义相似度匹配等),可用于文本分类、序列标注、阅读理解等任务。 |
## 视频
视频数据包含语音、图像等多种信息,因此理解视频任务不仅需要处理语音和图像,还需要提取视频帧时间序列中的上下文信息。视频分类模型提供了提取全局时序特征的方法,主要方式有卷积神经网络 (C3D, I3D, C2D等),神经网络和传统图像算法结合 (VLAD 等),循环神经网络等建模方法。视频动作定位模型需要同时识别视频动作的类别和起止时间点,通常采用类似于图像目标检测中的算法在时间维度上进行建模。
| 模型名称 | 模型简介 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
[
TSN
](
https://www.paddlepaddle.org.cn/hubdetail?name=tsn_kinetics400&en_category=VideoClassification
)
| TSN(Temporal Segment Network)是视频分类领域经典的基于2D-CNN的解决方案。该方法主要解决视频的长时间行为判断问题,通过稀疏采样视频帧的方式代替稠密采样,既能捕获视频全局信息,也能去除冗余,降低计算量。最终将每帧特征平均融合后得到视频的整体特征,并用于分类。TSN的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。该PaddleHub Module可支持预测。 |
|
[
Non-Local
](
https://www.paddlepaddle.org.cn/hubdetail?name=tsn_kinetics400&en_category=VideoClassification
)
| Non-local Neural Networks是由Xiaolong Wang等研究者在2017年提出的模型,主要特点是通过引入Non-local操作来描述距离较远的像素点之间的关联关系。其借助于传统计算机视觉中的non-local mean的思想,并将该思想扩展到神经网络中,通过定义输出位置和所有输入位置之间的关联函数,建立全局关联特性。Non-local模型的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。该PaddleHub Module可支持预测。 |
|
[
StNet
](
https://www.paddlepaddle.org.cn/hubdetail?name=stnet_kinetics400&en_category=VideoClassification
)
| StNet模型框架为ActivityNet Kinetics Challenge 2018中夺冠的基础网络框架,是基于ResNet50实现的。该模型提出super-image的概念,在super-image上进行2D卷积,建模视频中局部时空相关性。另外通过temporal modeling block建模视频的全局时空依赖,最后用一个temporal Xception block对抽取的特征序列进行长时序建模。StNet的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。该PaddleHub Module可支持预测。 |
|
[
TSM
](
https://www.paddlepaddle.org.cn/hubdetail?name=tsm_kinetics400&en_category=VideoClassification
)
| TSM(Temporal Shift Module)是由MIT和IBM Watson AI Lab的JiLin,ChuangGan和SongHan等人提出的通过时间位移来提高网络视频理解能力的模块。TSM的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。该PaddleHub Module可支持预测。 |
docs/reference/config.md
浏览文件 @
22c4494f
...
...
@@ -8,8 +8,8 @@
hub
.
RunConfig
(
log_interval
=
10
,
eval_interval
=
100
,
use_pyreader
=
Fals
e
,
use_data_parallel
=
Fals
e
,
use_pyreader
=
Tru
e
,
use_data_parallel
=
Tru
e
,
save_ckpt_interval
=
None
,
use_cuda
=
False
,
checkpoint_dir
=
None
,
...
...
@@ -22,8 +22,8 @@ hub.RunConfig(
*
`log_interval`
: 打印训练日志的周期,默认为10。
*
`eval_interval`
: 进行评估的周期,默认为100。
*
`use_pyreader`
: 是否使用pyreader,默认
Fals
e。
*
`use_data_parallel`
: 是否使用并行计算,默认
Fals
e。打开该功能依赖nccl库。
*
`use_pyreader`
: 是否使用pyreader,默认
Tru
e。
*
`use_data_parallel`
: 是否使用并行计算,默认
Tru
e。打开该功能依赖nccl库。
*
`save_ckpt_interval`
: 保存checkpoint的周期,默认为None。
*
`use_cuda`
: 是否使用GPU训练和评估,默认为False。
*
`checkpoint_dir`
: checkpoint的保存目录,默认为None,此时会在工作目录下根据时间戳生成一个临时目录。
...
...
docs/reference/task/base_task.md
浏览文件 @
22c4494f
...
...
@@ -169,15 +169,6 @@ import paddlehub as hub
task
.
predict
()
```
## Func `predict`
根据config配置进行predict
**示例**
```
python
import
paddlehub
as
hub
...
task
.
predict
()
```
## Property `is_train_phase`
判断是否处于训练阶段
...
...
docs/tutorial/define_task_example.md
0 → 100644
浏览文件 @
22c4494f
# 如何修改Task中的模型网络
在应用中,用户需要更换迁移网络结构以调整模型在数据集上的性能。根据
[
如何自定义Task
](
./how_to_define_task.md
)
,本教程展示如何修改Task中的默认网络。
以序列标注任务为例,本教程展示如何修改默认网络结构。SequenceLabelTask提供了两种网络选择,一种是FC网络,一种是FC+CRF网络。
此时如果想在这基础之上,添加LSTM网络,组成BiLSTM+CRF的一种序列标注任务常用网络结构。
此时,需要定义一个Task,继承自SequenceLabelTask,并改写其中build_net()方法。
下方代码示例写了一个BiLSTM+CRF的网络。代码如下:
```
python
class
SequenceLabelTask_BiLSTMCRF
(
SequenceLabelTask
):
def
_build_net
(
self
):
"""
自定义序列标注迁移网络结构BiLSTM+CRF
"""
self
.
seq_len
=
fluid
.
layers
.
data
(
name
=
"seq_len"
,
shape
=
[
1
],
dtype
=
'int64'
,
lod_level
=
0
)
if
version_compare
(
paddle
.
__version__
,
"1.6"
):
self
.
seq_len_used
=
fluid
.
layers
.
squeeze
(
self
.
seq_len
,
axes
=
[
1
])
else
:
self
.
seq_len_used
=
self
.
seq_len
if
self
.
add_crf
:
# 迁移网络为BiLSTM+CRF
# 去padding
unpad_feature
=
fluid
.
layers
.
sequence_unpad
(
self
.
feature
,
length
=
self
.
seq_len_used
)
# bilstm层
hid_dim
=
128
fc0
=
fluid
.
layers
.
fc
(
input
=
unpad_feature
,
size
=
hid_dim
*
4
)
rfc0
=
fluid
.
layers
.
fc
(
input
=
unpad_feature
,
size
=
hid_dim
*
4
)
lstm_h
,
c
=
fluid
.
layers
.
dynamic_lstm
(
input
=
fc0
,
size
=
hid_dim
*
4
,
is_reverse
=
False
)
rlstm_h
,
c
=
fluid
.
layers
.
dynamic_lstm
(
input
=
rfc0
,
size
=
hid_dim
*
4
,
is_reverse
=
True
)
# 拼接lstm
lstm_concat
=
fluid
.
layers
.
concat
(
input
=
[
lstm_h
,
rlstm_h
],
axis
=
1
)
self
.
emission
=
fluid
.
layers
.
fc
(
size
=
self
.
num_classes
,
input
=
lstm_concat
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Uniform
(
low
=-
0.1
,
high
=
0.1
),
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
1e-4
)))
size
=
self
.
emission
.
shape
[
1
]
fluid
.
layers
.
create_parameter
(
shape
=
[
size
+
2
,
size
],
dtype
=
self
.
emission
.
dtype
,
name
=
'crfw'
)
# CRF层
self
.
ret_infers
=
fluid
.
layers
.
crf_decoding
(
input
=
self
.
emission
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'crfw'
))
ret_infers
=
fluid
.
layers
.
assign
(
self
.
ret_infers
)
# 返回预测值,list类型
return
[
ret_infers
]
else
:
# 迁移网络为FC
self
.
logits
=
fluid
.
layers
.
fc
(
input
=
self
.
feature
,
size
=
self
.
num_classes
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"cls_seq_label_out_w"
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
scale
=
0.02
)),
bias_attr
=
fluid
.
ParamAttr
(
name
=
"cls_seq_label_out_b"
,
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))
self
.
ret_infers
=
fluid
.
layers
.
reshape
(
x
=
fluid
.
layers
.
argmax
(
self
.
logits
,
axis
=
2
),
shape
=
[
-
1
,
1
])
logits
=
self
.
logits
logits
=
fluid
.
layers
.
flatten
(
logits
,
axis
=
2
)
logits
=
fluid
.
layers
.
softmax
(
logits
)
self
.
num_labels
=
logits
.
shape
[
1
]
# 返回预测值,list类型
return
[
logits
]
```
以上代码通过继承PaddleHub已经内置的Task,改写其中_build_net方法即可实现自定义迁移网络结构。
docs/tutorial/finetuned_model_to_module.md
浏览文件 @
22c4494f
...
...
@@ -148,7 +148,9 @@ def _initialize(self,
初始化过程即为Fine-tune时创建Task的过程。
**NOTE:**
执行类的初始化不能使用默认的__init__接口,而是应该重载实现_initialize接口。对象默认内置了directory属性,可以直接获取到Module所在路径
**NOTE:**
1.
执行类的初始化不能使用默认的__init__接口,而是应该重载实现_initialize接口。对象默认内置了directory属性,可以直接获取到Module所在路径。
2.
使用Fine-tune保存的模型预测时,无需加载数据集Dataset,即Reader中的dataset参数可为None。
#### step 3_4. 完善预测逻辑
```
python
...
...
@@ -160,7 +162,14 @@ def predict(self, data, return_result=False, accelerate_mode=True):
data
=
data
,
return_result
=
return_result
,
accelerate_mode
=
accelerate_mode
)
return
run_states
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
prediction
=
[]
for
batch_result
in
results
:
# get predict index
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
batch_result
=
batch_result
.
tolist
()
prediction
+=
batch_result
return
prediction
```
#### step 3_5. 支持serving调用
...
...
@@ -179,7 +188,14 @@ def predict(self, data, return_result=False, accelerate_mode=True):
data
=
data
,
return_result
=
return_result
,
accelerate_mode
=
accelerate_mode
)
return
run_states
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
prediction
=
[]
for
batch_result
in
results
:
# get predict index
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
batch_result
=
batch_result
.
tolist
()
prediction
+=
batch_result
return
prediction
```
### 完整代码
...
...
@@ -214,15 +230,9 @@ ernie_tiny = hub.Module(name="ernie_tiny_finetuned")
data
=
[[
"这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"
],
[
"交通方便;环境很好;服务态度很好 房间较小"
],
[
"19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"
]]
index
=
0
run_states
=
ernie_tiny
.
predict
(
data
=
data
)
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
for
batch_result
in
results
:
# get predict index
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
for
result
in
batch_result
:
print
(
"%s
\t
predict=%s"
%
(
data
[
index
][
0
],
result
))
index
+=
1
predictions
=
ernie_tiny
.
predict
(
data
=
data
)
for
index
,
text
in
enumerate
(
data
):
print
(
"%s
\t
predict=%s"
%
(
data
[
index
][
0
],
predictions
[
index
]))
```
### 调用方法2
...
...
@@ -238,15 +248,9 @@ ernie_tiny_finetuned = hub.Module(directory="finetuned_model_to_module/")
data
=
[[
"这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"
],
[
"交通方便;环境很好;服务态度很好 房间较小"
],
[
"19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"
]]
index
=
0
run_states
=
ernie_tiny
.
predict
(
data
=
data
)
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
for
batch_result
in
results
:
# get predict index
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
for
result
in
batch_result
:
print
(
"%s
\t
predict=%s"
%
(
data
[
index
][
0
],
result
))
index
+=
1
predictions
=
ernie_tiny
.
predict
(
data
=
data
)
for
index
,
text
in
enumerate
(
data
):
print
(
"%s
\t
predict=%s"
%
(
data
[
index
][
0
],
predictions
[
index
]))
```
### 调用方法3
...
...
@@ -263,13 +267,42 @@ import numpy as np
data
=
[[
"这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"
],
[
"交通方便;环境很好;服务态度很好 房间较小"
],
[
"19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"
]]
run_states
=
ERNIETinyFinetuned
.
predict
(
data
=
data
)
index
=
0
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
for
batch_result
in
results
:
# get predict index
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
for
result
in
batch_result
:
print
(
"%s
\t
predict=%s"
%
(
data
[
index
][
0
],
result
))
index
+=
1
predictions
=
ERNIETinyFinetuned
.
predict
(
data
=
data
)
for
index
,
text
in
enumerate
(
data
):
print
(
"%s
\t
predict=%s"
%
(
data
[
index
][
0
],
predictions
[
index
]))
```
### PaddleHub Serving调用方法
**第一步:启动预测服务**
```
shell
hub serving start
-m
ernie_tiny_finetuned
```
**第二步:发送请求,获取预测结果**
通过如下脚本既可以发送请求:
```
python
# coding: utf8
import
requests
import
json
# 待预测文本
texts
=
[[
"这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"
],
[
"交通方便;环境很好;服务态度很好 房间较小"
],
[
"19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"
]]
# key为'data', 对应着预测接口predict的参数data
data
=
{
'data'
:
texts
}
# 指定模型为ernie_tiny_finetuned并发送post请求,且请求的headers为application/json方式
url
=
"http://127.0.0.1:8866/predict/ernie_tiny_finetuned"
headers
=
{
"Content-Type"
:
"application/json"
}
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
# 打印预测结果
print
(
json
.
dumps
(
r
.
json
(),
indent
=
4
,
ensure_ascii
=
False
))
```
关与PaddleHub Serving更多信息参见
[
Hub Serving教程
](
../../docs/tutorial/serving.md
)
以及
[
Demo
](
../../demo/serving
)
docs/tutorial/how_to_load_data.md
浏览文件 @
22c4494f
...
...
@@ -22,6 +22,7 @@
如果您有两个输入文本text_a、text_b,则第一列为第一个输入文本text_a, 第二列应为第二个输入文本text_b,第三列文本类别label。列与列之间以Tab键分隔。数据集第一行为
`text_a text_b label`
(中间以Tab键分隔)。
```
text
text_a label
15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错 1
...
...
@@ -36,6 +37,7 @@ text_a label
*
数据集文件编码格式建议为utf8格式。
*
如果相应的数据集文件没有上述的列说明,如train.tsv文件没有第一行的
`text_a label`
,则train_file_with_header=False。
*
如果您还有预测数据(没有文本类别),可以将预测数据存放在predict.tsv文件,文件格式和train.tsv类似。去掉label一列即可。
*
分类任务中,数据集的label必须从0开始计数
```
python
...
...
@@ -117,6 +119,7 @@ dog
*
训练/验证/测试集的数据列表文件中的图片路径需要相对于dataset_dir的相对路径,例如图片的实际位置为
`/test/data/dog/dog1.jpg`
。base_path为
`/test/data`
,则文件中填写的路径应该为
`dog/dog1.jpg`
。
*
如果您还有预测数据(没有文本类别),可以将预测数据存放在predict_list.txt文件,文件格式和train_list.txt类似。去掉label一列即可
*
如果您的数据集类别较少,可以不用定义label_list.txt,可以选择定义label_list=["数据集所有类别"]。
*
分类任务中,数据集的label必须从0开始计数
```
python
from
paddlehub.dataset.base_cv_dataset
import
BaseCVDataset
...
...
hub_module/scripts/configs/faster_rcnn_resnet50_fpn_venus.yml
浏览文件 @
22c4494f
name
:
faster_rcnn_resnet50_fpn_venus
dir
:
"
modules/image/object_detection/faster_rcnn_resnet50_fpn_venus"
#
resources:
#
-
# url: https://paddlehub.bj.bcebos.com/model/cv/faster_rcnn_resnet50_fpn
_model.tar.gz
#
dest: faster_rcnn_resnet50_fpn_model
#
uncompress: True
resources
:
-
url
:
https://paddlehub.bj.bcebos.com/model/cv/faster_rcnn_resnet50_fpn_venus
_model.tar.gz
dest
:
faster_rcnn_resnet50_fpn_model
uncompress
:
True
paddlehub/__init__.py
浏览文件 @
22c4494f
...
...
@@ -39,6 +39,7 @@ from .common.logger import logger
from
.common.paddle_helper
import
connect_program
from
.common.hub_server
import
HubServer
from
.common.hub_server
import
server_check
from
.common.downloader
import
download
,
ResourceNotFoundError
,
ServerConnectionError
from
.module.module
import
Module
from
.module.base_processor
import
BaseProcessor
...
...
paddlehub/common/downloader.py
浏览文件 @
22c4494f
#coding:utf-8
#
coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the
"License"
# Licensed under the Apache License, Version 2.0 (the
'License'
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an
"AS IS"
BASIS,
# distributed under the License is distributed on an
'AS IS'
BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...
...
@@ -28,6 +28,8 @@ import tarfile
from
paddlehub.common
import
utils
from
paddlehub.common.logger
import
logger
from
paddlehub.common
import
tmp_dir
import
paddlehub
as
hub
__all__
=
[
'Downloader'
,
'progress'
]
FLUSH_INTERVAL
=
0.1
...
...
@@ -38,10 +40,10 @@ lasttime = time.time()
def
progress
(
str
,
end
=
False
):
global
lasttime
if
end
:
str
+=
"
\n
"
str
+=
'
\n
'
lasttime
=
0
if
time
.
time
()
-
lasttime
>=
FLUSH_INTERVAL
:
sys
.
stdout
.
write
(
"
\r
%s"
%
str
)
sys
.
stdout
.
write
(
'
\r
%s'
%
str
)
lasttime
=
time
.
time
()
sys
.
stdout
.
flush
()
...
...
@@ -67,7 +69,7 @@ class Downloader(object):
if
retry_times
<
retry_limit
:
retry_times
+=
1
else
:
tips
=
"Cannot download {0} within retry limit {1}"
.
format
(
tips
=
'Cannot download {0} within retry limit {1}'
.
format
(
url
,
retry_limit
)
return
False
,
tips
,
None
r
=
requests
.
get
(
url
,
stream
=
True
)
...
...
@@ -82,19 +84,19 @@ class Downloader(object):
total_length
=
int
(
total_length
)
starttime
=
time
.
time
()
if
print_progress
:
print
(
"Downloading %s"
%
save_name
)
print
(
'Downloading %s'
%
save_name
)
for
data
in
r
.
iter_content
(
chunk_size
=
4096
):
dl
+=
len
(
data
)
f
.
write
(
data
)
if
print_progress
:
done
=
int
(
50
*
dl
/
total_length
)
progress
(
"[%-50s] %.2f%%"
%
'[%-50s] %.2f%%'
%
(
'='
*
done
,
float
(
dl
/
total_length
*
100
)))
if
print_progress
:
progress
(
"[%-50s] %.2f%%"
%
(
'='
*
50
,
100
),
end
=
True
)
progress
(
'[%-50s] %.2f%%'
%
(
'='
*
50
,
100
),
end
=
True
)
tips
=
"File %s download completed!"
%
(
file_name
)
tips
=
'File %s download completed!'
%
(
file_name
)
return
True
,
tips
,
file_name
def
uncompress
(
self
,
...
...
@@ -104,24 +106,25 @@ class Downloader(object):
print_progress
=
False
):
dirname
=
os
.
path
.
dirname
(
file
)
if
dirname
is
None
else
dirname
if
print_progress
:
print
(
"Uncompress %s"
%
file
)
with
tarfile
.
open
(
file
,
"r:gz"
)
as
tar
:
print
(
'Uncompress %s'
%
file
)
with
tarfile
.
open
(
file
,
'r:*'
)
as
tar
:
file_names
=
tar
.
getnames
()
size
=
len
(
file_names
)
-
1
module_dir
=
os
.
path
.
join
(
dirname
,
file_names
[
0
])
for
index
,
file_name
in
enumerate
(
file_names
):
if
print_progress
:
done
=
int
(
50
*
float
(
index
)
/
size
)
progress
(
"[%-50s] %.2f%%"
%
(
'='
*
done
,
progress
(
'[%-50s] %.2f%%'
%
(
'='
*
done
,
float
(
index
/
size
*
100
)))
tar
.
extract
(
file_name
,
dirname
)
if
print_progress
:
progress
(
"[%-50s] %.2f%%"
%
(
'='
*
50
,
100
),
end
=
True
)
progress
(
'[%-50s] %.2f%%'
%
(
'='
*
50
,
100
),
end
=
True
)
if
delete_file
:
os
.
remove
(
file
)
return
True
,
"File %s uncompress completed!"
%
file
,
module_dir
return
True
,
'File %s uncompress completed!'
%
file
,
module_dir
def
download_file_and_uncompress
(
self
,
url
,
...
...
@@ -147,8 +150,62 @@ class Downloader(object):
if
save_name
:
save_name
=
os
.
path
.
join
(
save_path
,
save_name
)
shutil
.
move
(
file
,
save_name
)
return
result
,
"%s
\n
%s"
%
(
tips_1
,
tips_2
),
save_name
return
result
,
"%s
\n
%s"
%
(
tips_1
,
tips_2
),
file
return
result
,
'%s
\n
%s'
%
(
tips_1
,
tips_2
),
save_name
return
result
,
'%s
\n
%s'
%
(
tips_1
,
tips_2
),
file
default_downloader
=
Downloader
()
class
ResourceNotFoundError
(
Exception
):
def
__init__
(
self
,
name
,
version
=
None
):
self
.
name
=
name
self
.
version
=
version
def
__str__
(
self
):
if
self
.
version
:
tips
=
'No resource named {} was found'
.
format
(
self
.
name
)
else
:
tips
=
'No resource named {}-{} was found'
.
format
(
self
.
name
,
self
.
version
)
return
tips
class
ServerConnectionError
(
Exception
):
def
__str__
(
self
):
tips
=
'Can
\'
t connect to Hub Server:{}'
.
format
(
hub
.
HubServer
().
server_url
[
0
])
return
tips
def
download
(
name
,
save_path
,
version
=
None
,
decompress
=
True
,
resource_type
=
'Model'
,
extra
=
None
):
file
=
os
.
path
.
join
(
save_path
,
name
)
file
=
os
.
path
.
realpath
(
file
)
if
os
.
path
.
exists
(
file
):
return
if
not
hub
.
HubServer
().
_server_check
():
raise
ServerConnectionError
search_result
=
hub
.
HubServer
().
get_resource_url
(
name
,
resource_type
=
resource_type
,
version
=
version
,
extra
=
extra
)
if
not
search_result
:
raise
ResourceNotFoundError
(
name
,
version
)
url
=
search_result
[
'url'
]
with
tmp_dir
()
as
_dir
:
if
not
os
.
path
.
exists
(
save_path
):
os
.
makedirs
(
save_path
)
_
,
_
,
savefile
=
default_downloader
.
download_file
(
url
=
url
,
save_path
=
_dir
,
print_progress
=
True
)
if
tarfile
.
is_tarfile
(
savefile
)
and
decompress
:
_
,
_
,
savefile
=
default_downloader
.
uncompress
(
file
=
savefile
,
print_progress
=
True
)
shutil
.
move
(
savefile
,
file
)
paddlehub/common/hub_server.py
浏览文件 @
22c4494f
...
...
@@ -46,7 +46,8 @@ class HubServer(object):
config_file_path
=
os
.
path
.
join
(
CONF_HOME
,
'config.json'
)
if
not
os
.
path
.
exists
(
CONF_HOME
):
utils
.
mkdir
(
CONF_HOME
)
if
not
os
.
path
.
exists
(
config_file_path
):
if
not
os
.
path
.
exists
(
config_file_path
)
or
0
==
os
.
path
.
getsize
(
config_file_path
):
with
open
(
config_file_path
,
'w+'
)
as
fp
:
lock
.
flock
(
fp
,
lock
.
LOCK_EX
)
fp
.
write
(
json
.
dumps
(
default_server_config
))
...
...
paddlehub/common/logger.py
浏览文件 @
22c4494f
...
...
@@ -62,7 +62,8 @@ class Logger(object):
self
.
logger
.
setLevel
(
logging
.
DEBUG
)
self
.
logger
.
propagate
=
False
if
os
.
path
.
exists
(
os
.
path
.
join
(
CONF_HOME
,
"config.json"
)):
config_path
=
os
.
path
.
join
(
CONF_HOME
,
"config.json"
)
if
os
.
path
.
exists
(
config_path
)
and
0
<
os
.
path
.
getsize
(
config_path
):
with
open
(
os
.
path
.
join
(
CONF_HOME
,
"config.json"
),
"r"
)
as
fp
:
level
=
json
.
load
(
fp
).
get
(
"log_level"
,
"DEBUG"
)
self
.
logLevel
=
level
...
...
paddlehub/common/paddle_helper.py
浏览文件 @
22c4494f
...
...
@@ -19,10 +19,11 @@ from __future__ import print_function
import
copy
import
paddle
import
paddle.fluid
as
fluid
from
paddlehub.module
import
module_desc_pb2
from
paddlehub.common.utils
import
from_pyobj_to_module_attr
,
from_module_attr_to_pyobj
from
paddlehub.common.utils
import
from_pyobj_to_module_attr
,
from_module_attr_to_pyobj
,
version_compare
from
paddlehub.common.logger
import
logger
dtype_map
=
{
...
...
@@ -62,6 +63,7 @@ def get_variable_info(var):
var_info
[
'trainable'
]
=
var
.
trainable
var_info
[
'optimize_attr'
]
=
var
.
optimize_attr
var_info
[
'regularizer'
]
=
var
.
regularizer
if
not
version_compare
(
paddle
.
__version__
,
'1.8'
):
var_info
[
'gradient_clip_attr'
]
=
var
.
gradient_clip_attr
var_info
[
'do_model_average'
]
=
var
.
do_model_average
else
:
...
...
paddlehub/dataset/food101.py
浏览文件 @
22c4494f
...
...
@@ -25,11 +25,11 @@ from paddlehub.dataset.base_cv_dataset import BaseCVDataset
class
Food101Dataset
(
BaseCVDataset
):
def
__init__
(
self
):
dataset_path
=
os
.
path
.
join
(
hub
.
common
.
dir
.
DATA_HOME
,
"food-101"
,
"images"
)
base_path
=
self
.
_download_dataset
(
dataset_path
=
os
.
path
.
join
(
hub
.
common
.
dir
.
DATA_HOME
,
"food-101"
)
dataset_path
=
self
.
_download_dataset
(
dataset_path
=
dataset_path
,
url
=
"https://bj.bcebos.com/paddlehub-dataset/Food101.tar.gz"
)
base_path
=
os
.
path
.
join
(
dataset_path
,
"images"
)
super
(
Food101Dataset
,
self
).
__init__
(
base_path
=
base_path
,
train_list_file
=
"train_list.txt"
,
...
...
paddlehub/module/manager.py
浏览文件 @
22c4494f
...
...
@@ -96,8 +96,10 @@ class LocalModuleManager(object):
for
sub_dir_name
in
os
.
listdir
(
self
.
local_modules_dir
):
sub_dir_path
=
os
.
path
.
join
(
self
.
local_modules_dir
,
sub_dir_name
)
if
os
.
path
.
isdir
(
sub_dir_path
):
if
"-"
in
sub_dir_path
:
new_sub_dir_path
=
sub_dir_path
.
replace
(
"-"
,
"_"
)
if
"-"
in
sub_dir_name
:
sub_dir_name
=
sub_dir_name
.
replace
(
"-"
,
"_"
)
new_sub_dir_path
=
os
.
path
.
join
(
self
.
local_modules_dir
,
sub_dir_name
)
shutil
.
move
(
sub_dir_path
,
new_sub_dir_path
)
sub_dir_path
=
new_sub_dir_path
valid
,
info
=
self
.
check_module_valid
(
sub_dir_path
)
...
...
@@ -180,11 +182,13 @@ class LocalModuleManager(object):
with
tarfile
.
open
(
module_package
,
"r:gz"
)
as
tar
:
file_names
=
tar
.
getnames
()
size
=
len
(
file_names
)
-
1
module_dir
=
os
.
path
.
join
(
_dir
,
file_names
[
0
])
module_name
=
file_names
[
0
]
module_dir
=
os
.
path
.
join
(
_dir
,
module_name
)
for
index
,
file_name
in
enumerate
(
file_names
):
tar
.
extract
(
file_name
,
_dir
)
if
"-"
in
module_dir
:
new_module_dir
=
module_dir
.
replace
(
"-"
,
"_"
)
if
"-"
in
module_name
:
module_name
=
module_name
.
replace
(
"-"
,
"_"
)
new_module_dir
=
os
.
path
.
join
(
_dir
,
module_name
)
shutil
.
move
(
module_dir
,
new_module_dir
)
module_dir
=
new_module_dir
module_name
=
hub
.
Module
(
directory
=
module_dir
).
name
...
...
paddlehub/module/module.py
浏览文件 @
22c4494f
...
...
@@ -89,7 +89,7 @@ def moduleinfo(name, version, author, author_email, summary, type):
return
_wrapper
class
Module
(
object
):
class
Module
(
fluid
.
dygraph
.
Layer
):
def
__new__
(
cls
,
name
=
None
,
directory
=
None
,
...
...
@@ -121,7 +121,7 @@ class Module(object):
module
=
Module
.
init_with_directory
(
directory
=
directory
,
**
kwargs
)
else
:
module
=
object
.
__new__
(
cls
)
module
=
fluid
.
dygraph
.
Layer
.
__new__
(
cls
)
return
module
...
...
@@ -135,6 +135,7 @@ class Module(object):
if
"_is_initialize"
in
self
.
__dict__
and
self
.
_is_initialize
:
return
super
(
Module
,
self
).
__init__
()
_run_func_name
=
self
.
_get_func_name
(
self
.
__class__
,
_module_runnable_func
)
self
.
_run_func
=
getattr
(
self
,
...
...
@@ -248,6 +249,10 @@ class Module(object):
def
_initialize
(
self
):
pass
def
forward
(
self
,
*
args
,
**
kwargs
):
raise
RuntimeError
(
'{} does not support dynamic graph mode yet.'
.
format
(
self
.
name
))
class
ModuleHelper
(
object
):
def
__init__
(
self
,
directory
):
...
...
paddlehub/module/nlp_module.py
浏览文件 @
22c4494f
...
...
@@ -24,13 +24,15 @@ import os
import
re
import
six
import
paddle
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddlehub.common
import
paddle_helper
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
import
paddlehub
as
hub
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
from
paddlehub.common
import
paddle_helper
,
tmp_dir
from
paddlehub.common.logger
import
logger
from
paddlehub.common.utils
import
sys_stdin_encoding
from
paddlehub.common.utils
import
sys_stdin_encoding
,
version_compare
from
paddlehub.io.parser
import
txt_parser
from
paddlehub.module.module
import
runnable
...
...
@@ -246,6 +248,45 @@ class TransformerModule(NLPBaseModule):
Tranformer Module base class can be used by BERT, ERNIE, RoBERTa and so on.
"""
def
__init__
(
self
,
name
=
None
,
directory
=
None
,
module_dir
=
None
,
version
=
None
,
max_seq_len
=
128
,
**
kwargs
):
if
not
directory
:
return
super
(
TransformerModule
,
self
).
__init__
(
name
=
name
,
directory
=
directory
,
module_dir
=
module_dir
,
version
=
version
,
**
kwargs
)
self
.
max_seq_len
=
max_seq_len
if
version_compare
(
paddle
.
__version__
,
'1.8.0'
):
with
tmp_dir
()
as
_dir
:
input_dict
,
output_dict
,
program
=
self
.
context
(
max_seq_len
=
max_seq_len
)
fluid
.
io
.
save_inference_model
(
dirname
=
_dir
,
main_program
=
program
,
feeded_var_names
=
[
input_dict
[
'input_ids'
].
name
,
input_dict
[
'position_ids'
].
name
,
input_dict
[
'segment_ids'
].
name
,
input_dict
[
'input_mask'
].
name
],
target_vars
=
[
output_dict
[
"pooled_output"
],
output_dict
[
"sequence_output"
]
],
executor
=
fluid
.
Executor
(
fluid
.
CPUPlace
()))
with
fluid
.
dygraph
.
guard
():
self
.
model_runner
=
fluid
.
dygraph
.
StaticModelRunner
(
_dir
)
def
init_pretraining_params
(
self
,
exe
,
pretraining_params_path
,
main_program
):
assert
os
.
path
.
exists
(
...
...
@@ -271,7 +312,7 @@ class TransformerModule(NLPBaseModule):
def
context
(
self
,
max_seq_len
=
128
,
max_seq_len
=
None
,
trainable
=
True
,
):
"""
...
...
@@ -287,6 +328,9 @@ class TransformerModule(NLPBaseModule):
"""
if
not
max_seq_len
:
max_seq_len
=
self
.
max_seq_len
assert
max_seq_len
<=
self
.
MAX_SEQ_LEN
and
max_seq_len
>=
1
,
"max_seq_len({}) should be in the range of [1, {}]"
.
format
(
max_seq_len
,
self
.
MAX_SEQ_LEN
)
...
...
@@ -431,3 +475,16 @@ class TransformerModule(NLPBaseModule):
"The module context has not been initialized. "
"Please call context() before using get_params_layer"
)
return
self
.
params_layer
def
forward
(
self
,
input_ids
,
position_ids
,
segment_ids
,
input_mask
):
if
version_compare
(
paddle
.
__version__
,
'1.8.0'
):
pooled_output
,
sequence_output
=
self
.
model_runner
(
input_ids
,
position_ids
,
segment_ids
,
input_mask
)
return
{
'pooled_output'
:
pooled_output
,
'sequence_output'
:
sequence_output
}
else
:
raise
RuntimeError
(
'{} only support dynamic graph mode in paddle >= 1.8.0'
.
format
(
self
.
name
))
paddlehub/reader/cv_reader.py
浏览文件 @
22c4494f
...
...
@@ -165,7 +165,7 @@ class ImageClassificationReader(BaseReader):
for
image_path
,
label
in
data
:
image
=
preprocess
(
image_path
)
images
.
append
(
image
.
astype
(
'float32'
))
labels
.
append
([
int
(
label
)])
labels
.
append
([
np
.
int64
(
label
)])
if
len
(
images
)
==
batch_size
:
if
return_list
:
...
...
paddlehub/version.py
浏览文件 @
22c4494f
...
...
@@ -13,5 +13,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PaddleHub version string """
hub_version
=
"1.6.
0
"
hub_version
=
"1.6.
2
"
module_proto_version
=
"1.0.0"
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录