Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
67c508e4
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
67c508e4
编写于
4月 22, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix senta to senta_bilstm
上级
fcab4a90
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
10 addition
and
12 deletion
+10
-12
demo/senta/cli_demo.sh
demo/senta/cli_demo.sh
+1
-1
demo/senta/predict.py
demo/senta/predict.py
+4
-6
demo/senta/run_finetune.sh
demo/senta/run_finetune.sh
+2
-2
demo/senta/senta_finetune.py
demo/senta/senta_finetune.py
+3
-3
未找到文件。
demo/senta/cli_demo.sh
浏览文件 @
67c508e4
python ../../paddlehub/commands/hub.py run senta
--input_file
test
/test.txt
python ../../paddlehub/commands/hub.py run senta
_bilstm
--input_file
test
/test.txt
demo/senta/predict.py
浏览文件 @
67c508e4
...
...
@@ -21,7 +21,7 @@ args = parser.parse_args()
if
__name__
==
'__main__'
:
# loading Paddlehub senta pretrained model
module
=
hub
.
Module
(
name
=
"senta"
)
module
=
hub
.
Module
(
name
=
"senta
_bilstm
"
)
inputs
,
outputs
,
program
=
module
.
context
(
trainable
=
True
)
# Sentence classification dataset reader
...
...
@@ -32,13 +32,11 @@ if __name__ == '__main__':
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
with
fluid
.
program_guard
(
program
):
# Use "sequence_output" for classification tasks on an entire sentence.
# Use "sequence_outputs" for token-level output.
sequence_output
=
outputs
[
"sequence_output"
]
sent_feature
=
outputs
[
"sentence_feature"
]
# Define a classfication finetune task by PaddleHub's API
cls_task
=
hub
.
create_text_cls_task
(
feature
=
se
quence_output
,
num_classes
=
dataset
.
num_labels
)
feature
=
se
nt_feature
,
num_classes
=
dataset
.
num_labels
)
# Setup feed list for data feeder
# Must feed all the tensor of senta's module need
...
...
@@ -69,4 +67,4 @@ if __name__ == '__main__':
correct
+=
1
acc
=
1.0
*
correct
/
total
print
(
"%s
\t
predict=%s"
%
(
test_examples
[
index
],
pred_v
[
0
][
0
]))
print
(
"accuracy = %f"
%
acc
)
print
(
"accuracy = %f"
%
acc
)
demo/senta/run_finetune.sh
浏览文件 @
67c508e4
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
5
DATASET
=
"chnsenticorp"
CKPT_DIR
=
"./ckpt_
${
DATASET
}
"
python
-u
senta_finetune.py
\
--batch_size
=
24
\
--use_gpu
=
Fals
e
\
--use_gpu
=
Tru
e
\
--checkpoint_dir
=
${
CKPT_DIR
}
\
--num_epoch
=
3
demo/senta/senta_finetune.py
浏览文件 @
67c508e4
...
...
@@ -15,7 +15,7 @@ args = parser.parse_args()
if
__name__
==
'__main__'
:
# Step1: load Paddlehub senta pretrained model
module
=
hub
.
Module
(
name
=
"senta"
)
module
=
hub
.
Module
(
name
=
"senta
_bilstm
"
)
inputs
,
outputs
,
program
=
module
.
context
(
trainable
=
True
)
# Step2: Download dataset and use TextClassificationReader to read dataset
...
...
@@ -24,7 +24,7 @@ if __name__ == '__main__':
reader
=
hub
.
reader
.
LACClassifyReader
(
dataset
=
dataset
,
vocab_path
=
module
.
get_vocab_path
())
sent_feature
=
outputs
[
"se
quence_output
"
]
sent_feature
=
outputs
[
"se
ntence_feature
"
]
# Define a classfication finetune task by PaddleHub's API
cls_task
=
hub
.
create_text_cls_task
(
...
...
@@ -35,7 +35,7 @@ if __name__ == '__main__':
feed_list
=
[
inputs
[
"words"
].
name
,
cls_task
.
variable
(
'label'
).
name
]
strategy
=
hub
.
finetune
.
strategy
.
AdamWeightDecayStrategy
(
learning_rate
=
1e-
3
,
weight_decay
=
0.01
,
warmup_proportion
=
0.01
)
learning_rate
=
1e-
4
,
weight_decay
=
0.01
,
warmup_proportion
=
0.05
)
config
=
hub
.
RunConfig
(
use_cuda
=
args
.
use_gpu
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录