Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
3b2cceb2
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3b2cceb2
编写于
5月 30, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'task_update' into develop
上级
b71e406d
5e9eda4e
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
1023 addition
and
704 deletion
+1023
-704
demo/image-classification/img_classifier.py
demo/image-classification/img_classifier.py
+9
-6
demo/image-classification/predict.py
demo/image-classification/predict.py
+29
-28
paddlehub/__init__.py
paddlehub/__init__.py
+4
-5
paddlehub/common/paddle_helper.py
paddlehub/common/paddle_helper.py
+9
-3
paddlehub/finetune/checkpoint.py
paddlehub/finetune/checkpoint.py
+15
-5
paddlehub/finetune/config.py
paddlehub/finetune/config.py
+6
-0
paddlehub/finetune/evaluate.py
paddlehub/finetune/evaluate.py
+0
-89
paddlehub/finetune/finetune.py
paddlehub/finetune/finetune.py
+0
-313
paddlehub/finetune/strategy.py
paddlehub/finetune/strategy.py
+5
-4
paddlehub/finetune/task.py
paddlehub/finetune/task.py
+766
-170
paddlehub/reader/nlp_reader.py
paddlehub/reader/nlp_reader.py
+180
-81
未找到文件。
demo/image-classification/img_classifier.py
浏览文件 @
3b2cceb2
...
...
@@ -9,7 +9,7 @@ import numpy as np
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--num_epoch"
,
type
=
int
,
default
=
1
,
help
=
"Number of epoches for fine-tuning."
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
bool
,
default
=
Fals
e
,
help
=
"Whether use GPU for fine-tuning."
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
bool
,
default
=
Tru
e
,
help
=
"Whether use GPU for fine-tuning."
)
parser
.
add_argument
(
"--checkpoint_dir"
,
type
=
str
,
default
=
"paddlehub_finetune_ckpt"
,
help
=
"Path to save log data."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--module"
,
type
=
str
,
default
=
"resnet50"
,
help
=
"Module used as feature extractor."
)
...
...
@@ -51,11 +51,9 @@ def finetune(args):
dataset
=
dataset
)
feature_map
=
output_dict
[
"feature_map"
]
task
=
hub
.
create_img_cls_task
(
feature
=
feature_map
,
num_classes
=
dataset
.
num_labels
)
img
=
input_dict
[
"image"
]
feed_list
=
[
img
.
name
,
task
.
variable
(
'label'
).
name
]
feed_list
=
[
img
.
name
]
config
=
hub
.
RunConfig
(
use_cuda
=
args
.
use_gpu
,
...
...
@@ -65,8 +63,13 @@ def finetune(args):
checkpoint_dir
=
args
.
checkpoint_dir
,
strategy
=
hub
.
finetune
.
strategy
.
DefaultFinetuneStrategy
())
hub
.
finetune_and_eval
(
task
,
feed_list
=
feed_list
,
data_reader
=
data_reader
,
config
=
config
)
task
=
hub
.
ImageClassifierTask
(
data_reader
=
data_reader
,
feed_list
=
feed_list
,
feature
=
feature_map
,
num_classes
=
dataset
.
num_labels
,
config
=
config
)
task
.
finetune_and_eval
()
if
__name__
==
"__main__"
:
...
...
demo/image-classification/predict.py
浏览文件 @
3b2cceb2
...
...
@@ -10,6 +10,7 @@ import numpy as np
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
bool
,
default
=
False
,
help
=
"Whether use GPU for predict."
)
parser
.
add_argument
(
"--checkpoint_dir"
,
type
=
str
,
default
=
"paddlehub_finetune_ckpt"
,
help
=
"Path to save log data."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--module"
,
type
=
str
,
default
=
"resnet50"
,
help
=
"Module used as a feature extractor."
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"flowers"
,
help
=
"Dataset to finetune."
)
# yapf: enable.
...
...
@@ -25,6 +26,8 @@ module_map = {
def
predict
(
args
):
module
=
hub
.
Module
(
name
=
args
.
module
)
input_dict
,
output_dict
,
program
=
module
.
context
(
trainable
=
True
)
if
args
.
dataset
.
lower
()
==
"flowers"
:
dataset
=
hub
.
dataset
.
Flowers
()
...
...
@@ -39,45 +42,43 @@ def predict(args):
else
:
raise
ValueError
(
"%s dataset is not defined"
%
args
.
dataset
)
label_map
=
dataset
.
label_dict
()
num_labels
=
len
(
label_map
)
module
=
hub
.
Module
(
name
=
args
.
module
)
input_dict
,
output_dict
,
program
=
module
.
context
()
data_reader
=
hub
.
reader
.
ImageClassificationReader
(
image_width
=
module
.
get_expected_image_width
(),
image_height
=
module
.
get_expected_image_height
(),
images_mean
=
module
.
get_pretrained_images_mean
(),
images_std
=
module
.
get_pretrained_images_std
(),
dataset
=
None
)
dataset
=
dataset
)
img
=
input_dict
[
"image"
]
feature_map
=
output_dict
[
"feature_map"
]
task
=
hub
.
create_img_cls_task
(
feature
=
feature_map
,
num_classes
=
num_labels
)
img
=
input_dict
[
"image"
]
feed_list
=
[
img
.
name
]
with
fluid
.
program_guard
(
task
.
inference_program
()):
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
pretrained_model_dir
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
"best_model"
)
if
not
os
.
path
.
exists
(
pretrained_model_dir
):
hub
.
logger
.
error
(
"pretrained model dir %s didn't exist"
%
pretrained_model_dir
)
exit
(
1
)
fluid
.
io
.
load_persistables
(
exe
,
pretrained_model_dir
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
data
=
[
"test/test_img_roses.jpg"
,
"test/test_img_daisy.jpg"
]
config
=
hub
.
RunConfig
(
use_cuda
=
args
.
use_gpu
,
batch_size
=
args
.
batch_size
,
enable_memory_optim
=
False
,
checkpoint_dir
=
args
.
checkpoint_dir
,
strategy
=
hub
.
finetune
.
strategy
.
DefaultFinetuneStrategy
())
predict_reader
=
data_reader
.
data_generator
(
phase
=
"predict"
,
batch_size
=
1
,
data
=
data
)
for
index
,
batch
in
enumerate
(
predict_reader
()):
result
,
=
exe
.
run
(
feed
=
feeder
.
feed
(
batch
),
fetch_list
=
[
task
.
variable
(
'probs'
)])
predict_result
=
label_map
[
np
.
argsort
(
result
[
0
])[::
-
1
][
0
]]
print
(
"input %i is %s, and the predict result is %s"
%
(
index
,
data
[
index
],
predict_result
))
task
=
hub
.
ClassifierTask
(
data_reader
=
data_reader
,
feed_list
=
feed_list
,
feature
=
feature_map
,
num_classes
=
dataset
.
num_labels
,
config
=
config
)
data
=
[
"./test/test_img_daisy.jpg"
,
"./test/test_img_roses.jpg"
]
label_map
=
dataset
.
label_dict
()
for
result
in
task
.
predict
(
data
=
data
):
result
=
np
.
argmax
(
result
,
axis
=
2
)
index
=
0
for
batch
in
result
:
for
predict_result
in
batch
:
index
+=
1
predict_result
=
label_map
[
predict_result
]
print
(
"input %i is %s, and the predict result is %s"
%
(
index
,
data
[
index
-
1
],
predict_result
))
if
__name__
==
"__main__"
:
...
...
paddlehub/__init__.py
浏览文件 @
3b2cceb2
...
...
@@ -42,11 +42,10 @@ from .module.manager import default_module_manager
from
.io.type
import
DataType
from
.finetune.task
import
Task
from
.finetune.task
import
create_seq_label_task
from
.finetune.task
import
create_text_cls_task
from
.finetune.task
import
create_img_cls_task
from
.finetune.finetune
import
finetune_and_eval
from
.finetune.task
import
ClassifierTask
from
.finetune.task
import
TextClassifierTask
from
.finetune.task
import
ImageClassifierTask
from
.finetune.task
import
SequenceLabelTask
from
.finetune.config
import
RunConfig
from
.finetune.strategy
import
AdamWeightDecayStrategy
from
.finetune.strategy
import
DefaultStrategy
...
...
paddlehub/common/paddle_helper.py
浏览文件 @
3b2cceb2
...
...
@@ -143,7 +143,11 @@ def from_module_attr_to_param(module_attr):
return
param
def
connect_program
(
pre_program
,
next_program
,
input_dict
=
None
,
inplace
=
True
):
def
connect_program
(
pre_program
,
next_program
,
input_dict
=
None
,
inplace
=
True
,
need_log
=
True
):
def
_copy_vars_and_ops_in_blocks
(
from_block
,
to_block
):
for
var
in
from_block
.
vars
:
var
=
from_block
.
var
(
var
)
...
...
@@ -199,7 +203,8 @@ def connect_program(pre_program, next_program, input_dict=None, inplace=True):
outputs
=
{
'Out'
:
output_var
})
block_map
=
{
0
:
0
}
logger
.
info
(
"Connect program's input tensor"
)
if
need_log
:
logger
.
info
(
"Connect program's input tensor"
)
for
index
,
block
in
enumerate
(
next_program
.
blocks
):
if
block
.
idx
==
0
:
_copy_vars_and_ops_in_blocks
(
block
,
output_program
.
global_block
())
...
...
@@ -211,7 +216,8 @@ def connect_program(pre_program, next_program, input_dict=None, inplace=True):
new_block
=
output_program
.
_create_block
(
parent_idx
=
block_map
[
block
.
parent_idx
])
_copy_vars_and_ops_in_blocks
(
block
,
new_block
)
logger
.
info
(
"Connect program's input tensor done"
)
if
need_log
:
logger
.
info
(
"Connect program's input tensor done"
)
return
output_program
...
...
paddlehub/finetune/checkpoint.py
浏览文件 @
3b2cceb2
...
...
@@ -27,7 +27,11 @@ from paddlehub.common.logger import logger
CKPT_FILE_NAME
=
"ckpt.meta"
def
load_checkpoint
(
checkpoint_dir
,
exe
):
def
load_checkpoint
(
checkpoint_dir
,
exe
,
main_program
=
fluid
.
default_main_program
(),
startup_program
=
fluid
.
default_startup_program
()):
ckpt_meta_path
=
os
.
path
.
join
(
checkpoint_dir
,
CKPT_FILE_NAME
)
logger
.
info
(
"Try loading checkpoint from {}"
.
format
(
ckpt_meta_path
))
if
os
.
path
.
exists
(
ckpt_meta_path
):
...
...
@@ -35,7 +39,7 @@ def load_checkpoint(checkpoint_dir, exe):
with
open
(
ckpt_meta_path
,
"rb"
)
as
f
:
ckpt
.
ParseFromString
(
f
.
read
())
fluid
.
io
.
load_persistables
(
exe
,
ckpt
.
latest_model_dir
)
fluid
.
io
.
load_persistables
(
exe
,
ckpt
.
latest_model_dir
,
main_program
)
logger
.
info
(
"PaddleHub model checkpoint loaded. current_epoch={}, "
"global_step={}"
.
format
(
ckpt
.
current_epoch
,
...
...
@@ -48,18 +52,24 @@ def load_checkpoint(checkpoint_dir, exe):
logger
.
info
(
"PaddleHub model checkpoint not found, start training from scratch..."
)
exe
.
run
(
fluid
.
default_startup_program
()
)
exe
.
run
(
startup_program
)
return
current_epoch
,
global_step
def
save_checkpoint
(
checkpoint_dir
,
current_epoch
,
global_step
,
exe
):
def
save_checkpoint
(
checkpoint_dir
,
current_epoch
,
global_step
,
exe
,
main_program
=
fluid
.
default_main_program
()):
ckpt_meta_path
=
os
.
path
.
join
(
checkpoint_dir
,
CKPT_FILE_NAME
)
ckpt
=
checkpoint_pb2
.
CheckPoint
()
model_saved_dir
=
os
.
path
.
join
(
checkpoint_dir
,
"step_%d"
%
global_step
)
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
model_saved_dir
))
fluid
.
io
.
save_persistables
(
exe
,
dirname
=
model_saved_dir
)
fluid
.
io
.
save_persistables
(
exe
,
dirname
=
model_saved_dir
,
main_program
=
main_program
)
ckpt
.
current_epoch
=
current_epoch
ckpt
.
global_step
=
global_step
...
...
paddlehub/finetune/config.py
浏览文件 @
3b2cceb2
...
...
@@ -30,6 +30,7 @@ class RunConfig(object):
def
__init__
(
self
,
log_interval
=
10
,
eval_interval
=
100
,
use_pyreader
=
False
,
save_ckpt_interval
=
None
,
use_cuda
=
True
,
checkpoint_dir
=
None
,
...
...
@@ -45,6 +46,7 @@ class RunConfig(object):
self
.
_checkpoint_dir
=
checkpoint_dir
self
.
_num_epoch
=
num_epoch
self
.
_batch_size
=
batch_size
self
.
_use_pyreader
=
use_pyreader
if
strategy
is
None
:
self
.
_strategy
=
DefaultStrategy
()
else
:
...
...
@@ -94,3 +96,7 @@ class RunConfig(object):
@
property
def
enable_memory_optim
(
self
):
return
self
.
_enable_memory_optim
@
property
def
use_pyreader
(
self
):
return
self
.
_use_pyreader
paddlehub/finetune/evaluate.py
浏览文件 @
3b2cceb2
...
...
@@ -26,95 +26,6 @@ from paddlehub.common.logger import logger
import
paddlehub
as
hub
def
evaluate_cls_task
(
task
,
data_reader
,
feed_list
,
phase
=
"test"
,
config
=
None
):
logger
.
info
(
"Evaluation on {} dataset start"
.
format
(
phase
))
test_program
=
task
.
test_program
()
main_program
=
task
.
main_program
()
loss
=
task
.
variable
(
"loss"
)
accuracy
=
task
.
variable
(
"accuracy"
)
batch_size
=
config
.
batch_size
place
,
dev_count
=
hub
.
common
.
get_running_device_info
(
config
)
exe
=
fluid
.
Executor
(
place
=
place
)
with
fluid
.
program_guard
(
test_program
):
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
num_eval_examples
=
acc_sum
=
loss_sum
=
0
test_reader
=
data_reader
.
data_generator
(
batch_size
=
batch_size
,
phase
=
phase
)
eval_time_begin
=
time
.
time
()
eval_step
=
0
for
batch
in
test_reader
():
num_batch_examples
=
len
(
batch
)
eval_step
+=
1
loss_v
,
accuracy_v
=
exe
.
run
(
feed
=
data_feeder
.
feed
(
batch
),
fetch_list
=
[
loss
.
name
,
accuracy
.
name
])
num_eval_examples
+=
num_batch_examples
if
num_eval_examples
%
10000
==
0
:
logger
.
info
(
"{} examples evaluated."
.
format
(
num_eval_examples
))
acc_sum
+=
accuracy_v
*
num_batch_examples
loss_sum
+=
loss_v
*
num_batch_examples
eval_time_used
=
time
.
time
()
-
eval_time_begin
avg_loss
=
loss_sum
/
num_eval_examples
avg_acc
=
acc_sum
/
num_eval_examples
eval_speed
=
eval_step
/
eval_time_used
logger
.
info
(
"[%s dataset evaluation result] loss=%.5f acc=%.5f [step/sec: %.2f]"
%
(
phase
,
avg_loss
,
avg_acc
,
eval_speed
))
return
avg_loss
,
avg_acc
,
eval_speed
def
evaluate_seq_label_task
(
task
,
data_reader
,
feed_list
,
phase
=
"test"
,
config
=
None
):
fetch_list
=
[
task
.
variable
(
"labels"
).
name
,
task
.
variable
(
"infers"
).
name
,
task
.
variable
(
"seq_len"
).
name
,
task
.
variable
(
"loss"
).
name
]
logger
.
info
(
"Evaluation on {} dataset start"
.
format
(
phase
))
test_program
=
task
.
test_program
()
batch_size
=
config
.
batch_size
place
,
dev_count
=
hub
.
common
.
get_running_device_info
(
config
)
exe
=
fluid
.
Executor
(
place
=
place
)
# calculate the num of label from probs variable shape
num_labels
=
task
.
variable
(
"probs"
).
shape
[
1
]
with
fluid
.
program_guard
(
test_program
):
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
num_eval_examples
=
acc_sum
=
loss_sum
=
0
test_reader
=
data_reader
.
data_generator
(
batch_size
=
batch_size
,
phase
=
phase
)
eval_time_begin
=
time
.
time
()
eval_step
=
0
total_label
,
total_infer
,
total_correct
=
0.0
,
0.0
,
0.0
for
batch
in
test_reader
():
num_batch_examples
=
len
(
batch
)
eval_step
+=
1
np_labels
,
np_infers
,
np_lens
,
_
=
exe
.
run
(
feed
=
data_feeder
.
feed
(
batch
),
fetch_list
=
fetch_list
)
label_num
,
infer_num
,
correct_num
=
chunk_eval
(
np_labels
,
np_infers
,
np_lens
,
num_labels
,
dev_count
)
total_infer
+=
infer_num
total_label
+=
label_num
total_correct
+=
correct_num
precision
,
recall
,
f1
=
calculate_f1
(
total_label
,
total_infer
,
total_correct
)
eval_time_used
=
time
.
time
()
-
eval_time_begin
eval_speed
=
eval_step
/
eval_time_used
logger
.
info
(
"[%s evaluation] F1-Score=%f, precision=%f, recall=%f [step/sec: %.2f]"
%
(
phase
,
f1
,
precision
,
recall
,
eval_speed
))
return
f1
,
precision
,
recall
# Sequence label evaluation functions
def
chunk_eval
(
np_labels
,
np_infers
,
np_lens
,
tag_num
,
dev_count
=
1
):
def
extract_bio_chunk
(
seq
):
...
...
paddlehub/finetune/finetune.py
已删除
100644 → 0
浏览文件 @
b71e406d
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
time
import
paddle
import
paddle.fluid
as
fluid
import
numpy
as
np
from
visualdl
import
LogWriter
from
paddlehub.common.logger
import
logger
from
paddlehub.common.utils
import
mkdir
from
paddlehub.finetune.config
import
RunConfig
from
paddlehub.finetune.strategy
import
AdamWeightDecayStrategy
,
DefaultStrategy
from
paddlehub.finetune.checkpoint
import
load_checkpoint
,
save_checkpoint
from
paddlehub.finetune.evaluate
import
evaluate_cls_task
,
evaluate_seq_label_task
import
paddlehub
as
hub
def
_do_memory_optimization
(
task
,
config
):
if
config
.
enable_memory_optim
:
logger
.
info
(
"Memory optimization start..."
)
task_var_name
=
task
.
metric_variable_names
()
logger
.
info
(
"Skip memory optimization on variables: {}"
.
format
(
task_var_name
))
optimize_time_begin
=
time
.
time
()
fluid
.
memory_optimize
(
input_program
=
fluid
.
default_main_program
(),
# skip memory optimization on task metric variables
skip_opt_set
=
task_var_name
)
time_used
=
time
.
time
()
-
optimize_time_begin
logger
.
info
(
"Memory optimization done! Time elapsed %f sec"
%
time_used
)
# lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
# program=task.main_program(), batch_size=config.batch_size)
# logger.info("Theoretical memory usage in training: %.2f - %.2f %s" %
# (lower_mem, upper_mem, unit)),
def
_finetune_seq_label_task
(
task
,
data_reader
,
feed_list
,
config
=
None
,
do_eval
=
False
):
"""
Finetune sequence labeling task, evaluate metric is F1, precision and recall
"""
main_program
=
task
.
main_program
()
startup_program
=
task
.
startup_program
()
loss
=
task
.
variable
(
"loss"
)
seq_len
=
task
.
variable
(
"seq_len"
)
num_epoch
=
config
.
num_epoch
batch_size
=
config
.
batch_size
log_writer
=
LogWriter
(
os
.
path
.
join
(
config
.
checkpoint_dir
,
"vdllog"
),
sync_cycle
=
1
)
place
,
dev_count
=
hub
.
common
.
get_running_device_info
(
config
)
with
fluid
.
program_guard
(
main_program
,
startup_program
):
exe
=
fluid
.
Executor
(
place
=
place
)
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
# Select strategy
if
isinstance
(
config
.
strategy
,
hub
.
AdamWeightDecayStrategy
):
scheduled_lr
=
config
.
strategy
.
execute
(
loss
,
main_program
,
data_reader
,
config
)
elif
isinstance
(
config
.
strategy
,
hub
.
DefaultStrategy
):
config
.
strategy
.
execute
(
loss
)
#TODO: add more finetune strategy
_do_memory_optimization
(
task
,
config
)
# Try to restore model training checkpoint
current_epoch
,
global_step
=
load_checkpoint
(
config
.
checkpoint_dir
,
exe
)
best_eval_f1
=
0.0
train_time_used
=
0
logger
.
info
(
"PaddleHub finetune start"
)
exe
.
run
(
fluid
.
default_startup_program
())
# add visualdl scalar
with
log_writer
.
mode
(
"train"
)
as
logw
:
train_loss_scalar
=
logw
.
scalar
(
tag
=
"Loss [train]"
)
with
log_writer
.
mode
(
"evaluate"
)
as
logw
:
eval_f1_scalar
=
logw
.
scalar
(
tag
=
"F1 [eval]"
)
eval_precision_scalar
=
logw
.
scalar
(
tag
=
"Precision [eval]"
)
eval_recall_scalar
=
logw
.
scalar
(
tag
=
"Recall [eval]"
)
# Finetune loop
for
epoch
in
range
(
current_epoch
,
num_epoch
+
1
):
train_reader
=
data_reader
.
data_generator
(
batch_size
=
batch_size
,
phase
=
'train'
)
num_trained_examples
=
loss_sum
=
0
for
batch
in
train_reader
():
num_batch_examples
=
len
(
batch
)
train_time_begin
=
time
.
time
()
loss_v
=
exe
.
run
(
feed
=
data_feeder
.
feed
(
batch
),
fetch_list
=
[
loss
.
name
])
train_time_used
+=
time
.
time
()
-
train_time_begin
global_step
+=
1
num_trained_examples
+=
num_batch_examples
loss_sum
+=
loss_v
[
0
]
*
num_batch_examples
# log fintune status
if
global_step
%
config
.
log_interval
==
0
:
avg_loss
=
loss_sum
/
num_trained_examples
speed
=
config
.
log_interval
/
train_time_used
logger
.
info
(
"step %d: loss=%.5f [step/sec: %.2f]"
%
(
global_step
,
avg_loss
,
speed
))
train_loss_scalar
.
add_record
(
global_step
,
avg_loss
)
train_time_used
=
0
num_trained_examples
=
0
loss_sum
=
0
if
config
.
save_ckpt_interval
and
global_step
%
config
.
save_ckpt_interval
==
0
:
# NOTE: current saved checkpoint machanism is not completed,
# it can't restore correct dataset training status
save_checkpoint
(
checkpoint_dir
=
config
.
checkpoint_dir
,
current_epoch
=
epoch
,
global_step
=
global_step
,
exe
=
exe
)
if
do_eval
and
global_step
%
config
.
eval_interval
==
0
:
f1
,
precision
,
recall
=
evaluate_seq_label_task
(
task
,
data_reader
,
feed_list
,
phase
=
"dev"
,
config
=
config
)
eval_f1_scalar
.
add_record
(
global_step
,
f1
)
eval_precision_scalar
.
add_record
(
global_step
,
precision
)
eval_recall_scalar
.
add_record
(
global_step
,
recall
)
if
f1
>
best_eval_f1
:
best_eval_f1
=
f1
model_saved_dir
=
os
.
path
.
join
(
config
.
checkpoint_dir
,
"best_model"
)
logger
.
info
(
"best model saved to %s [best F1=%.5f]"
%
(
model_saved_dir
,
best_eval_f1
))
fluid
.
io
.
save_persistables
(
exe
,
dirname
=
model_saved_dir
)
# NOTE: current saved checkpoint machanism is not completed, it can't
# resotre dataset training status
save_checkpoint
(
checkpoint_dir
=
config
.
checkpoint_dir
,
current_epoch
=
num_epoch
+
1
,
global_step
=
global_step
,
exe
=
exe
)
# Final evaluation
if
do_eval
:
evaluate_seq_label_task
(
task
,
data_reader
,
feed_list
,
phase
=
"dev"
,
config
=
config
)
evaluate_seq_label_task
(
task
,
data_reader
,
feed_list
,
phase
=
"test"
,
config
=
config
)
logger
.
info
(
"PaddleHub finetune finished."
)
def
_finetune_cls_task
(
task
,
data_reader
,
feed_list
,
config
=
None
,
do_eval
=
False
):
main_program
=
task
.
main_program
()
startup_program
=
task
.
startup_program
()
loss
=
task
.
variable
(
"loss"
)
accuracy
=
task
.
variable
(
"accuracy"
)
num_epoch
=
config
.
num_epoch
batch_size
=
config
.
batch_size
log_writer
=
LogWriter
(
os
.
path
.
join
(
config
.
checkpoint_dir
,
"vdllog"
),
sync_cycle
=
1
)
place
,
dev_count
=
hub
.
common
.
get_running_device_info
(
config
)
with
fluid
.
program_guard
(
main_program
,
startup_program
):
exe
=
fluid
.
Executor
(
place
=
place
)
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_list
,
place
=
place
)
# select strategy
if
isinstance
(
config
.
strategy
,
hub
.
AdamWeightDecayStrategy
):
scheduled_lr
=
config
.
strategy
.
execute
(
loss
,
main_program
,
data_reader
,
config
)
elif
isinstance
(
config
.
strategy
,
hub
.
DefaultStrategy
):
config
.
strategy
.
execute
(
loss
)
#TODO: add more finetune strategy
_do_memory_optimization
(
task
,
config
)
# Try to restore model training checkpoint
current_epoch
,
global_step
=
load_checkpoint
(
config
.
checkpoint_dir
,
exe
)
best_eval_acc
=
0.0
train_time_used
=
0
logger
.
info
(
"PaddleHub finetune start"
)
# add visualdl scalar
with
log_writer
.
mode
(
"train"
)
as
logw
:
train_loss_scalar
=
logw
.
scalar
(
tag
=
"Loss [train]"
)
train_acc_scalar
=
logw
.
scalar
(
tag
=
"Accuracy [train]"
)
with
log_writer
.
mode
(
"evaluate"
)
as
logw
:
eval_loss_scalar
=
logw
.
scalar
(
tag
=
"Loss [eval]"
)
eval_acc_scalar
=
logw
.
scalar
(
tag
=
"Accuracy [eval]"
)
exe
.
run
(
fluid
.
default_startup_program
())
# Finetune loop
for
epoch
in
range
(
current_epoch
,
num_epoch
+
1
):
train_reader
=
data_reader
.
data_generator
(
batch_size
=
batch_size
,
phase
=
'train'
)
num_trained_examples
=
acc_sum
=
loss_sum
=
0
for
batch
in
train_reader
():
num_batch_examples
=
len
(
batch
)
train_time_begin
=
time
.
time
()
loss_v
,
accuracy_v
=
exe
.
run
(
feed
=
data_feeder
.
feed
(
batch
),
fetch_list
=
[
loss
.
name
,
accuracy
.
name
],
return_numpy
=
False
)
loss_v
=
np
.
array
(
loss_v
)
accuracy_v
=
np
.
array
(
accuracy_v
)
train_time_used
+=
time
.
time
()
-
train_time_begin
global_step
+=
1
num_trained_examples
+=
num_batch_examples
acc_sum
+=
accuracy_v
*
num_batch_examples
loss_sum
+=
loss_v
*
num_batch_examples
# log fintune status
if
global_step
%
config
.
log_interval
==
0
:
avg_loss
=
loss_sum
/
num_trained_examples
avg_acc
=
acc_sum
/
num_trained_examples
speed
=
config
.
log_interval
/
train_time_used
logger
.
info
(
"step %d: loss=%.5f acc=%.5f [step/sec: %.2f]"
%
(
global_step
,
avg_loss
,
avg_acc
,
speed
))
# record visualdl log
train_loss_scalar
.
add_record
(
global_step
,
avg_loss
)
train_acc_scalar
.
add_record
(
global_step
,
avg_acc
)
train_time_used
=
0
num_trained_examples
=
acc_sum
=
loss_sum
=
0
if
config
.
save_ckpt_interval
and
global_step
%
config
.
save_ckpt_interval
==
0
:
# NOTE: current saved checkpoint machanism is not completed,
# it can't restore dataset training status
save_checkpoint
(
checkpoint_dir
=
config
.
checkpoint_dir
,
current_epoch
=
epoch
,
global_step
=
global_step
,
exe
=
exe
)
if
do_eval
and
global_step
%
config
.
eval_interval
==
0
:
eval_loss
,
eval_acc
,
eval_perf
=
evaluate_cls_task
(
task
,
data_reader
,
feed_list
,
phase
=
"val"
,
config
=
config
)
eval_loss_scalar
.
add_record
(
global_step
,
eval_loss
)
eval_acc_scalar
.
add_record
(
global_step
,
eval_acc
)
if
eval_acc
>
best_eval_acc
:
best_eval_acc
=
eval_acc
model_saved_dir
=
os
.
path
.
join
(
config
.
checkpoint_dir
,
"best_model"
)
logger
.
info
(
"best model saved to %s [best accuracy=%.5f]"
%
(
model_saved_dir
,
best_eval_acc
))
fluid
.
io
.
save_persistables
(
exe
,
dirname
=
model_saved_dir
)
# NOTE: current saved checkpoint machanism is not completed, it can't
# resotre dataset training status
save_checkpoint
(
checkpoint_dir
=
config
.
checkpoint_dir
,
current_epoch
=
num_epoch
+
1
,
global_step
=
global_step
,
exe
=
exe
)
# Final evaluation
if
do_eval
:
evaluate_cls_task
(
task
,
data_reader
,
feed_list
,
phase
=
"dev"
,
config
=
config
)
evaluate_cls_task
(
task
,
data_reader
,
feed_list
,
phase
=
"test"
,
config
=
config
)
logger
.
info
(
"PaddleHub finetune finished."
)
def
finetune_and_eval
(
task
,
data_reader
,
feed_list
,
config
=
None
):
if
config
is
None
:
config
=
RunConfig
()
if
not
os
.
path
.
exists
(
config
.
checkpoint_dir
):
mkdir
(
config
.
checkpoint_dir
)
if
task
.
task_type
==
"sequence_labeling"
:
_finetune_seq_label_task
(
task
,
data_reader
,
feed_list
,
config
,
do_eval
=
True
)
elif
task
.
task_type
==
"image_classification"
or
task
.
task_type
==
"text_classification"
:
_finetune_cls_task
(
task
,
data_reader
,
feed_list
,
config
,
do_eval
=
True
)
paddlehub/finetune/strategy.py
浏览文件 @
3b2cceb2
...
...
@@ -75,7 +75,7 @@ class DefaultStrategy(object):
self
.
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
self
.
learning_rate
)
def
execute
(
self
,
loss
):
def
execute
(
self
,
loss
,
data_reader
,
config
):
if
self
.
optimizer
is
not
None
:
self
.
optimizer
.
minimize
(
loss
)
else
:
...
...
@@ -115,7 +115,8 @@ class AdamWeightDecayStrategy(DefaultStrategy):
def
weight_decay
(
self
):
return
self
.
_weight_decay
def
execute
(
self
,
loss
,
main_program
,
data_reader
,
config
):
def
execute
(
self
,
loss
,
data_reader
,
config
):
main_program
=
loss
.
block
.
program
# calculate wamrup step
dev_count
=
self
.
_get_dev_count
(
config
)
data_reader
.
data_generator
(
...
...
@@ -159,7 +160,7 @@ class DefaultFinetuneStrategy(DefaultStrategy):
self
.
_optimizer_name
=
optimizer_name
self
.
regularization_coeff
=
regularization_coeff
def
execute
(
self
,
loss
):
def
execute
(
self
,
loss
,
data_reader
,
config
):
# get pretrained parameters
program
=
loss
.
block
.
program
global_block
=
program
.
global_block
()
...
...
@@ -188,7 +189,7 @@ class L2SPFinetuneStrategy(DefaultStrategy):
self
.
_optimizer_name
=
optimizer_name
self
.
regularization_coeff
=
regularization_coeff
def
execute
(
self
,
loss
):
def
execute
(
self
,
loss
,
data_reader
,
config
):
# get pretrained parameters
program
=
loss
.
block
.
program
global_block
=
program
.
global_block
()
...
...
paddlehub/finetune/task.py
浏览文件 @
3b2cceb2
...
...
@@ -19,78 +19,566 @@ from __future__ import print_function
import
os
import
collections
import
contextlib
import
time
import
multiprocessing
import
copy
import
numpy
as
np
import
paddle.fluid
as
fluid
class
Task
(
object
):
"""
A simple transfer learning task definition,
including Paddle's main_program, startup_program and inference program
"""
from
visualdl
import
LogWriter
import
paddlehub
as
hub
from
paddlehub.common.paddle_helper
import
dtype_map
from
paddlehub.common.utils
import
mkdir
from
paddlehub.common.logger
import
logger
from
paddlehub.finetune.checkpoint
import
load_checkpoint
,
save_checkpoint
from
paddlehub.finetune.evaluate
import
chunk_eval
,
calculate_f1
from
paddlehub.finetune.config
import
RunConfig
__all__
=
[
"ClassifierTask"
,
"ImageClassifierTask"
,
"TextClassifierTask"
,
"SequenceLabelTask"
]
class
RunState
(
object
):
def
__init__
(
self
,
length
):
self
.
run_time_begin
=
time
.
time
()
self
.
run_step
=
0
self
.
run_examples
=
0
self
.
run_results
=
[
0
]
*
length
self
.
run_time_used
=
0
self
.
run_speed
=
0.0
def
__add__
(
self
,
other
):
self
.
run_step
+=
other
.
run_step
self
.
run_examples
+=
other
.
run_examples
for
index
in
range
(
len
(
self
.
run_results
)):
self
.
run_results
[
index
]
+=
other
.
run_results
[
index
]
return
self
def
update
(
self
):
self
.
run_time_used
=
time
.
time
()
-
self
.
run_time_begin
self
.
run_speed
=
self
.
run_step
/
self
.
run_time_used
return
self
class
RunEnv
(
object
):
def
__init__
(
self
):
self
.
current_epoch
=
0
self
.
current_step
=
0
self
.
main_program
=
None
self
.
start_program
=
None
self
.
main_program_compiled
=
None
self
.
py_reader
=
None
self
.
reader
=
None
self
.
loss
=
None
self
.
label
=
None
self
.
metrics
=
None
self
.
is_inititalized
=
False
self
.
UNG
=
copy
.
deepcopy
(
fluid
.
unique_name
.
generator
)
def
__setattr__
(
self
,
key
,
value
):
self
.
__dict__
[
key
]
=
value
def
__getattr__
(
self
,
key
):
return
self
.
__dict__
[
key
]
class
BasicTask
(
object
):
def
__init__
(
self
,
task_type
,
graph_var_dict
,
main_program
,
startup_program
,
inference_program
=
None
):
self
.
task_type
=
task_type
self
.
graph_var_dict
=
graph_var_dict
self
.
_main_program
=
main_program
self
.
_startup_program
=
startup_program
self
.
_inference_program
=
inference_program
self
.
_test_program
=
main_program
.
clone
(
for_test
=
True
)
def
variable
(
self
,
var_name
):
if
var_name
in
self
.
graph_var_dict
:
return
self
.
graph_var_dict
[
var_name
]
raise
KeyError
(
"var_name {} not in task graph"
.
format
(
var_name
))
feed_list
,
data_reader
,
main_program
=
None
,
startup_program
=
None
,
config
=
None
):
# base item
self
.
_base_data_reader
=
data_reader
self
.
_base_feed_list
=
feed_list
if
main_program
is
None
:
self
.
_base_main_program
=
fluid
.
default_main_program
().
clone
()
else
:
self
.
_base_main_program
=
main_program
.
clone
()
if
startup_program
is
None
:
self
.
_base_startup_program
=
fluid
.
default_startup_program
().
clone
()
else
:
self
.
_base_startup_program
=
startup_program
.
clone
()
self
.
_load_checkpoint
=
False
self
.
_base_compile_program
=
None
# run config
self
.
config
=
config
if
config
else
RunConfig
()
self
.
place
,
self
.
device_count
=
hub
.
common
.
get_running_device_info
(
self
.
config
)
self
.
exe
=
fluid
.
Executor
(
place
=
self
.
place
)
self
.
build_strategy
=
fluid
.
BuildStrategy
()
if
self
.
config
.
enable_memory_optim
:
self
.
build_strategy
.
memory_optimize
=
True
else
:
self
.
build_strategy
.
memory_optimize
=
False
# log item
if
not
os
.
path
.
exists
(
self
.
config
.
checkpoint_dir
):
mkdir
(
self
.
config
.
checkpoint_dir
)
vdl_log_dir
=
os
.
path
.
join
(
self
.
config
.
checkpoint_dir
,
"vdllog"
)
self
.
log_writer
=
LogWriter
(
vdl_log_dir
,
sync_cycle
=
1
)
# run environment
self
.
_phases
=
[]
self
.
_envs
=
{}
def
init_if_necessary
(
self
):
if
not
self
.
_load_checkpoint
:
self
.
load_checkpoint
()
self
.
_load_checkpoint
=
True
@
contextlib
.
contextmanager
def
phase_guard
(
self
,
phase
):
if
phase
not
in
[
"train"
,
"val"
,
"dev"
,
"test"
,
"predict"
,
"inference"
]:
raise
RuntimeError
()
self
.
_phases
.
append
(
phase
)
yield
self
.
_phases
=
self
.
_phases
[:
-
1
]
def
_build_env
(
self
):
if
self
.
env
.
is_inititalized
:
return
self
.
_build_env_start_event
()
self
.
env
.
is_inititalized
=
True
self
.
env
.
main_program
=
self
.
_base_main_program
.
clone
()
self
.
env
.
startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
self
.
env
.
main_program
,
self
.
_base_startup_program
):
with
fluid
.
unique_name
.
guard
(
self
.
env
.
UNG
):
self
.
env
.
output
=
self
.
_build_net
()
if
self
.
is_train_phase
or
self
.
is_test_phase
:
self
.
env
.
label
=
self
.
_add_label
()
self
.
env
.
loss
=
self
.
_add_loss
()
self
.
env
.
metrics
=
self
.
_add_metrics
()
if
self
.
config
.
use_pyreader
:
t_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
t_program
,
self
.
env
.
startup_program
):
self
.
env
.
py_reader
=
fluid
.
layers
.
py_reader
(
capacity
=
64
,
shapes
=
[
var
.
shape
for
var
in
self
.
feed_var_list
],
dtypes
=
[
dtype_map
[
var
.
dtype
]
for
var
in
self
.
feed_var_list
],
lod_levels
=
[
var
.
lod_level
for
var
in
self
.
feed_var_list
],
use_double_buffer
=
False
)
feed_var_list
=
self
.
feed_var_list
py_vars
=
fluid
.
layers
.
read_file
(
self
.
env
.
py_reader
)
input_dict
=
{
feed_var_list
[
index
].
name
:
py_var
for
index
,
py_var
in
enumerate
(
py_vars
)
}
hub
.
connect_program
(
pre_program
=
t_program
,
next_program
=
self
.
env
.
main_program
,
input_dict
=
input_dict
,
need_log
=
False
)
self
.
env
.
main_program
=
t_program
self
.
env
.
loss
=
self
.
env
.
main_program
.
global_block
().
vars
[
self
.
env
.
loss
.
name
]
self
.
env
.
output
=
self
.
env
.
main_program
.
global_block
().
vars
[
self
.
env
.
output
.
name
]
metrics_name
=
[
var
.
name
for
var
in
self
.
env
.
metrics
]
self
.
env
.
metrics
=
[
self
.
env
.
main_program
.
global_block
().
vars
[
name
]
for
name
in
metrics_name
]
if
self
.
config
.
enable_memory_optim
:
for
var_name
in
self
.
fetch_list
:
var
=
self
.
env
.
main_program
.
global_block
().
vars
[
var_name
]
var
.
persistable
=
True
if
self
.
is_train_phase
:
with
fluid
.
program_guard
(
self
.
env
.
main_program
,
self
.
_base_startup_program
):
with
fluid
.
unique_name
.
guard
(
self
.
env
.
UNG
):
self
.
config
.
strategy
.
execute
(
self
.
loss
,
self
.
_base_data_reader
,
self
.
config
)
if
self
.
is_train_phase
:
loss_name
=
self
.
env
.
loss
.
name
share_vars_from
=
None
else
:
loss_name
=
None
if
self
.
_base_compile_program
is
None
:
share_vars_from
=
None
else
:
share_vars_from
=
self
.
_base_compile_program
self
.
env
.
main_program_compiled
=
fluid
.
CompiledProgram
(
self
.
env
.
main_program
).
with_data_parallel
(
loss_name
=
loss_name
,
share_vars_from
=
share_vars_from
,
build_strategy
=
self
.
build_strategy
)
if
self
.
_base_compile_program
is
None
:
self
.
_base_compile_program
=
self
.
env
.
main_program_compiled
self
.
exe
.
run
(
self
.
env
.
startup_program
)
self
.
_build_env_end_event
()
@
property
def
is_train_phase
(
self
):
return
self
.
phase
in
[
"train"
]
@
property
def
is_test_phase
(
self
):
return
self
.
phase
in
[
"val"
,
"dev"
,
"test"
]
@
property
def
is_predict_phase
(
self
):
return
self
.
phase
in
[
"predict"
,
"inference"
]
@
property
def
phase
(
self
):
return
self
.
_phases
[
-
1
]
@
property
def
env
(
self
):
phase
=
self
.
phase
if
phase
in
[
"val"
,
"dev"
,
"test"
]:
phase
=
"val"
if
not
phase
in
self
.
_envs
:
self
.
_envs
[
phase
]
=
RunEnv
()
return
self
.
_envs
[
phase
]
@
property
def
py_reader
(
self
):
if
not
self
.
env
.
is_inititalized
:
self
.
_build_env
()
return
self
.
env
.
py_reader
@
property
def
current_step
(
self
):
if
not
self
.
env
.
is_inititalized
:
self
.
_build_env
()
return
self
.
env
.
current_step
@
property
def
current_epoch
(
self
):
if
not
self
.
env
.
is_inititalized
:
self
.
_build_env
()
return
self
.
env
.
current_epoch
@
property
def
main_program
(
self
):
return
self
.
_main_program
if
not
self
.
env
.
is_inititalized
:
self
.
_build_env
()
return
self
.
env
.
main_program
@
property
def
startup_program
(
self
):
return
self
.
_startup_program
def
inference_program
(
self
):
return
self
.
_inference_program
def
test_program
(
self
):
return
self
.
_test_program
def
metric_variable_names
(
self
):
metric_variable_names
=
[]
for
var_name
in
self
.
graph_var_dict
:
metric_variable_names
.
append
(
var_name
)
return
metric_variable_names
def
create_text_cls_task
(
feature
,
num_classes
,
hidden_units
=
None
):
"""
Append a multi-layer perceptron classifier for binary classification base
on input feature
"""
program
=
feature
.
block
.
program
with
fluid
.
program_guard
(
program
):
cls_feats
=
fluid
.
layers
.
dropout
(
x
=
feature
,
dropout_prob
=
0.1
,
dropout_implementation
=
"upscale_in_train"
)
# append fully connected layer according to hidden_units
if
hidden_units
is
not
None
:
for
n_hidden
in
hidden_units
:
cls_feats
=
fluid
.
layers
.
fc
(
input
=
cls_feats
,
size
=
n_hidden
)
if
not
self
.
env
.
is_inititalized
:
self
.
_build_env
()
return
self
.
env
.
startup_program
@
property
def
main_program_compiled
(
self
):
if
not
self
.
env
.
is_inititalized
:
self
.
_build_env
()
return
self
.
env
.
main_program_compiled
@
property
def
reader
(
self
):
self
.
env
.
reader
=
self
.
_base_data_reader
.
data_generator
(
batch_size
=
self
.
config
.
batch_size
,
phase
=
self
.
phase
)
return
self
.
env
.
reader
@
property
def
loss
(
self
):
if
self
.
is_predict_phase
:
raise
RuntimeError
()
if
not
self
.
env
.
is_inititalized
:
self
.
_build_env
()
return
self
.
env
.
loss
@
property
def
label
(
self
):
if
self
.
is_predict_phase
:
raise
RuntimeError
()
if
not
self
.
env
.
is_inititalized
:
self
.
_build_env
()
return
self
.
env
.
label
@
property
def
output
(
self
):
if
self
.
is_predict_phase
:
raise
RuntimeError
()
if
not
self
.
env
.
is_inititalized
:
self
.
_build_env
()
return
self
.
env
.
output
@
property
def
metrics
(
self
):
if
self
.
is_predict_phase
:
raise
RuntimeError
()
if
not
self
.
env
.
is_inititalized
:
self
.
_build_env
()
return
self
.
env
.
metrics
@
property
def
unique_name_generator
(
self
):
return
self
.
env
.
UNG
@
property
def
feed_list
(
self
):
feed_list
=
[
varname
for
varname
in
self
.
_base_feed_list
]
if
self
.
is_train_phase
or
self
.
is_test_phase
:
feed_list
+=
[
self
.
label
.
name
]
return
feed_list
@
property
def
feed_var_list
(
self
):
vars
=
self
.
main_program
.
global_block
().
vars
return
[
vars
[
varname
]
for
varname
in
self
.
feed_list
]
@
property
def
fetch_list
(
self
):
if
self
.
is_train_phase
or
self
.
is_test_phase
:
return
[
metric
.
name
for
metric
in
self
.
metrics
]
+
[
self
.
loss
.
name
]
return
[
self
.
output
.
name
]
def
_build_env_start_event
(
self
):
pass
def
_build_env_end_event
(
self
):
pass
def
_eval_start_event
(
self
):
logger
.
info
(
"Evaluation on {} dataset start"
.
format
(
self
.
phase
))
def
_eval_end_event
(
self
,
run_state
):
logger
.
info
(
"[%s dataset evaluation result] [step/sec: %.2f]"
%
(
self
.
phase
,
run_state
.
run_speed
))
def
_log_interval_event
(
self
,
run_state
):
logger
.
info
(
"step %d: [step/sec: %.2f]"
%
(
self
.
current_step
,
run_state
.
run_speed
))
def
_save_ckpt_interval_event
(
self
):
self
.
save_checkpoint
(
self
.
current_epoch
,
self
.
current_step
)
def
_eval_interval_event
(
self
):
self
.
eval
(
phase
=
"dev"
)
def
_run_step_event
(
self
,
run_state
):
if
self
.
is_predict_phase
:
yield
run_state
.
run_results
def
_finetune_start_event
(
self
):
logger
.
info
(
"PaddleHub finetune start"
)
def
_finetune_end_event
(
self
,
run_state
):
logger
.
info
(
"PaddleHub finetune finished."
)
def
_build_net
(
self
):
raise
NotImplementedError
def
_add_loss
(
self
):
raise
NotImplementedError
def
_add_label
(
self
):
raise
NotImplementedError
def
_add_metrics
(
self
):
raise
NotImplementedError
# NOTE: current saved checkpoint machanism is not completed,
# it can't restore dataset training status
def
save_checkpoint
(
self
,
epoch
,
step
):
save_checkpoint
(
checkpoint_dir
=
self
.
config
.
checkpoint_dir
,
current_epoch
=
self
.
current_epoch
,
global_step
=
self
.
current_step
,
exe
=
self
.
exe
,
main_program
=
self
.
main_program
)
def
load_checkpoint
(
self
,
load_best_model
=
False
):
self
.
env
.
current_epoch
,
self
.
env
.
current_step
=
load_checkpoint
(
self
.
config
.
checkpoint_dir
,
self
.
exe
,
main_program
=
self
.
main_program
,
startup_program
=
self
.
_base_startup_program
)
if
load_best_model
:
model_saved_dir
=
os
.
path
.
join
(
self
.
config
.
checkpoint_dir
,
"best_model"
)
if
os
.
path
.
exists
(
model_saved_dir
):
fluid
.
io
.
load_persistables
(
executor
=
self
.
exe
,
dirname
=
model_saved_dir
,
main_program
=
self
.
main_program
)
def
finetune_and_eval
(
self
):
self
.
finetune
(
do_eval
=
True
)
def
finetune
(
self
,
do_eval
=
False
):
# Start to finetune
with
self
.
phase_guard
(
phase
=
"train"
):
self
.
init_if_necessary
()
self
.
_finetune_start_event
()
run_states
=
[]
if
self
.
current_epoch
<=
self
.
config
.
num_epoch
:
while
self
.
current_epoch
<=
self
.
config
.
num_epoch
:
run_states
=
self
.
_run
(
do_eval
=
do_eval
)
self
.
env
.
current_epoch
+=
1
# Save checkpoint after finetune
self
.
save_checkpoint
(
self
.
current_epoch
+
1
,
self
.
current_step
)
# Final evaluation
self
.
eval
(
phase
=
"dev"
)
self
.
eval
(
phase
=
"test"
)
self
.
_finetune_end_event
(
run_states
)
def
eval
(
self
,
phase
=
"dev"
):
with
self
.
phase_guard
(
phase
=
phase
):
self
.
init_if_necessary
()
self
.
_eval_start_event
()
run_states
=
self
.
_run
()
self
.
_eval_end_event
(
run_states
)
def
predict
(
self
,
data
,
load_best_model
=
True
):
with
self
.
phase_guard
(
phase
=
phase
):
self
.
init_if_necessary
()
for
run_state
in
self
.
_run
():
yield
run_state
.
run_results
def
_run
(
self
,
do_eval
=
False
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
if
self
.
config
.
use_pyreader
:
return
self
.
_run_with_py_reader
(
do_eval
=
do_eval
)
return
self
.
_run_with_data_feeder
(
do_eval
=
do_eval
)
def
_run_with_data_feeder
(
self
,
do_eval
=
False
):
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
self
.
feed_list
,
place
=
self
.
place
)
global_run_states
=
[]
period_run_states
=
[]
for
run_step
,
batch
in
enumerate
(
self
.
reader
(),
start
=
1
):
step_run_state
=
RunState
(
len
(
self
.
fetch_list
))
step_run_state
.
run_step
=
1
num_batch_examples
=
len
(
batch
)
fetch_result
=
self
.
exe
.
run
(
self
.
main_program_compiled
,
feed
=
data_feeder
.
feed
(
batch
),
fetch_list
=
self
.
fetch_list
)
for
index
,
result
in
enumerate
(
fetch_result
):
step_run_state
.
run_results
[
index
]
=
result
step_run_state
.
run_examples
+=
num_batch_examples
step_run_state
.
update
()
period_run_states
+=
[
step_run_state
]
if
self
.
is_train_phase
:
self
.
env
.
current_step
+=
1
if
self
.
current_step
%
self
.
config
.
log_interval
==
0
:
self
.
_log_interval_event
(
period_run_states
)
global_run_states
+=
period_run_states
period_run_states
=
[]
if
self
.
config
.
save_ckpt_interval
and
self
.
current_step
%
self
.
config
.
save_ckpt_interval
==
0
:
self
.
_save_ckpt_interval_event
()
if
do_eval
and
self
.
current_step
%
self
.
config
.
eval_interval
==
0
:
self
.
_eval_interval_event
()
self
.
_run_step_event
(
step_run_state
)
global_run_states
+=
period_run_states
return
global_run_states
def
_run_with_py_reader
(
self
,
do_eval
=
False
):
global_run_states
=
[]
period_run_states
=
[]
self
.
py_reader
.
decorate_paddle_reader
(
self
.
reader
)
self
.
py_reader
.
start
()
try
:
while
True
:
num_batch_examples
=
self
.
config
.
batch_size
step_run_state
=
RunState
(
len
(
self
.
fetch_list
))
step_run_state
.
run_step
=
1
fetch_result
=
self
.
exe
.
run
(
self
.
main_program_compiled
,
fetch_list
=
self
.
fetch_list
)
for
index
,
result
in
enumerate
(
fetch_result
):
step_run_state
.
run_results
[
index
]
=
result
step_run_state
.
run_examples
+=
num_batch_examples
step_run_state
.
update
()
period_run_states
+=
[
step_run_state
]
if
self
.
is_train_phase
:
self
.
env
.
current_step
+=
1
if
self
.
current_step
%
self
.
config
.
log_interval
==
0
:
self
.
_log_interval_event
(
period_run_states
)
global_run_states
+=
period_run_states
period_run_states
=
[]
if
self
.
config
.
save_ckpt_interval
and
self
.
current_step
%
self
.
config
.
save_ckpt_interval
==
0
:
self
.
_save_ckpt_interval_event
()
if
do_eval
and
self
.
current_step
%
self
.
config
.
eval_interval
==
0
:
self
.
_eval_interval_event
()
self
.
_run_step_event
(
step_run_state
)
except
fluid
.
core
.
EOFException
:
self
.
py_reader
.
reset
()
global_run_states
+=
period_run_states
return
global_run_states
class
ClassifierTask
(
BasicTask
):
def
__init__
(
self
,
data_reader
,
feature
,
num_classes
,
feed_list
,
startup_program
=
None
,
config
=
None
,
hidden_units
=
None
):
main_program
=
feature
.
block
.
program
super
(
ClassifierTask
,
self
).
__init__
(
data_reader
=
data_reader
,
main_program
=
main_program
,
feed_list
=
feed_list
,
startup_program
=
startup_program
,
config
=
config
)
self
.
feature
=
feature
self
.
num_classes
=
num_classes
self
.
hidden_units
=
hidden_units
self
.
best_accuracy
=
-
1
def
_build_net
(
self
):
cls_feats
=
self
.
feature
if
self
.
hidden_units
is
not
None
:
for
n_hidden
in
self
.
hidden_units
:
cls_feats
=
fluid
.
layers
.
fc
(
input
=
cls_feats
,
size
=
n_hidden
,
act
=
"relu"
)
logits
=
fluid
.
layers
.
fc
(
input
=
cls_feats
,
size
=
num_classes
,
size
=
self
.
num_classes
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"cls_out_w"
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
scale
=
0.02
)),
...
...
@@ -98,56 +586,108 @@ def create_text_cls_task(feature, num_classes, hidden_units=None):
name
=
"cls_out_b"
,
initializer
=
fluid
.
initializer
.
Constant
(
0.
)),
act
=
"softmax"
)
inference_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
dtype
=
"int64"
,
shape
=
[
1
])
ce_loss
=
fluid
.
layers
.
cross_entropy
(
input
=
logits
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
x
=
ce_loss
)
return
logits
def
_add_label
(
self
):
return
fluid
.
layers
.
data
(
name
=
"label"
,
dtype
=
"int64"
,
shape
=
[
1
])
def
_add_loss
(
self
):
ce_loss
=
fluid
.
layers
.
cross_entropy
(
input
=
self
.
output
,
label
=
self
.
label
)
return
fluid
.
layers
.
mean
(
x
=
ce_loss
)
def
_add_metrics
(
self
):
return
[
fluid
.
layers
.
accuracy
(
input
=
self
.
output
,
label
=
self
.
label
)]
def
_build_env_end_event
(
self
):
with
self
.
log_writer
.
mode
(
self
.
phase
)
as
logw
:
self
.
env
.
loss_scalar
=
logw
.
scalar
(
tag
=
"Loss [{}]"
.
format
(
self
.
phase
))
self
.
env
.
acc_scalar
=
logw
.
scalar
(
tag
=
"Accuracy [{}]"
.
format
(
self
.
phase
))
def
_calculate_metrics
(
self
,
run_states
):
loss_sum
=
acc_sum
=
run_examples
=
0
run_step
=
run_time_used
=
0
for
run_state
in
run_states
:
run_examples
+=
run_state
.
run_examples
run_step
+=
run_state
.
run_step
loss_sum
+=
np
.
mean
(
run_state
.
run_results
[
-
1
])
*
run_state
.
run_examples
acc_sum
+=
np
.
mean
(
run_state
.
run_results
[
0
])
*
run_state
.
run_examples
run_time_used
=
time
.
time
()
-
run_states
[
0
].
run_time_begin
avg_loss
=
loss_sum
/
run_examples
avg_acc
=
acc_sum
/
run_examples
run_speed
=
run_step
/
run_time_used
return
avg_loss
,
avg_acc
,
run_speed
def
_log_interval_event
(
self
,
run_states
):
avg_loss
,
avg_acc
,
run_speed
=
self
.
_calculate_metrics
(
run_states
)
self
.
env
.
loss_scalar
.
add_record
(
self
.
current_step
,
avg_loss
)
self
.
env
.
acc_scalar
.
add_record
(
self
.
current_step
,
avg_acc
)
logger
.
info
(
"step %d: loss=%.5f acc=%.5f [step/sec: %.2f]"
%
(
self
.
current_step
,
avg_loss
,
avg_acc
,
run_speed
))
def
_eval_end_event
(
self
,
run_states
):
eval_loss
,
eval_acc
,
run_speed
=
self
.
_calculate_metrics
(
run_states
)
logger
.
info
(
"[%s dataset evaluation result] loss=%.5f acc=%.5f [step/sec: %.2f]"
%
(
self
.
phase
,
eval_loss
,
eval_acc
,
run_speed
))
if
self
.
phase
in
[
"dev"
,
"val"
]
and
eval_acc
>
self
.
best_accuracy
:
self
.
env
.
loss_scalar
.
add_record
(
self
.
current_step
,
eval_loss
)
self
.
env
.
acc_scalar
.
add_record
(
self
.
current_step
,
eval_acc
)
self
.
best_accuracy
=
eval_acc
model_saved_dir
=
os
.
path
.
join
(
self
.
config
.
checkpoint_dir
,
"best_model"
)
logger
.
info
(
"best model saved to %s [best accuracy=%.5f]"
%
(
model_saved_dir
,
self
.
best_accuracy
))
save_result
=
fluid
.
io
.
save_persistables
(
executor
=
self
.
exe
,
dirname
=
model_saved_dir
,
main_program
=
self
.
main_program
)
ImageClassifierTask
=
ClassifierTask
class
TextClassifierTask
(
ClassifierTask
):
def
__init__
(
self
,
data_reader
,
feature
,
num_classes
,
feed_list
,
startup_program
=
None
,
config
=
None
,
hidden_units
=
None
):
main_program
=
feature
.
block
.
program
super
(
TextClassifierTask
,
self
).
__init__
(
data_reader
=
data_reader
,
feature
=
feature
,
num_classes
=
num_classes
,
feed_list
=
feed_list
,
startup_program
=
startup_program
,
config
=
config
,
hidden_units
=
hidden_units
)
def
_build_net
(
self
):
cls_feats
=
fluid
.
layers
.
dropout
(
x
=
self
.
feature
,
dropout_prob
=
0.1
,
dropout_implementation
=
"upscale_in_train"
)
if
self
.
hidden_units
is
not
None
:
for
n_hidden
in
self
.
hidden_units
:
cls_feats
=
fluid
.
layers
.
fc
(
input
=
cls_feats
,
size
=
n_hidden
,
act
=
"relu"
)
num_example
=
fluid
.
layers
.
create_tensor
(
dtype
=
'int64'
)
accuracy
=
fluid
.
layers
.
accuracy
(
input
=
logits
,
label
=
label
,
total
=
num_example
)
graph_var_dict
=
{
"loss"
:
loss
,
"accuracy"
:
accuracy
,
"num_example"
:
num_example
,
"label"
:
label
,
"probs"
:
logits
}
task
=
Task
(
"text_classification"
,
graph_var_dict
,
fluid
.
default_main_program
(),
fluid
.
default_startup_program
(),
inference_program
=
inference_program
)
return
task
def
create_img_cls_task
(
feature
,
num_classes
,
hidden_units
=
None
):
"""
Create the transfer learning task for image classification.
Args:
feature:
Return:
Task
Raise:
None
"""
program
=
feature
.
block
.
program
with
fluid
.
program_guard
(
program
):
cls_feats
=
feature
# append fully connected layer according to hidden_units
if
hidden_units
is
not
None
:
for
n_hidden
in
hidden_units
:
cls_feats
=
fluid
.
layers
.
fc
(
input
=
cls_feats
,
size
=
n_hidden
)
probs
=
fluid
.
layers
.
fc
(
logits
=
fluid
.
layers
.
fc
(
input
=
cls_feats
,
size
=
num_classes
,
size
=
self
.
num_classes
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"cls_out_w"
,
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
scale
=
0.02
)),
...
...
@@ -155,41 +695,39 @@ def create_img_cls_task(feature, num_classes, hidden_units=None):
name
=
"cls_out_b"
,
initializer
=
fluid
.
initializer
.
Constant
(
0.
)),
act
=
"softmax"
)
inference_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
dtype
=
"int64"
,
shape
=
[
1
])
ce_loss
=
fluid
.
layers
.
cross_entropy
(
input
=
probs
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
x
=
ce_loss
)
num_example
=
fluid
.
layers
.
create_tensor
(
dtype
=
'int64'
)
accuracy
=
fluid
.
layers
.
accuracy
(
input
=
probs
,
label
=
label
,
total
=
num_example
)
graph_var_dict
=
{
"loss"
:
loss
,
"probs"
:
probs
,
"accuracy"
:
accuracy
,
"num_example"
:
num_example
,
"label"
:
label
,
"probs"
:
probs
}
task
=
Task
(
"image_classification"
,
graph_var_dict
,
fluid
.
default_main_program
(),
fluid
.
default_startup_program
(),
inference_program
=
inference_program
)
return
task
def
create_seq_label_task
(
feature
,
max_seq_len
,
num_classes
):
program
=
feature
.
block
.
program
with
fluid
.
program_guard
(
program
):
logits
=
fluid
.
layers
.
fc
(
input
=
feature
,
size
=
num_classes
,
return
logits
class
SequenceLabelTask
(
BasicTask
):
def
__init__
(
self
,
feature
,
max_seq_len
,
num_classes
,
data_reader
,
feed_list
,
startup_program
=
None
,
config
=
None
,
):
main_program
=
feature
.
block
.
program
super
(
SequenceLabelTask
,
self
).
__init__
(
data_reader
=
data_reader
,
main_program
=
main_program
,
feed_list
=
feed_list
,
startup_program
=
startup_program
,
config
=
config
)
self
.
feature
=
feature
self
.
max_seq_len
=
max_seq_len
self
.
num_classes
=
num_classes
self
.
best_f1
=
-
1
def
_build_net
(
self
):
self
.
logits
=
fluid
.
layers
.
fc
(
input
=
self
.
feature
,
size
=
self
.
num_classes
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"cls_seq_label_out_w"
,
...
...
@@ -198,37 +736,95 @@ def create_seq_label_task(feature, max_seq_len, num_classes):
name
=
"cls_seq_label_out_b"
,
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))
ret_infers
=
fluid
.
layers
.
reshape
(
x
=
fluid
.
layers
.
argmax
(
logits
,
axis
=
2
),
shape
=
[
-
1
,
1
])
logits
=
self
.
logits
logits
=
fluid
.
layers
.
flatten
(
logits
,
axis
=
2
)
logits
=
fluid
.
layers
.
softmax
(
logits
)
self
.
num_labels
=
logits
.
shape
[
1
]
return
logits
inference_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
seq_len
=
fluid
.
layers
.
data
(
name
=
"seq_len"
,
shape
=
[
1
],
dtype
=
'int64'
)
def
_add_label
(
self
):
label
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
max_seq_len
,
1
],
dtype
=
'int64'
)
ret
_labels
=
fluid
.
layers
.
reshape
(
x
=
label
,
shape
=
[
-
1
,
1
])
name
=
"label"
,
shape
=
[
self
.
max_seq_len
,
1
],
dtype
=
'int64'
)
ret
urn
label
labels
=
fluid
.
layers
.
flatten
(
label
,
axis
=
2
)
ce_loss
=
fluid
.
layers
.
cross_entropy
(
input
=
logits
,
label
=
labels
)
def
_add_loss
(
self
):
labels
=
fluid
.
layers
.
flatten
(
self
.
label
,
axis
=
2
)
ce_loss
=
fluid
.
layers
.
cross_entropy
(
input
=
self
.
output
,
label
=
labels
)
loss
=
fluid
.
layers
.
mean
(
x
=
ce_loss
)
return
loss
graph_var_dict
=
{
"loss"
:
loss
,
"probs"
:
logits
,
"labels"
:
ret_labels
,
"infers"
:
ret_infers
,
"seq_len"
:
seq_len
,
"label"
:
label
}
task
=
Task
(
"sequence_labeling"
,
graph_var_dict
,
fluid
.
default_main_program
(),
fluid
.
default_startup_program
(),
inference_program
=
inference_program
)
return
task
def
_add_metrics
(
self
):
ret_labels
=
fluid
.
layers
.
reshape
(
x
=
self
.
label
,
shape
=
[
-
1
,
1
])
ret_infers
=
fluid
.
layers
.
reshape
(
x
=
fluid
.
layers
.
argmax
(
self
.
logits
,
axis
=
2
),
shape
=
[
-
1
,
1
])
self
.
seq_len
=
fluid
.
layers
.
data
(
name
=
"seq_len"
,
shape
=
[
1
],
dtype
=
'int64'
)
seq_len
=
fluid
.
layers
.
assign
(
self
.
seq_len
)
return
[
ret_labels
,
ret_infers
,
seq_len
]
def
_build_env_end_event
(
self
):
with
self
.
log_writer
.
mode
(
self
.
phase
)
as
logw
:
self
.
env
.
loss_scalar
=
logw
.
scalar
(
tag
=
"Loss [{}]"
.
format
(
self
.
phase
))
self
.
env
.
f1_scalar
=
logw
.
scalar
(
tag
=
"F1 [{}]"
.
format
(
self
.
phase
))
self
.
env
.
precision_scalar
=
logw
.
scalar
(
tag
=
"Precision [{}]"
.
format
(
self
.
phase
))
self
.
env
.
recall_scalar
=
logw
.
scalar
(
tag
=
"Recall [{}]"
.
format
(
self
.
phase
))
def
_calculate_metrics
(
self
,
run_states
):
total_infer
=
total_label
=
total_correct
=
loss_sum
=
0
run_step
=
run_time_used
=
run_examples
=
0
for
run_state
in
run_states
:
loss_sum
+=
np
.
mean
(
run_state
.
run_results
[
-
1
])
np_labels
=
run_state
.
run_results
[
0
]
np_infers
=
run_state
.
run_results
[
1
]
np_lens
=
run_state
.
run_results
[
2
]
label_num
,
infer_num
,
correct_num
=
chunk_eval
(
np_labels
,
np_infers
,
np_lens
,
self
.
num_labels
,
self
.
device_count
)
total_infer
+=
infer_num
total_label
+=
label_num
total_correct
+=
correct_num
run_examples
+=
run_state
.
run_examples
run_step
+=
run_state
.
run_step
run_time_used
=
time
.
time
()
-
run_states
[
0
].
run_time_begin
run_speed
=
run_step
/
run_time_used
avg_loss
=
loss_sum
/
run_examples
precision
,
recall
,
f1
=
calculate_f1
(
total_label
,
total_infer
,
total_correct
)
return
precision
,
recall
,
f1
,
avg_loss
,
run_speed
def
_log_interval_event
(
self
,
run_states
):
precision
,
recall
,
f1
,
avg_loss
,
run_speed
=
self
.
_calculate_metrics
(
run_states
)
self
.
env
.
loss_scalar
.
add_record
(
self
.
current_step
,
avg_loss
)
logger
.
info
(
"step %d: loss=%.5f [step/sec: %.2f]"
%
(
self
.
current_step
,
avg_loss
,
run_speed
))
def
_eval_end_event
(
self
,
run_states
):
precision
,
recall
,
f1
,
avg_loss
,
run_speed
=
self
.
_calculate_metrics
(
run_states
)
self
.
env
.
f1_scalar
.
add_record
(
self
.
current_step
,
f1
)
self
.
env
.
precision_scalar
.
add_record
(
self
.
current_step
,
precision
)
self
.
env
.
recall_scalar
.
add_record
(
self
.
current_step
,
recall
)
logger
.
info
(
"[%s dataset evaluation result] [step/sec: %.2f]"
%
(
self
.
phase
,
run_speed
))
logger
.
info
(
"[%s evaluation] F1-Score=%f, precision=%f, recall=%f [step/sec: %.2f]"
%
(
self
.
phase
,
f1
,
precision
,
recall
,
run_speed
))
if
self
.
phase
in
[
"dev"
,
"val"
]
and
f1
>
self
.
best_f1
:
self
.
best_f1
=
f1
model_saved_dir
=
os
.
path
.
join
(
self
.
config
.
checkpoint_dir
,
"best_model"
)
logger
.
info
(
"best model saved to %s [best F1=%.5f]"
%
(
model_saved_dir
,
self
.
best_f1
))
fluid
.
io
.
save_persistables
(
self
.
exe
,
dirname
=
model_saved_dir
)
@
property
def
feed_list
(
self
):
feed_list
=
[
varname
for
varname
in
self
.
_base_feed_list
]
if
self
.
is_train_phase
or
self
.
is_test_phase
:
feed_list
+=
[
self
.
label
.
name
,
self
.
seq_len
.
name
]
return
feed_list
paddlehub/reader/nlp_reader.py
浏览文件 @
3b2cceb2
...
...
@@ -19,16 +19,17 @@ from __future__ import print_function
import
csv
import
json
import
numpy
as
np
import
platform
import
six
import
sys
from
collections
import
namedtuple
import
paddle
import
numpy
as
np
from
paddlehub.reader
import
tokenization
from
paddlehub.common.logger
import
logger
from
paddlehub.dataset.dataset
import
InputExample
from
.batching
import
pad_batch_data
import
paddlehub
as
hub
...
...
@@ -100,7 +101,11 @@ class BaseReader(object):
else
:
tokens_b
.
pop
()
def
_convert_example_to_record
(
self
,
example
,
max_seq_length
,
tokenizer
):
def
_convert_example_to_record
(
self
,
example
,
max_seq_length
,
tokenizer
,
phase
=
None
):
"""Converts a single `Example` into a single `Record`."""
text_a
=
tokenization
.
convert_to_unicode
(
example
.
text_a
)
...
...
@@ -171,11 +176,24 @@ class BaseReader(object):
'Record'
,
[
'token_ids'
,
'text_type_ids'
,
'position_ids'
,
'label_id'
])
record
=
Record
(
token_ids
=
token_ids
,
text_type_ids
=
text_type_ids
,
position_ids
=
position_ids
,
label_id
=
label_id
)
if
phase
!=
"predict"
:
Record
=
namedtuple
(
'Record'
,
[
'token_ids'
,
'text_type_ids'
,
'position_ids'
,
'label_id'
])
record
=
Record
(
token_ids
=
token_ids
,
text_type_ids
=
text_type_ids
,
position_ids
=
position_ids
,
label_id
=
label_id
)
else
:
Record
=
namedtuple
(
'Record'
,
[
'token_ids'
,
'text_type_ids'
,
'position_ids'
])
record
=
Record
(
token_ids
=
token_ids
,
text_type_ids
=
text_type_ids
,
position_ids
=
position_ids
)
return
record
def
_prepare_batch_data
(
self
,
examples
,
batch_size
,
phase
=
None
):
...
...
@@ -185,7 +203,7 @@ class BaseReader(object):
if
phase
==
"train"
:
self
.
current_example
=
index
record
=
self
.
_convert_example_to_record
(
example
,
self
.
max_seq_len
,
self
.
tokenizer
)
self
.
tokenizer
,
phase
)
max_len
=
max
(
max_len
,
len
(
record
.
token_ids
))
if
self
.
in_tokens
:
to_append
=
(
len
(
batch_records
)
+
1
)
*
max_len
<=
batch_size
...
...
@@ -194,11 +212,11 @@ class BaseReader(object):
if
to_append
:
batch_records
.
append
(
record
)
else
:
yield
self
.
_pad_batch_records
(
batch_records
)
yield
self
.
_pad_batch_records
(
batch_records
,
phase
)
batch_records
,
max_len
=
[
record
],
len
(
record
.
token_ids
)
if
batch_records
:
yield
self
.
_pad_batch_records
(
batch_records
)
yield
self
.
_pad_batch_records
(
batch_records
,
phase
)
def
get_num_examples
(
self
,
phase
):
"""Get number of examples for train, dev or test."""
...
...
@@ -208,20 +226,51 @@ class BaseReader(object):
)
return
self
.
num_examples
[
phase
]
def
data_generator
(
self
,
batch_size
=
1
,
phase
=
'train'
,
shuffle
=
True
):
def
data_generator
(
self
,
batch_size
=
1
,
phase
=
'train'
,
shuffle
=
True
,
data
=
None
):
if
phase
==
'train'
:
shuffle
=
True
examples
=
self
.
get_train_examples
()
self
.
num_examples
[
'train'
]
=
len
(
examples
)
elif
phase
==
'val'
or
phase
==
'dev'
:
shuffle
=
False
examples
=
self
.
get_dev_examples
()
self
.
num_examples
[
'dev'
]
=
len
(
examples
)
elif
phase
==
'test'
:
shuffle
=
False
examples
=
self
.
get_test_examples
()
self
.
num_examples
[
'test'
]
=
len
(
examples
)
elif
phase
==
'predict'
:
shuffle
=
False
examples
=
[]
seq_id
=
0
for
item
in
data
:
# set label in order to run the program
label
=
"0"
if
len
(
item
)
==
1
:
item_i
=
InputExample
(
guid
=
seq_id
,
text_a
=
item
[
0
],
label
=
label
)
elif
len
(
item
)
==
2
:
item_i
=
InputExample
(
guid
=
seq_id
,
text_a
=
item
[
0
],
text_b
=
item
[
1
],
label
=
label
)
else
:
raise
ValueError
(
"The length of input_text is out of handling, which must be 1 or 2!"
)
examples
.
append
(
item_i
)
seq_id
+=
1
else
:
raise
ValueError
(
"Unknown phase, which should be in ['train', 'dev', 'test']."
)
"Unknown phase, which should be in ['train', 'dev', 'test', 'predict']."
)
def
wrapper
():
if
shuffle
:
...
...
@@ -235,20 +284,11 @@ class BaseReader(object):
class
ClassifyReader
(
BaseReader
):
def
_pad_batch_records
(
self
,
batch_records
):
def
_pad_batch_records
(
self
,
batch_records
,
phase
=
None
):
batch_token_ids
=
[
record
.
token_ids
for
record
in
batch_records
]
batch_text_type_ids
=
[
record
.
text_type_ids
for
record
in
batch_records
]
batch_position_ids
=
[
record
.
position_ids
for
record
in
batch_records
]
batch_labels
=
[
record
.
label_id
for
record
in
batch_records
]
batch_labels
=
np
.
array
(
batch_labels
).
astype
(
"int64"
).
reshape
([
-
1
,
1
])
# if batch_records[0].qid:
# batch_qids = [record.qid for record in batch_records]
# batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1])
# else:
# batch_qids = np.array([]).astype("int64").reshape([-1, 1])
# padding
padded_token_ids
,
input_mask
=
pad_batch_data
(
batch_token_ids
,
max_seq_len
=
self
.
max_seq_len
,
...
...
@@ -263,20 +303,29 @@ class ClassifyReader(BaseReader):
max_seq_len
=
self
.
max_seq_len
,
pad_idx
=
self
.
pad_id
)
return_list
=
[
padded_token_ids
,
padded_position_ids
,
padded_text_type_ids
,
input_mask
,
batch_labels
]
if
phase
!=
"predict"
:
batch_labels
=
[
record
.
label_id
for
record
in
batch_records
]
batch_labels
=
np
.
array
(
batch_labels
).
astype
(
"int64"
).
reshape
(
[
-
1
,
1
])
return_list
=
[
padded_token_ids
,
padded_position_ids
,
padded_text_type_ids
,
input_mask
,
batch_labels
]
else
:
return_list
=
[
padded_token_ids
,
padded_position_ids
,
padded_text_type_ids
,
input_mask
]
return
return_list
class
SequenceLabelReader
(
BaseReader
):
def
_pad_batch_records
(
self
,
batch_records
):
def
_pad_batch_records
(
self
,
batch_records
,
phase
=
None
):
batch_token_ids
=
[
record
.
token_ids
for
record
in
batch_records
]
batch_text_type_ids
=
[
record
.
text_type_ids
for
record
in
batch_records
]
batch_position_ids
=
[
record
.
position_ids
for
record
in
batch_records
]
batch_label_ids
=
[
record
.
label_ids
for
record
in
batch_records
]
# padding
padded_token_ids
,
input_mask
,
batch_seq_lens
=
pad_batch_data
(
...
...
@@ -293,65 +342,115 @@ class SequenceLabelReader(BaseReader):
batch_position_ids
,
max_seq_len
=
self
.
max_seq_len
,
pad_idx
=
self
.
pad_id
)
padded_label_ids
=
pad_batch_data
(
batch_label_ids
,
max_seq_len
=
self
.
max_seq_len
,
pad_idx
=
len
(
self
.
label_map
)
-
1
)
return_list
=
[
padded_token_ids
,
padded_position_ids
,
padded_text_type_ids
,
input_mask
,
padded_label_ids
,
batch_seq_lens
]
if
phase
!=
"predict"
:
batch_label_ids
=
[
record
.
label_ids
for
record
in
batch_records
]
padded_label_ids
=
pad_batch_data
(
batch_label_ids
,
max_seq_len
=
self
.
max_seq_len
,
pad_idx
=
len
(
self
.
label_map
)
-
1
)
return_list
=
[
padded_token_ids
,
padded_position_ids
,
padded_text_type_ids
,
input_mask
,
padded_label_ids
,
batch_seq_lens
]
else
:
return_list
=
[
padded_token_ids
,
padded_position_ids
,
padded_text_type_ids
,
input_mask
,
batch_seq_lens
]
return
return_list
def
_reseg_token_label
(
self
,
tokens
,
labels
,
tokenizer
):
if
len
(
tokens
)
!=
len
(
labels
):
raise
ValueError
(
"The length of tokens must be same with labels"
)
ret_tokens
=
[]
ret_labels
=
[]
for
token
,
label
in
zip
(
tokens
,
labels
):
sub_token
=
tokenizer
.
tokenize
(
token
)
if
len
(
sub_token
)
==
0
:
continue
ret_tokens
.
extend
(
sub_token
)
ret_labels
.
append
(
label
)
if
len
(
sub_token
)
<
2
:
continue
sub_label
=
label
if
label
.
startswith
(
"B-"
):
sub_label
=
"I-"
+
label
[
2
:]
ret_labels
.
extend
([
sub_label
]
*
(
len
(
sub_token
)
-
1
))
if
len
(
ret_tokens
)
!=
len
(
labels
):
raise
ValueError
(
"The length of ret_tokens can't match with labels"
)
return
ret_tokens
,
ret_labels
def
_convert_example_to_record
(
self
,
example
,
max_seq_length
,
tokenizer
):
tokens
=
tokenization
.
convert_to_unicode
(
example
.
text_a
).
split
(
u
""
)
labels
=
tokenization
.
convert_to_unicode
(
example
.
label
).
split
(
u
""
)
tokens
,
labels
=
self
.
_reseg_token_label
(
tokens
,
labels
,
tokenizer
)
def
_reseg_token_label
(
self
,
tokens
,
tokenizer
,
phase
,
labels
=
None
):
if
phase
!=
"predict"
:
if
len
(
tokens
)
!=
len
(
labels
):
raise
ValueError
(
"The length of tokens must be same with labels"
)
ret_tokens
=
[]
ret_labels
=
[]
for
token
,
label
in
zip
(
tokens
,
labels
):
sub_token
=
tokenizer
.
tokenize
(
token
)
if
len
(
sub_token
)
==
0
:
continue
ret_tokens
.
extend
(
sub_token
)
ret_labels
.
append
(
label
)
if
len
(
sub_token
)
<
2
:
continue
sub_label
=
label
if
label
.
startswith
(
"B-"
):
sub_label
=
"I-"
+
label
[
2
:]
ret_labels
.
extend
([
sub_label
]
*
(
len
(
sub_token
)
-
1
))
if
len
(
ret_tokens
)
!=
len
(
labels
):
raise
ValueError
(
"The length of ret_tokens can't match with labels"
)
return
ret_tokens
,
ret_labels
else
:
ret_tokens
=
[]
for
token
in
tokens
:
sub_token
=
tokenizer
.
tokenize
(
token
)
if
len
(
sub_token
)
==
0
:
continue
ret_tokens
.
extend
(
sub_token
)
if
len
(
sub_token
)
<
2
:
continue
return
ret_tokens
def
_convert_example_to_record
(
self
,
example
,
max_seq_length
,
tokenizer
,
phase
=
None
):
if
len
(
tokens
)
>
max_seq_length
-
2
:
tokens
=
tokens
[
0
:(
max_seq_length
-
2
)]
labels
=
labels
[
0
:(
max_seq_length
-
2
)]
tokens
=
tokenization
.
convert_to_unicode
(
example
.
text_a
).
split
(
u
""
)
tokens
=
[
"[CLS]"
]
+
tokens
+
[
"[SEP]"
]
token_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
position_ids
=
list
(
range
(
len
(
token_ids
)))
text_type_ids
=
[
0
]
*
len
(
token_ids
)
no_entity_id
=
len
(
self
.
label_map
)
-
1
label_ids
=
[
no_entity_id
]
+
[
self
.
label_map
[
label
]
for
label
in
labels
]
+
[
no_entity_id
]
if
phase
!=
"predict"
:
labels
=
tokenization
.
convert_to_unicode
(
example
.
label
).
split
(
u
""
)
tokens
,
labels
=
self
.
_reseg_token_label
(
tokens
=
tokens
,
labels
=
labels
,
tokenizer
=
tokenizer
,
phase
=
phase
)
if
len
(
tokens
)
>
max_seq_length
-
2
:
tokens
=
tokens
[
0
:(
max_seq_length
-
2
)]
labels
=
labels
[
0
:(
max_seq_length
-
2
)]
tokens
=
[
"[CLS]"
]
+
tokens
+
[
"[SEP]"
]
token_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
position_ids
=
list
(
range
(
len
(
token_ids
)))
text_type_ids
=
[
0
]
*
len
(
token_ids
)
no_entity_id
=
len
(
self
.
label_map
)
-
1
label_ids
=
[
no_entity_id
]
+
[
self
.
label_map
[
label
]
for
label
in
labels
]
+
[
no_entity_id
]
Record
=
namedtuple
(
'Record'
,
[
'token_ids'
,
'text_type_ids'
,
'position_ids'
,
'label_ids'
])
record
=
Record
(
token_ids
=
token_ids
,
text_type_ids
=
text_type_ids
,
position_ids
=
position_ids
,
label_ids
=
label_ids
)
else
:
tokens
=
self
.
_reseg_token_label
(
tokens
=
tokens
,
tokenizer
=
tokenizer
,
phase
=
phase
)
if
len
(
tokens
)
>
max_seq_length
-
2
:
tokens
=
tokens
[
0
:(
max_seq_length
-
2
)]
tokens
=
[
"[CLS]"
]
+
tokens
+
[
"[SEP]"
]
token_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
position_ids
=
list
(
range
(
len
(
token_ids
)))
text_type_ids
=
[
0
]
*
len
(
token_ids
)
Record
=
namedtuple
(
'Record'
,
[
'token_ids'
,
'text_type_ids'
,
'position_ids'
])
record
=
Record
(
token_ids
=
token_ids
,
text_type_ids
=
text_type_ids
,
position_ids
=
position_ids
,
)
Record
=
namedtuple
(
'Record'
,
[
'token_ids'
,
'text_type_ids'
,
'position_ids'
,
'label_ids'
])
record
=
Record
(
token_ids
=
token_ids
,
text_type_ids
=
text_type_ids
,
position_ids
=
position_ids
,
label_ids
=
label_ids
)
return
record
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录