Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
0399449f
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0399449f
编写于
6月 25, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update task finetune api to return run states
上级
93bad059
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
16 addition
and
8 deletion
+16
-8
demo/elmo/predict.py
demo/elmo/predict.py
+2
-1
demo/image-classification/predict.py
demo/image-classification/predict.py
+2
-1
demo/multi-label-classification/predict.py
demo/multi-label-classification/predict.py
+2
-1
demo/senta/predict.py
demo/senta/predict.py
+2
-1
demo/sequence-labeling/predict.py
demo/sequence-labeling/predict.py
+2
-1
demo/text-classification/predict.py
demo/text-classification/predict.py
+2
-1
paddlehub/finetune/task.py
paddlehub/finetune/task.py
+4
-2
未找到文件。
demo/elmo/predict.py
浏览文件 @
0399449f
...
@@ -168,7 +168,8 @@ if __name__ == '__main__':
...
@@ -168,7 +168,8 @@ if __name__ == '__main__':
]
]
index
=
0
index
=
0
results
=
elmo_task
.
predict
(
data
=
data
)
run_states
=
elmo_task
.
predict
(
data
=
data
)
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
for
batch_result
in
results
:
for
batch_result
in
results
:
# get predict index
# get predict index
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
...
...
demo/image-classification/predict.py
浏览文件 @
0399449f
...
@@ -76,7 +76,8 @@ def predict(args):
...
@@ -76,7 +76,8 @@ def predict(args):
label_map
=
dataset
.
label_dict
()
label_map
=
dataset
.
label_dict
()
index
=
0
index
=
0
# get classification result
# get classification result
results
=
task
.
predict
(
data
=
data
)
run_states
=
task
.
predict
(
data
=
data
)
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
for
batch_result
in
results
:
for
batch_result
in
results
:
# get predict index
# get predict index
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
...
...
demo/multi-label-classification/predict.py
浏览文件 @
0399449f
...
@@ -99,7 +99,8 @@ if __name__ == '__main__':
...
@@ -99,7 +99,8 @@ if __name__ == '__main__':
]
]
index
=
0
index
=
0
results
=
multi_label_cls_task
.
predict
(
data
=
data
)
run_states
=
multi_label_cls_task
.
predict
(
data
=
data
)
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
for
result
in
results
:
for
result
in
results
:
# get predict index
# get predict index
label_ids
=
[]
label_ids
=
[]
...
...
demo/senta/predict.py
浏览文件 @
0399449f
...
@@ -60,7 +60,8 @@ if __name__ == '__main__':
...
@@ -60,7 +60,8 @@ if __name__ == '__main__':
data
=
[
"这家餐厅很好吃"
,
"这部电影真的很差劲"
]
data
=
[
"这家餐厅很好吃"
,
"这部电影真的很差劲"
]
results
=
cls_task
.
predict
(
data
=
data
)
run_states
=
cls_task
.
predict
(
data
=
data
)
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
index
=
0
index
=
0
for
batch_result
in
results
:
for
batch_result
in
results
:
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
...
...
demo/sequence-labeling/predict.py
浏览文件 @
0399449f
...
@@ -96,7 +96,8 @@ if __name__ == '__main__':
...
@@ -96,7 +96,8 @@ if __name__ == '__main__':
[
"不过重在晋趣,略增明人气息,妙在集古有道、不露痕迹罢了。"
],
[
"不过重在晋趣,略增明人气息,妙在集古有道、不露痕迹罢了。"
],
]
]
results
=
seq_label_task
.
predict
(
data
=
data
)
run_states
=
seq_label_task
.
predict
(
data
=
data
)
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
for
num_batch
,
batch_results
in
enumerate
(
results
):
for
num_batch
,
batch_results
in
enumerate
(
results
):
infers
=
batch_results
[
0
].
reshape
([
-
1
]).
astype
(
np
.
int32
).
tolist
()
infers
=
batch_results
[
0
].
reshape
([
-
1
]).
astype
(
np
.
int32
).
tolist
()
...
...
demo/text-classification/predict.py
浏览文件 @
0399449f
...
@@ -97,7 +97,8 @@ if __name__ == '__main__':
...
@@ -97,7 +97,8 @@ if __name__ == '__main__':
]
]
index
=
0
index
=
0
results
=
cls_task
.
predict
(
data
=
data
)
run_states
=
cls_task
.
predict
(
data
=
data
)
results
=
[
run_state
.
run_results
for
run_state
in
run_states
]
for
batch_result
in
results
:
for
batch_result
in
results
:
# get predict index
# get predict index
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
...
...
paddlehub/finetune/task.py
浏览文件 @
0399449f
...
@@ -498,7 +498,7 @@ class BasicTask(object):
...
@@ -498,7 +498,7 @@ class BasicTask(object):
self
.
exe
,
dirname
=
dirname
,
main_program
=
self
.
main_program
)
self
.
exe
,
dirname
=
dirname
,
main_program
=
self
.
main_program
)
def
finetune_and_eval
(
self
):
def
finetune_and_eval
(
self
):
self
.
finetune
(
do_eval
=
True
)
return
self
.
finetune
(
do_eval
=
True
)
def
finetune
(
self
,
do_eval
=
False
):
def
finetune
(
self
,
do_eval
=
False
):
# Start to finetune
# Start to finetune
...
@@ -519,6 +519,7 @@ class BasicTask(object):
...
@@ -519,6 +519,7 @@ class BasicTask(object):
self
.
eval
(
phase
=
"test"
)
self
.
eval
(
phase
=
"test"
)
self
.
_finetune_end_event
(
run_states
)
self
.
_finetune_end_event
(
run_states
)
return
run_states
def
eval
(
self
,
phase
=
"dev"
):
def
eval
(
self
,
phase
=
"dev"
):
with
self
.
phase_guard
(
phase
=
phase
):
with
self
.
phase_guard
(
phase
=
phase
):
...
@@ -526,6 +527,7 @@ class BasicTask(object):
...
@@ -526,6 +527,7 @@ class BasicTask(object):
self
.
_eval_start_event
()
self
.
_eval_start_event
()
run_states
=
self
.
_run
()
run_states
=
self
.
_run
()
self
.
_eval_end_event
(
run_states
)
self
.
_eval_end_event
(
run_states
)
return
run_states
def
predict
(
self
,
data
,
load_best_model
=
True
):
def
predict
(
self
,
data
,
load_best_model
=
True
):
with
self
.
phase_guard
(
phase
=
"predict"
):
with
self
.
phase_guard
(
phase
=
"predict"
):
...
@@ -539,7 +541,7 @@ class BasicTask(object):
...
@@ -539,7 +541,7 @@ class BasicTask(object):
run_states
=
self
.
_run
()
run_states
=
self
.
_run
()
self
.
_predict_end_event
(
run_states
)
self
.
_predict_end_event
(
run_states
)
self
.
_predict_data
=
None
self
.
_predict_data
=
None
return
[
run_state
.
run_results
for
run_state
in
run_states
]
return
run_states
def
_run
(
self
,
do_eval
=
False
):
def
_run
(
self
,
do_eval
=
False
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录