Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
ead7f993
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
283
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ead7f993
编写于
5月 30, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update predict interface
上级
3b2cceb2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
49 addition
and
31 deletion
+49
-31
demo/image-classification/predict.py
demo/image-classification/predict.py
+12
-10
paddlehub/finetune/checkpoint.py
paddlehub/finetune/checkpoint.py
+16
-11
paddlehub/finetune/task.py
paddlehub/finetune/task.py
+21
-10
未找到文件。
demo/image-classification/predict.py
浏览文件 @
ead7f993
...
...
@@ -8,7 +8,7 @@ import numpy as np
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
bool
,
default
=
False
,
help
=
"Whether use GPU for predict."
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
bool
,
default
=
True
,
help
=
"Whether use GPU for predict."
)
parser
.
add_argument
(
"--checkpoint_dir"
,
type
=
str
,
default
=
"paddlehub_finetune_ckpt"
,
help
=
"Path to save log data."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--module"
,
type
=
str
,
default
=
"resnet50"
,
help
=
"Module used as a feature extractor."
)
...
...
@@ -70,15 +70,17 @@ def predict(args):
data
=
[
"./test/test_img_daisy.jpg"
,
"./test/test_img_roses.jpg"
]
label_map
=
dataset
.
label_dict
()
for
result
in
task
.
predict
(
data
=
data
):
result
=
np
.
argmax
(
result
,
axis
=
2
)
index
=
0
for
batch
in
result
:
for
predict_result
in
batch
:
index
+=
1
predict_result
=
label_map
[
predict_result
]
print
(
"input %i is %s, and the predict result is %s"
%
(
index
,
data
[
index
-
1
],
predict_result
))
index
=
0
# get classification result
results
=
task
.
predict
(
data
=
data
)
for
batch_result
in
results
:
# get predict index
batch_result
=
np
.
argmax
(
batch_result
,
axis
=
2
)[
0
]
for
result
in
batch_result
:
index
+=
1
result
=
label_map
[
result
]
print
(
"input %i is %s, and the predict result is %s"
%
(
index
,
data
[
index
-
1
],
result
))
if
__name__
==
"__main__"
:
...
...
paddlehub/finetune/checkpoint.py
浏览文件 @
ead7f993
...
...
@@ -30,31 +30,36 @@ CKPT_FILE_NAME = "ckpt.meta"
def
load_checkpoint
(
checkpoint_dir
,
exe
,
main_program
=
fluid
.
default_main_program
(),
startup_program
=
fluid
.
default_startup_program
()):
startup_program
=
fluid
.
default_startup_program
(),
load_best_model
=
False
):
ckpt_meta_path
=
os
.
path
.
join
(
checkpoint_dir
,
CKPT_FILE_NAME
)
ckpt
=
checkpoint_pb2
.
CheckPoint
()
logger
.
info
(
"Try loading checkpoint from {}"
.
format
(
ckpt_meta_path
))
if
os
.
path
.
exists
(
ckpt_meta_path
):
ckpt
=
checkpoint_pb2
.
CheckPoint
()
with
open
(
ckpt_meta_path
,
"rb"
)
as
f
:
ckpt
.
ParseFromString
(
f
.
read
())
current_epoch
=
1
global_step
=
0
best_model_path
=
os
.
path
.
join
(
checkpoint_dir
,
"best_model"
)
if
load_best_model
and
os
.
path
.
exists
(
best_model_path
):
fluid
.
io
.
load_persistables
(
exe
,
best_model_path
,
main_program
)
logger
.
info
(
"PaddleHub model best model loaded."
)
return
current_epoch
,
global_step
elif
ckpt
.
latest_model_dir
:
fluid
.
io
.
load_persistables
(
exe
,
ckpt
.
latest_model_dir
,
main_program
)
logger
.
info
(
"PaddleHub model checkpoint loaded. current_epoch={}, "
"global_step={}"
.
format
(
ckpt
.
current_epoch
,
ckpt
.
global_step
))
return
ckpt
.
current_epoch
,
ckpt
.
global_step
else
:
current_epoch
=
1
global_step
=
0
latest_model_dir
=
None
logger
.
info
(
"PaddleHub model checkpoint not found, start training from scratch..."
)
exe
.
run
(
startup_program
)
return
current_epoch
,
global_step
logger
.
info
(
"PaddleHub model checkpoint not found, start training from scratch..."
)
exe
.
run
(
startup_program
)
return
current_epoch
,
global_step
def
save_checkpoint
(
checkpoint_dir
,
...
...
paddlehub/finetune/task.py
浏览文件 @
ead7f993
...
...
@@ -128,10 +128,11 @@ class BasicTask(object):
# run environment
self
.
_phases
=
[]
self
.
_envs
=
{}
self
.
_predict_data
=
None
def
init_if_necessary
(
self
):
def
init_if_necessary
(
self
,
load_best_model
=
False
):
if
not
self
.
_load_checkpoint
:
self
.
load_checkpoint
()
self
.
load_checkpoint
(
load_best_model
=
load_best_model
)
self
.
_load_checkpoint
=
True
@
contextlib
.
contextmanager
...
...
@@ -159,6 +160,11 @@ class BasicTask(object):
self
.
env
.
loss
=
self
.
_add_loss
()
self
.
env
.
metrics
=
self
.
_add_metrics
()
if
self
.
is_predict_phase
or
self
.
is_test_phase
:
self
.
env
.
main_program
=
self
.
env
.
main_program
.
clone
(
for_test
=
True
)
hub
.
common
.
paddle_helper
.
set_op_attr
(
self
.
env
.
main_program
,
is_test
=
True
)
if
self
.
config
.
use_pyreader
:
t_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
t_program
,
self
.
env
.
startup_program
):
...
...
@@ -291,8 +297,12 @@ class BasicTask(object):
@
property
def
reader
(
self
):
if
self
.
is_predict_phase
:
data
=
self
.
_predict_data
else
:
data
=
None
self
.
env
.
reader
=
self
.
_base_data_reader
.
data_generator
(
batch_size
=
self
.
config
.
batch_size
,
phase
=
self
.
phase
)
batch_size
=
self
.
config
.
batch_size
,
phase
=
self
.
phase
,
data
=
data
)
return
self
.
env
.
reader
@
property
...
...
@@ -315,8 +325,6 @@ class BasicTask(object):
@
property
def
output
(
self
):
if
self
.
is_predict_phase
:
raise
RuntimeError
()
if
not
self
.
env
.
is_inititalized
:
self
.
_build_env
()
return
self
.
env
.
output
...
...
@@ -412,7 +420,8 @@ class BasicTask(object):
self
.
config
.
checkpoint_dir
,
self
.
exe
,
main_program
=
self
.
main_program
,
startup_program
=
self
.
_base_startup_program
)
startup_program
=
self
.
_base_startup_program
,
load_best_model
=
load_best_model
)
if
load_best_model
:
model_saved_dir
=
os
.
path
.
join
(
self
.
config
.
checkpoint_dir
,
...
...
@@ -454,10 +463,12 @@ class BasicTask(object):
self
.
_eval_end_event
(
run_states
)
def
predict
(
self
,
data
,
load_best_model
=
True
):
with
self
.
phase_guard
(
phase
=
phase
):
self
.
init_if_necessary
()
for
run_state
in
self
.
_run
():
yield
run_state
.
run_results
with
self
.
phase_guard
(
phase
=
"predict"
):
self
.
_predict_data
=
data
self
.
init_if_necessary
(
load_best_model
=
load_best_model
)
run_states
=
self
.
_run
()
self
.
_predict_data
=
None
return
[
run_state
.
run_results
for
run_state
in
run_states
]
def
_run
(
self
,
do_eval
=
False
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录