Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
68b2f915
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
8 个月 前同步成功
通知
280
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
68b2f915
编写于
9月 17, 2020
作者:
K
kinghuin
提交者:
GitHub
9月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update hook.md
上级
2046b95c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
27 addition
and
27 deletion
+27
-27
docs/Secondary_development/hook.md
docs/Secondary_development/hook.md
+27
-27
未找到文件。
docs/Secondary_development/hook.md
浏览文件 @
68b2f915
...
...
@@ -13,15 +13,15 @@ Task定义了[组网事件](./how_to_define_task.md)和[运行事件](./how_to_d
*
"finetune_start_event","finetune_end_event","predict_start_event","predict_end_event",
"eval_start_event","eval_end_event"等事件是用于打印相应阶段的日志信息。"save_ckpt_interval_event"事件用于保存当前训练的模型参数。"log_interval_event"事件用于计算模型评价指标以及可视化这些指标。
如果您需要对图中提到的事件的具体实现进行修改,可以通过Task提供的事件回调hook机制进行改写。
如你想要改变任务评价指标,如下示例中将PaddleHub默认的accuracy评价指标改为F1评价指标。同时还想用自定义的可视化工具可视化模型训练过程,如下示例将可视化工具改写为tb-paddle。则你需要改写评估方法log_interval_event。这时候你可以用Hook实现。具体使用方法如下:
如果您需要对图中提到的事件的具体实现进行修改,可以通过Task提供的事件回调hook机制进行改写。如下示例中将PaddleHub log_interval_event默认的accuracy评价指标改为F1评价指标:
```
python
import
time
from
collections
import
OrderedDict
import
numpy
as
np
import
paddlehub
as
hub
def
calculate_f1_np
(
preds
,
labels
):
# 计算F1分数
...
...
@@ -70,33 +70,32 @@ def calculate_metrics(self, run_states):
return
scores
,
avg_loss
,
run_speed
# 利用自定义可视化工具tb-paddle记录训练过程中的损失值,评估指标等
from
tb_paddle
import
SummaryWriter
tb_writer
=
SummaryWriter
(
"PATH/TO/LOG"
)
def
record_value
(
evaluation_scores
,
loss
,
s
)
tb_writer
.
add_scalar
(
tag
=
"Loss_{}"
.
format
(
self
.
phase
),
scalar_value
=
loss
,
global_step
=
self
.
_envs
[
'train'
].
current_step
)
log_scores
=
""
for
metric
in
evaluation_scores
:
self
.
tb_writer
.
add_scalar
(
tag
=
"{}_{}"
.
format
(
metric
,
self
.
phase
),
scalar_value
=
scores
[
metric
],
global_step
=
self
.
_envs
[
'train'
].
current_step
)
log_scores
+=
"%s=%.5f "
%
(
metric
,
scores
[
metric
])
print
(
"step %d / %d: loss=%.5f %s[step/sec: %.2f]"
%
(
self
.
current_step
,
self
.
max_train_steps
,
avg_loss
,
log_scores
,
run_speed
))
# 改写_log_interval_event实现
def
new_log_interval_event
(
self
,
run_states
):
# 改写的事件方法,参数列表务必与PaddleHub内置的相应方法保持一致
print
(
"This is the new log_interval_event!"
)
scores
,
avg_loss
,
run_speed
=
calculate_metrics
(
self
,
run_states
)
record_value
(
scores
,
avg_loss
,
run_speed
)
formatted_scores
=
", "
.
join
([
"%s: %.5f"
%
(
key
,
value
)
for
key
,
value
in
scores
.
items
()])
print
(
"[new_log_interval_event] step %d / %d: loss=%.5f %s[step/sec: %.2f]"
%
(
self
.
current_step
,
self
.
max_train_steps
,
avg_loss
,
formatted_scores
,
run_speed
))
# 最简单的PaddleHub运行样例
module
=
hub
.
Module
(
name
=
"ernie_tiny"
)
inputs
,
outputs
,
program
=
module
.
context
(
trainable
=
True
,
max_seq_len
=
128
)
tokenizer
=
hub
.
ErnieTinyTokenizer
(
vocab_file
=
module
.
get_vocab_path
(),
spm_path
=
module
.
get_spm_path
(),
word_dict_path
=
module
.
get_word_dict_path
())
dataset
=
hub
.
dataset
.
ChnSentiCorp
(
tokenizer
=
tokenizer
,
max_seq_len
=
128
)
task
=
hub
.
TextClassifierTask
(
dataset
=
dataset
,
feature
=
outputs
[
"pooled_output"
],
num_classes
=
dataset
.
num_labels
,
)
# 利用Hook改写PaddleHub内置_log_interval_event实现,需要2步(假设task已经创建好)
# 1.删除PaddleHub内置_log_interval_event实现
...
...
@@ -111,12 +110,13 @@ task.delete_hook(hook_type="log_interval_event", name="default")
task
.
add_hook
(
hook_type
=
"log_interval_event"
,
name
=
"new_log_interval_event"
,
func
=
new_log_interval_event
)
# 输出hook信息
task
.
hook_info
()
print
(
task
.
hooks_info
())
task
.
finetune_and_eval
()
```
**NOTE:**
*
关于上述提到的run_states参见
[
RunEnv说明
](
../reference/task/runenv.md
)
*
tb-paddle详细信息参见
[
官方文档
](
https://github.com/ShenYuhan/tb-paddle
)
*
改写的事件方法,参数列表务必与PaddleHub内置的相应方法保持一致。
*
只支持改写/删除以下事件hook类型:
"build_env_start_event","build_env_end_event","finetune_start_event","finetune_end_event",
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录