Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
28eb496f
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
283
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
28eb496f
编写于
5月 23, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add py_reader processing flow
上级
2620edc3
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
153 addition
and
7 deletion
+153
-7
paddlehub/finetune/config.py
paddlehub/finetune/config.py
+6
-0
paddlehub/finetune/task.py
paddlehub/finetune/task.py
+147
-7
未找到文件。
paddlehub/finetune/config.py
浏览文件 @
28eb496f
...
...
@@ -29,6 +29,7 @@ class RunConfig(object):
def
__init__
(
self
,
log_interval
=
10
,
eval_interval
=
100
,
use_pyreader
=
False
,
save_ckpt_interval
=
None
,
use_cuda
=
True
,
checkpoint_dir
=
None
,
...
...
@@ -44,6 +45,7 @@ class RunConfig(object):
self
.
_checkpoint_dir
=
checkpoint_dir
self
.
_num_epoch
=
num_epoch
self
.
_batch_size
=
batch_size
self
.
_use_pyreader
=
use_pyreader
if
strategy
is
None
:
self
.
_strategy
=
DefaultStrategy
()
else
:
...
...
@@ -93,3 +95,7 @@ class RunConfig(object):
@
property
def
enable_memory_optim
(
self
):
return
self
.
_enable_memory_optim
@
property
def
use_pyreader
(
self
):
return
self
.
_use_pyreader
paddlehub/finetune/task.py
浏览文件 @
28eb496f
...
...
@@ -26,6 +26,7 @@ import paddle.fluid as fluid
from
visualdl
import
LogWriter
import
paddlehub
as
hub
from
paddlehub.common.paddle_helper
import
dtype_map
from
paddlehub.common.utils
import
mkdir
from
paddlehub.common.logger
import
logger
from
paddlehub.finetune.checkpoint
import
load_checkpoint
,
save_checkpoint
...
...
@@ -77,6 +78,9 @@ class BasicTask(object):
self
.
config
)
self
.
exe
=
fluid
.
Executor
(
place
=
self
.
place
)
self
.
feed_list
=
feed_list
self
.
feed_variables
=
[
main_program
.
global_block
().
vars
[
var_name
]
for
var_name
in
feed_list
]
self
.
metrics
=
[]
self
.
is_inititalized
=
False
self
.
current_step
=
0
...
...
@@ -127,6 +131,63 @@ class BasicTask(object):
def
_add_metrics
(
self
):
raise
NotImplementedError
def
_add_py_reader
(
self
):
for
program
,
add_label
in
((
self
.
main_program
,
True
),
(
self
.
test_program
,
True
),
(
self
.
inference_program
,
False
)):
temp_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
temp_program
,
startup_program
):
feed_variables
=
self
.
feed_variables
if
add_label
:
feed_variables
=
feed_variables
+
[
self
.
label
]
feed_list
=
self
.
feed_list
if
add_label
:
feed_list
=
feed_list
+
[
self
.
label
.
name
]
py_reader
=
fluid
.
layers
.
py_reader
(
capacity
=
16
,
shapes
=
[
var
.
shape
for
var
in
feed_variables
],
lod_levels
=
[
var
.
lod_level
for
var
in
feed_variables
],
dtypes
=
[
dtype_map
[
var
.
dtype
]
for
var
in
feed_variables
],
use_double_buffer
=
True
)
feed_variables
=
fluid
.
layers
.
read_file
(
py_reader
)
input_dict
=
{
key
:
feed_variables
[
index
]
for
index
,
key
in
enumerate
(
feed_list
)
}
hub
.
connect_program
(
pre_program
=
temp_program
,
next_program
=
program
,
input_dict
=
input_dict
,
inplace
=
True
)
self
.
exe
.
run
(
startup_program
)
if
program
==
self
.
main_program
:
self
.
main_program
=
temp_program
self
.
loss
=
self
.
main_program
.
global_block
().
vars
[
self
.
loss
.
name
]
for
index
,
metric
in
enumerate
(
self
.
metrics
):
self
.
metrics
[
index
]
=
self
.
main_program
.
global_block
().
vars
[
metric
.
name
]
self
.
output
=
self
.
main_program
.
global_block
().
vars
[
self
.
output
.
name
]
self
.
loss
.
persistable
=
True
for
metric
in
self
.
metrics
:
metric
.
persistable
=
True
self
.
output
.
persistable
=
True
self
.
main_py_reader
=
py_reader
elif
program
==
self
.
test_program
:
self
.
test_program
=
temp_program
self
.
test_py_reader
=
py_reader
elif
program
==
self
.
inference_program
:
self
.
inference_program
=
temp_program
self
.
inference_py_reader
=
py_reader
def
_init_if_necessary
(
self
,
load_best_model
=
False
):
if
not
self
.
is_inititalized
:
self
.
_init_start_event
()
...
...
@@ -137,12 +198,17 @@ class BasicTask(object):
self
.
_add_loss
()
self
.
_add_metrics
()
self
.
test_program
=
self
.
main_program
.
clone
(
for_test
=
True
)
if
self
.
config
.
use_pyreader
:
self
.
_add_py_reader
()
with
fluid
.
program_guard
(
self
.
main_program
):
self
.
config
.
strategy
.
execute
(
self
.
loss
,
self
.
data_reader
,
self
.
config
)
self
.
loss
.
persistable
=
True
for
metric
s
in
self
.
metrics
:
metric
s
.
persistable
=
True
for
metric
in
self
.
metrics
:
metric
.
persistable
=
True
self
.
output
.
persistable
=
True
self
.
build_strategy
=
fluid
.
BuildStrategy
()
...
...
@@ -187,7 +253,8 @@ class BasicTask(object):
self
.
current_epoch
,
self
.
current_step
=
load_checkpoint
(
self
.
config
.
checkpoint_dir
,
self
.
exe
,
main_program
=
self
.
main_program
)
main_program
=
self
.
main_program
,
startup_program
=
self
.
startup_program
)
if
load_best_model
:
model_saved_dir
=
os
.
path
.
join
(
self
.
config
.
checkpoint_dir
,
...
...
@@ -245,11 +312,17 @@ class BasicTask(object):
test_reader
=
self
.
data_reader
.
data_generator
(
batch_size
=
self
.
config
.
batch_size
,
phase
=
phase
)
run_states
=
self
.
_run
(
test_reader
,
phase
=
phase
,
program_compiled
=
self
.
test_program
)
test_reader
,
phase
=
phase
,
program_compiled
=
self
.
test_program_compiled
)
self
.
_eval_end_event
(
phase
,
run_states
)
def
_run
(
self
,
reader
,
phase
,
do_eval
=
False
,
program_compiled
=
None
):
def
_run_with_data_feeder
(
self
,
reader
,
phase
,
do_eval
=
False
,
program_compiled
=
None
):
if
program_compiled
is
None
:
program_compiled
=
self
.
main_program_compiled
feed_list
=
self
.
get_feed_list
(
phase
=
phase
)
...
...
@@ -291,6 +364,73 @@ class BasicTask(object):
global_run_states
+=
period_run_states
return
global_run_states
def
_run_with_py_reader
(
self
,
reader
,
phase
,
do_eval
=
False
,
program_compiled
=
None
):
if
program_compiled
is
None
:
program_compiled
=
self
.
main_program_compiled
if
phase
==
"train"
:
py_reader
=
self
.
main_py_reader
elif
phase
in
[
"dev"
,
"val"
,
"test"
]:
py_reader
=
self
.
test_py_reader
elif
phase
==
"predict"
:
py_reader
=
self
.
inference_py_reader
py_reader
.
decorate_paddle_reader
(
reader
)
fetch_list
=
self
.
get_fetch_list
(
phase
=
phase
)
global_run_states
=
[]
period_run_states
=
[]
py_reader
.
start
()
try
:
while
True
:
num_batch_examples
=
self
.
config
.
batch_size
step_run_state
=
RunState
(
len
(
fetch_list
))
step_run_state
.
run_step
=
1
fetch_result
=
self
.
exe
.
run
(
program_compiled
,
fetch_list
=
fetch_list
)
for
index
,
result
in
enumerate
(
fetch_result
):
step_run_state
.
run_results
[
index
]
=
result
step_run_state
.
run_examples
+=
num_batch_examples
step_run_state
.
update
()
period_run_states
+=
[
step_run_state
]
if
phase
==
"train"
:
self
.
current_step
+=
1
if
self
.
current_step
%
self
.
config
.
log_interval
==
0
:
self
.
_log_interval_event
(
period_run_states
)
global_run_states
+=
period_run_states
period_run_states
=
[]
if
self
.
config
.
save_ckpt_interval
and
self
.
current_step
%
self
.
config
.
save_ckpt_interval
==
0
:
self
.
_save_ckpt_interval_event
()
if
do_eval
and
self
.
current_step
%
self
.
config
.
eval_interval
==
0
:
self
.
_eval_interval_event
()
self
.
_run_step_event
(
phase
,
step_run_state
)
except
fluid
.
core
.
EOFException
:
py_reader
.
reset
()
global_run_states
+=
period_run_states
return
global_run_states
def
_run
(
self
,
reader
,
phase
,
do_eval
=
False
,
program_compiled
=
None
):
if
self
.
config
.
use_pyreader
:
return
self
.
_run_with_py_reader
(
reader
,
phase
,
do_eval
=
do_eval
,
program_compiled
=
program_compiled
)
else
:
return
self
.
_run_with_data_feeder
(
reader
,
phase
,
do_eval
=
do_eval
,
program_compiled
=
program_compiled
)
def
predict
(
self
,
data
,
load_best_model
=
True
):
self
.
_init_if_necessary
(
load_best_model
=
load_best_model
)
with
fluid
.
program_guard
(
self
.
inference_program
):
...
...
@@ -299,7 +439,7 @@ class BasicTask(object):
for
run_state
in
self
.
_run
(
inference_reader
,
phase
=
'predict'
,
program_compiled
=
self
.
inference_program
):
program_compiled
=
self
.
inference_program
_compiled
):
yield
run_state
.
run_results
...
...
@@ -408,7 +548,7 @@ class ClassifierTask(BasicTask):
save_result
=
fluid
.
io
.
save_persistables
(
executor
=
self
.
exe
,
dirname
=
model_saved_dir
,
main_program
=
self
.
main
_program
)
main_program
=
self
.
test
_program
)
ImageClassifierTask
=
ClassifierTask
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录