提交 6f9adbab 编写于 作者: K kinghuin 提交者: wuzewu

Hook (#248)

* add task hook
上级 e9c0af05
...@@ -61,6 +61,7 @@ class Logger(object): ...@@ -61,6 +61,7 @@ class Logger(object):
self.logLevel = "DEBUG" self.logLevel = "DEBUG"
self.logger.setLevel(logging.DEBUG) self.logger.setLevel(logging.DEBUG)
self.logger.propagate = False self.logger.propagate = False
if os.path.exists(os.path.join(CONF_HOME, "config.json")): if os.path.exists(os.path.join(CONF_HOME, "config.json")):
with open(os.path.join(CONF_HOME, "config.json"), "r") as fp: with open(os.path.join(CONF_HOME, "config.json"), "r") as fp:
level = json.load(fp).get("log_level", "DEBUG") level = json.load(fp).get("log_level", "DEBUG")
......
...@@ -22,6 +22,9 @@ import contextlib ...@@ -22,6 +22,9 @@ import contextlib
import time import time
import copy import copy
import logging import logging
import inspect
from functools import partial
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from tb_paddle import SummaryWriter from tb_paddle import SummaryWriter
...@@ -78,6 +81,117 @@ class RunEnv(object): ...@@ -78,6 +81,117 @@ class RunEnv(object):
return self.__dict__[key] return self.__dict__[key]
class TaskHooks():
def __init__(self):
self._registered_hooks = {
"build_env_start": {},
"build_env_end": {},
"finetune_start": {},
"finetune_end": {},
"predict_start": {},
"predict_end": {},
"eval_start": {},
"eval_end": {},
"log_interval": {},
"save_ckpt_interval": {},
"eval_interval": {},
"run_step": {},
}
self._hook_params_num = {
"build_env_start": 1,
"build_env_end": 1,
"finetune_start": 1,
"finetune_end": 2,
"predict_start": 1,
"predict_end": 2,
"eval_start": 1,
"eval_end": 2,
"log_interval": 2,
"save_ckpt_interval": 1,
"eval_interval": 1,
"run_step": 2,
}
def add(self, hook_type, name=None, func=None):
if not func or not callable(func):
raise TypeError(
"The hook function is empty or it is not a function")
if name and not isinstance(name, str):
raise TypeError("The hook name must be a string")
if not name:
name = "hook_%s" % id(func)
# check validity
if hook_type not in self._registered_hooks:
raise ValueError("hook_type: %s does not exist" % (hook_type))
if name in self._registered_hooks[hook_type]:
raise ValueError(
"name: %s has existed in hook_type:%s, use modify method to modify it"
% (name, hook_type))
else:
args_num = len(inspect.getfullargspec(func).args)
if args_num != self._hook_params_num[hook_type]:
raise ValueError(
"The number of parameters to the hook hook_type:%s should be %i"
% (hook_type, self._hook_params_num[hook_type]))
self._registered_hooks[hook_type][name] = func
def delete(self, hook_type, name):
if self.exist(hook_type, name):
del self._registered_hooks[hook_type][name]
else:
raise ValueError(
"No hook_type: %s exists or name: %s does not exist in hook_type: %s"
% (hook_type, name, hook_type))
def modify(self, hook_type, name, func):
if not (isinstance(name, str) and callable(func)):
raise TypeError(
"The hook name must be a string, and the hook function must be a function"
)
if self.exist(hook_type, name):
self._registered_hooks[hook_type][name] = func
else:
raise ValueError(
"No hook_type: %s exists or name: %s does not exist in hook_type: %s"
% (hook_type, name, hook_type))
def exist(self, hook_type, name):
if hook_type not in self._registered_hooks \
or name not in self._registered_hooks[hook_type]:
return False
else:
return True
def info(self, only_customized=True):
# formatted output the source code
ret = ""
for hook_type, hooks in self._registered_hooks.items():
already_print_type = False
for name, func in hooks.items():
if name == "default" and only_customized:
continue
if not already_print_type:
ret += "hook_type: %s{\n" % hook_type
already_print_type = True
source = inspect.getsource(func)
ret += " name: %s{\n" % name
for line in source.split("\n"):
ret += " %s\n" % line
ret += " }\n"
if already_print_type:
ret += "}\n"
if not ret:
ret = "Not any hooks when only_customized=%s" % only_customized
return ret
def __getitem__(self, hook_type):
return self._registered_hooks[hook_type]
def __repr__(self):
return self.info(only_customized=False)
class BasicTask(object): class BasicTask(object):
def __init__(self, def __init__(self,
feed_list, feed_list,
...@@ -146,6 +260,14 @@ class BasicTask(object): ...@@ -146,6 +260,14 @@ class BasicTask(object):
self._envs = {} self._envs = {}
self._predict_data = None self._predict_data = None
# event hooks
self._hooks = TaskHooks()
for hook_type, event_hooks in self._hooks._registered_hooks.items():
self._hooks.add(hook_type, "default",
eval("self._default_%s_event" % hook_type))
setattr(BasicTask, "_%s_event" % hook_type,
self.create_event_function(hook_type))
# accelerate predict # accelerate predict
self.is_best_model_loaded = False self.is_best_model_loaded = False
...@@ -261,10 +383,6 @@ class BasicTask(object): ...@@ -261,10 +383,6 @@ class BasicTask(object):
var = self.env.main_program.global_block().vars[var_name] var = self.env.main_program.global_block().vars[var_name]
var.persistable = True var.persistable = True
# to avoid to print logger two times in result of the logger usage of paddle-fluid 1.6
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if self.is_train_phase: if self.is_train_phase:
with fluid.program_guard(self.env.main_program, with fluid.program_guard(self.env.main_program,
self._base_startup_program): self._base_startup_program):
...@@ -441,29 +559,55 @@ class BasicTask(object): ...@@ -441,29 +559,55 @@ class BasicTask(object):
return [metric.name for metric in self.metrics] + [self.loss.name] return [metric.name for metric in self.metrics] + [self.loss.name]
return [output.name for output in self.outputs] return [output.name for output in self.outputs]
def _build_env_start_event(self): def create_event_function(self, hook_type):
def hook_function(self, *args):
for name, func in self._hooks[hook_type].items():
if inspect.ismethod(func):
func(*args)
else:
partial(func, self)(*args)
return hook_function
@property
def hooks(self):
return self._hooks
def hooks_info(self, only_customized=True):
return self._hooks.info(only_customized)
def add_hook(self, hook_type, name=None, func=None):
self._hooks.add(hook_type, name=name, func=func)
def delete_hook(self, hook_type, name):
self._hooks.delete(hook_type, name)
def modify_hook(self, hook_type, name, func):
self._hooks.modify(hook_type, name, func)
def _default_build_env_start_event(self):
pass pass
def _build_env_end_event(self): def _default_build_env_end_event(self):
if not self.is_predict_phase: if not self.is_predict_phase:
self.env.score_scalar = {} self.env.score_scalar = {}
def _finetune_start_event(self): def _default_finetune_start_event(self):
logger.train("PaddleHub finetune start") logger.info("PaddleHub finetune start")
def _finetune_end_event(self, run_states): def _default_finetune_end_event(self, run_states):
logger.info("PaddleHub finetune finished.") logger.info("PaddleHub finetune finished.")
def _predict_start_event(self): def _default_predict_start_event(self):
logger.info("PaddleHub predict start") logger.info("PaddleHub predict start")
def _predict_end_event(self, run_states): def _default_predict_end_event(self, run_states):
logger.info("PaddleHub predict finished.") logger.info("PaddleHub predict finished.")
def _eval_start_event(self): def _default_eval_start_event(self):
logger.eval("Evaluation on {} dataset start".format(self.phase)) logger.info("Evaluation on {} dataset start".format(self.phase))
def _eval_end_event(self, run_states): def _default_eval_end_event(self, run_states):
eval_scores, eval_loss, run_speed = self._calculate_metrics(run_states) eval_scores, eval_loss, run_speed = self._calculate_metrics(run_states)
if 'train' in self._envs: if 'train' in self._envs:
self.tb_writer.add_scalar( self.tb_writer.add_scalar(
...@@ -505,7 +649,7 @@ class BasicTask(object): ...@@ -505,7 +649,7 @@ class BasicTask(object):
dirname=model_saved_dir, dirname=model_saved_dir,
main_program=self.main_program) main_program=self.main_program)
def _log_interval_event(self, run_states): def _default_log_interval_event(self, run_states):
scores, avg_loss, run_speed = self._calculate_metrics(run_states) scores, avg_loss, run_speed = self._calculate_metrics(run_states)
self.tb_writer.add_scalar( self.tb_writer.add_scalar(
tag="Loss_{}".format(self.phase), tag="Loss_{}".format(self.phase),
...@@ -522,15 +666,14 @@ class BasicTask(object): ...@@ -522,15 +666,14 @@ class BasicTask(object):
(self.current_step, self.max_train_steps, avg_loss, (self.current_step, self.max_train_steps, avg_loss,
log_scores, run_speed)) log_scores, run_speed))
def _save_ckpt_interval_event(self): def _default_save_ckpt_interval_event(self):
self.save_checkpoint() self.save_checkpoint()
def _eval_interval_event(self): def _default_eval_interval_event(self):
self.eval(phase="dev") self.eval(phase="dev")
def _run_step_event(self, run_state): def _default_run_step_event(self, run_state):
if self.is_predict_phase: pass
yield run_state.run_results
def _build_net(self): def _build_net(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -575,7 +575,7 @@ class ReadingComprehensionTask(BasicTask): ...@@ -575,7 +575,7 @@ class ReadingComprehensionTask(BasicTask):
scores = cmrc2018_evaluate.get_eval(dataset, predictions) scores = cmrc2018_evaluate.get_eval(dataset, predictions)
return scores, avg_loss, run_speed return scores, avg_loss, run_speed
def _predict_end_event(self, run_states): def _default_predict_end_event(self, run_states):
all_results = [] all_results = []
RawResult = collections.namedtuple( RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"]) "RawResult", ["unique_id", "start_logits", "end_logits"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册