提交 1491d0cb 编写于 作者: K kinghuin 提交者: wuzewu

Optimize hook (#317)

* optimize hook
上级 f378a281
......@@ -24,6 +24,7 @@ import copy
import logging
import inspect
from functools import partial
from collections import OrderedDict
import six
if six.PY2:
from inspect import getargspec as get_args
......@@ -88,44 +89,44 @@ class RunEnv(object):
class TaskHooks():
def __init__(self):
self._registered_hooks = {
"build_env_start": {},
"build_env_end": {},
"finetune_start": {},
"finetune_end": {},
"predict_start": {},
"predict_end": {},
"eval_start": {},
"eval_end": {},
"log_interval": {},
"save_ckpt_interval": {},
"eval_interval": {},
"run_step": {},
"build_env_start_event": OrderedDict(),
"build_env_end_event": OrderedDict(),
"finetune_start_event": OrderedDict(),
"finetune_end_event": OrderedDict(),
"predict_start_event": OrderedDict(),
"predict_end_event": OrderedDict(),
"eval_start_event": OrderedDict(),
"eval_end_event": OrderedDict(),
"log_interval_event": OrderedDict(),
"save_ckpt_interval_event": OrderedDict(),
"eval_interval_event": OrderedDict(),
"run_step_event": OrderedDict(),
}
self._hook_params_num = {
"build_env_start": 1,
"build_env_end": 1,
"finetune_start": 1,
"finetune_end": 2,
"predict_start": 1,
"predict_end": 2,
"eval_start": 1,
"eval_end": 2,
"log_interval": 2,
"save_ckpt_interval": 1,
"eval_interval": 1,
"run_step": 2,
"build_env_start_event": 1,
"build_env_end_event": 1,
"finetune_start_event": 1,
"finetune_end_event": 2,
"predict_start_event": 1,
"predict_end_event": 2,
"eval_start_event": 1,
"eval_end_event": 2,
"log_interval_event": 2,
"save_ckpt_interval_event": 1,
"eval_interval_event": 1,
"run_step_event": 2,
}
def add(self, hook_type, name=None, func=None):
if not func or not callable(func):
raise TypeError(
"The hook function is empty or it is not a function")
if name and not isinstance(name, str):
raise TypeError("The hook name must be a string")
if not name:
if name == None:
name = "hook_%s" % id(func)
# check validity
if not isinstance(name, str) or name.strip() == "":
raise TypeError("The hook name must be a non-empty string")
if hook_type not in self._registered_hooks:
raise ValueError("hook_type: %s does not exist" % (hook_type))
if name in self._registered_hooks[hook_type]:
......@@ -167,13 +168,13 @@ class TaskHooks():
else:
return True
def info(self, only_customized=True):
def info(self, show_default=False):
# formatted output the source code
ret = ""
for hook_type, hooks in self._registered_hooks.items():
already_print_type = False
for name, func in hooks.items():
if name == "default" and only_customized:
if name == "default" and not show_default:
continue
if not already_print_type:
ret += "hook_type: %s{\n" % hook_type
......@@ -186,7 +187,7 @@ class TaskHooks():
if already_print_type:
ret += "}\n"
if not ret:
ret = "Not any hooks when only_customized=%s" % only_customized
ret = "Not any customized hooks have been defined, you can set show_default=True to see the default hooks information"
return ret
def __getitem__(self, hook_type):
......@@ -263,8 +264,8 @@ class BaseTask(object):
self._hooks = TaskHooks()
for hook_type, event_hooks in self._hooks._registered_hooks.items():
self._hooks.add(hook_type, "default",
eval("self._default_%s_event" % hook_type))
setattr(BaseTask, "_%s_event" % hook_type,
eval("self._default_%s" % hook_type))
setattr(BaseTask, "_%s" % hook_type,
self.create_event_function(hook_type))
# accelerate predict
......@@ -585,13 +586,18 @@ class BaseTask(object):
return self._hooks.info(only_customized)
def add_hook(self, hook_type, name=None, func=None):
if name == None:
name = "hook_%s" % id(func)
self._hooks.add(hook_type, name=name, func=func)
logger.info("Add hook %s:%s successfully" % (hook_type, name))
def delete_hook(self, hook_type, name):
self._hooks.delete(hook_type, name)
logger.info("Delete hook %s:%s successfully" % (hook_type, name))
def modify_hook(self, hook_type, name, func):
self._hooks.modify(hook_type, name, func)
logger.info("Modify hook %s:%s successfully" % (hook_type, name))
def _default_build_env_start_event(self):
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册