提交 6f9adbab 编写于 作者: K kinghuin 提交者: wuzewu

Hook (#248)

* add task hook
上级 e9c0af05
......@@ -61,6 +61,7 @@ class Logger(object):
self.logLevel = "DEBUG"
self.logger.setLevel(logging.DEBUG)
self.logger.propagate = False
if os.path.exists(os.path.join(CONF_HOME, "config.json")):
with open(os.path.join(CONF_HOME, "config.json"), "r") as fp:
level = json.load(fp).get("log_level", "DEBUG")
......
......@@ -22,6 +22,9 @@ import contextlib
import time
import copy
import logging
import inspect
from functools import partial
import numpy as np
import paddle.fluid as fluid
from tb_paddle import SummaryWriter
......@@ -78,6 +81,117 @@ class RunEnv(object):
return self.__dict__[key]
class TaskHooks():
def __init__(self):
self._registered_hooks = {
"build_env_start": {},
"build_env_end": {},
"finetune_start": {},
"finetune_end": {},
"predict_start": {},
"predict_end": {},
"eval_start": {},
"eval_end": {},
"log_interval": {},
"save_ckpt_interval": {},
"eval_interval": {},
"run_step": {},
}
self._hook_params_num = {
"build_env_start": 1,
"build_env_end": 1,
"finetune_start": 1,
"finetune_end": 2,
"predict_start": 1,
"predict_end": 2,
"eval_start": 1,
"eval_end": 2,
"log_interval": 2,
"save_ckpt_interval": 1,
"eval_interval": 1,
"run_step": 2,
}
def add(self, hook_type, name=None, func=None):
if not func or not callable(func):
raise TypeError(
"The hook function is empty or it is not a function")
if name and not isinstance(name, str):
raise TypeError("The hook name must be a string")
if not name:
name = "hook_%s" % id(func)
# check validity
if hook_type not in self._registered_hooks:
raise ValueError("hook_type: %s does not exist" % (hook_type))
if name in self._registered_hooks[hook_type]:
raise ValueError(
"name: %s has existed in hook_type:%s, use modify method to modify it"
% (name, hook_type))
else:
args_num = len(inspect.getfullargspec(func).args)
if args_num != self._hook_params_num[hook_type]:
raise ValueError(
"The number of parameters to the hook hook_type:%s should be %i"
% (hook_type, self._hook_params_num[hook_type]))
self._registered_hooks[hook_type][name] = func
def delete(self, hook_type, name):
if self.exist(hook_type, name):
del self._registered_hooks[hook_type][name]
else:
raise ValueError(
"No hook_type: %s exists or name: %s does not exist in hook_type: %s"
% (hook_type, name, hook_type))
def modify(self, hook_type, name, func):
if not (isinstance(name, str) and callable(func)):
raise TypeError(
"The hook name must be a string, and the hook function must be a function"
)
if self.exist(hook_type, name):
self._registered_hooks[hook_type][name] = func
else:
raise ValueError(
"No hook_type: %s exists or name: %s does not exist in hook_type: %s"
% (hook_type, name, hook_type))
def exist(self, hook_type, name):
if hook_type not in self._registered_hooks \
or name not in self._registered_hooks[hook_type]:
return False
else:
return True
def info(self, only_customized=True):
# formatted output the source code
ret = ""
for hook_type, hooks in self._registered_hooks.items():
already_print_type = False
for name, func in hooks.items():
if name == "default" and only_customized:
continue
if not already_print_type:
ret += "hook_type: %s{\n" % hook_type
already_print_type = True
source = inspect.getsource(func)
ret += " name: %s{\n" % name
for line in source.split("\n"):
ret += " %s\n" % line
ret += " }\n"
if already_print_type:
ret += "}\n"
if not ret:
ret = "Not any hooks when only_customized=%s" % only_customized
return ret
def __getitem__(self, hook_type):
return self._registered_hooks[hook_type]
def __repr__(self):
return self.info(only_customized=False)
class BasicTask(object):
def __init__(self,
feed_list,
......@@ -146,6 +260,14 @@ class BasicTask(object):
self._envs = {}
self._predict_data = None
# event hooks
self._hooks = TaskHooks()
for hook_type, event_hooks in self._hooks._registered_hooks.items():
self._hooks.add(hook_type, "default",
eval("self._default_%s_event" % hook_type))
setattr(BasicTask, "_%s_event" % hook_type,
self.create_event_function(hook_type))
# accelerate predict
self.is_best_model_loaded = False
......@@ -261,10 +383,6 @@ class BasicTask(object):
var = self.env.main_program.global_block().vars[var_name]
var.persistable = True
# to avoid to print logger two times in result of the logger usage of paddle-fluid 1.6
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if self.is_train_phase:
with fluid.program_guard(self.env.main_program,
self._base_startup_program):
......@@ -441,29 +559,55 @@ class BasicTask(object):
return [metric.name for metric in self.metrics] + [self.loss.name]
return [output.name for output in self.outputs]
def _build_env_start_event(self):
def create_event_function(self, hook_type):
def hook_function(self, *args):
for name, func in self._hooks[hook_type].items():
if inspect.ismethod(func):
func(*args)
else:
partial(func, self)(*args)
return hook_function
@property
def hooks(self):
return self._hooks
def hooks_info(self, only_customized=True):
return self._hooks.info(only_customized)
def add_hook(self, hook_type, name=None, func=None):
self._hooks.add(hook_type, name=name, func=func)
def delete_hook(self, hook_type, name):
self._hooks.delete(hook_type, name)
def modify_hook(self, hook_type, name, func):
self._hooks.modify(hook_type, name, func)
def _default_build_env_start_event(self):
pass
def _build_env_end_event(self):
def _default_build_env_end_event(self):
if not self.is_predict_phase:
self.env.score_scalar = {}
def _finetune_start_event(self):
logger.train("PaddleHub finetune start")
def _default_finetune_start_event(self):
logger.info("PaddleHub finetune start")
def _finetune_end_event(self, run_states):
def _default_finetune_end_event(self, run_states):
logger.info("PaddleHub finetune finished.")
def _predict_start_event(self):
def _default_predict_start_event(self):
logger.info("PaddleHub predict start")
def _predict_end_event(self, run_states):
def _default_predict_end_event(self, run_states):
logger.info("PaddleHub predict finished.")
def _eval_start_event(self):
logger.eval("Evaluation on {} dataset start".format(self.phase))
def _default_eval_start_event(self):
logger.info("Evaluation on {} dataset start".format(self.phase))
def _eval_end_event(self, run_states):
def _default_eval_end_event(self, run_states):
eval_scores, eval_loss, run_speed = self._calculate_metrics(run_states)
if 'train' in self._envs:
self.tb_writer.add_scalar(
......@@ -505,7 +649,7 @@ class BasicTask(object):
dirname=model_saved_dir,
main_program=self.main_program)
def _log_interval_event(self, run_states):
def _default_log_interval_event(self, run_states):
scores, avg_loss, run_speed = self._calculate_metrics(run_states)
self.tb_writer.add_scalar(
tag="Loss_{}".format(self.phase),
......@@ -522,15 +666,14 @@ class BasicTask(object):
(self.current_step, self.max_train_steps, avg_loss,
log_scores, run_speed))
def _save_ckpt_interval_event(self):
def _default_save_ckpt_interval_event(self):
self.save_checkpoint()
def _eval_interval_event(self):
def _default_eval_interval_event(self):
self.eval(phase="dev")
def _run_step_event(self, run_state):
if self.is_predict_phase:
yield run_state.run_results
def _default_run_step_event(self, run_state):
pass
def _build_net(self):
raise NotImplementedError
......
......@@ -575,7 +575,7 @@ class ReadingComprehensionTask(BasicTask):
scores = cmrc2018_evaluate.get_eval(dataset, predictions)
return scores, avg_loss, run_speed
def _predict_end_event(self, run_states):
def _default_predict_end_event(self, run_states):
all_results = []
RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册