未验证 提交 68b2f915 编写于 作者: K kinghuin 提交者: GitHub

Update hook.md

上级 2046b95c
......@@ -13,15 +13,15 @@ Task定义了[组网事件](./how_to_define_task.md)和[运行事件](./how_to_d
* "finetune_start_event","finetune_end_event","predict_start_event","predict_end_event",
"eval_start_event","eval_end_event"等事件是用于打印相应阶段的日志信息。"save_ckpt_interval_event"事件用于保存当前训练的模型参数。"log_interval_event"事件用于计算模型评价指标以及可视化这些指标。
如果您需要对图中提到的事件的具体实现进行修改,可以通过Task提供的事件回调hook机制进行改写。
如你想要改变任务评价指标,如下示例中将PaddleHub默认的accuracy评价指标改为F1评价指标。同时还想用自定义的可视化工具可视化模型训练过程,如下示例将可视化工具改写为tb-paddle。则你需要改写评估方法log_interval_event。这时候你可以用Hook实现。具体使用方法如下:
如果您需要对图中提到的事件的具体实现进行修改,可以通过Task提供的事件回调hook机制进行改写。如下示例中将PaddleHub log_interval_event默认的accuracy评价指标改为F1评价指标:
```python
import time
from collections import OrderedDict
import numpy as np
import paddlehub as hub
def calculate_f1_np(preds, labels):
# 计算F1分数
......@@ -70,33 +70,32 @@ def calculate_metrics(self, run_states):
return scores, avg_loss, run_speed
# 利用自定义可视化工具tb-paddle记录训练过程中的损失值,评估指标等
from tb_paddle import SummaryWriter
tb_writer = SummaryWriter("PATH/TO/LOG")
def record_value(evaluation_scores, loss, s)
tb_writer.add_scalar(
tag="Loss_{}".format(self.phase),
scalar_value=loss,
global_step=self._envs['train'].current_step)
log_scores = ""
for metric in evaluation_scores:
self.tb_writer.add_scalar(
tag="{}_{}".format(metric, self.phase),
scalar_value=scores[metric],
global_step=self._envs['train'].current_step)
log_scores += "%s=%.5f " % (metric, scores[metric])
print("step %d / %d: loss=%.5f %s[step/sec: %.2f]" %
(self.current_step, self.max_train_steps, avg_loss,
log_scores, run_speed))
# 改写_log_interval_event实现
def new_log_interval_event(self, run_states):
# 改写的事件方法,参数列表务必与PaddleHub内置的相应方法保持一致
print("This is the new log_interval_event!")
scores, avg_loss, run_speed = calculate_metrics(self, run_states)
record_value(scores, avg_loss, run_speed)
formatted_scores = ", ".join(["%s: %.5f"%(key, value) for key, value in scores.items()])
print("[new_log_interval_event] step %d / %d: loss=%.5f %s[step/sec: %.2f]" %
(self.current_step, self.max_train_steps, avg_loss,
formatted_scores, run_speed))
# 最简单的PaddleHub运行样例
module = hub.Module(name="ernie_tiny")
inputs, outputs, program = module.context(
trainable=True, max_seq_len=128)
tokenizer = hub.ErnieTinyTokenizer(
vocab_file=module.get_vocab_path(),
spm_path=module.get_spm_path(),
word_dict_path=module.get_word_dict_path())
dataset = hub.dataset.ChnSentiCorp(
tokenizer=tokenizer, max_seq_len=128)
task = hub.TextClassifierTask(
dataset=dataset,
feature=outputs["pooled_output"],
num_classes=dataset.num_labels,
)
# 利用Hook改写PaddleHub内置_log_interval_event实现,需要2步(假设task已经创建好)
# 1.删除PaddleHub内置_log_interval_event实现
......@@ -111,12 +110,13 @@ task.delete_hook(hook_type="log_interval_event", name="default")
task.add_hook(hook_type="log_interval_event", name="new_log_interval_event", func=new_log_interval_event)
# 输出hook信息
task.hook_info()
print(task.hooks_info())
task.finetune_and_eval()
```
**NOTE:**
* 关于上述提到的run_states参见[RunEnv说明](../reference/task/runenv.md)
* tb-paddle详细信息参见[官方文档](https://github.com/ShenYuhan/tb-paddle)
* 改写的事件方法,参数列表务必与PaddleHub内置的相应方法保持一致。
* 只支持改写/删除以下事件hook类型:
"build_env_start_event","build_env_end_event","finetune_start_event","finetune_end_event",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册