提交 1491d0cb 编写于 作者: K kinghuin 提交者: wuzewu

Optimize hook (#317)

* optimize hook
上级 f378a281
...@@ -24,6 +24,7 @@ import copy ...@@ -24,6 +24,7 @@ import copy
import logging import logging
import inspect import inspect
from functools import partial from functools import partial
from collections import OrderedDict
import six import six
if six.PY2: if six.PY2:
from inspect import getargspec as get_args from inspect import getargspec as get_args
...@@ -88,44 +89,44 @@ class RunEnv(object): ...@@ -88,44 +89,44 @@ class RunEnv(object):
class TaskHooks(): class TaskHooks():
def __init__(self): def __init__(self):
self._registered_hooks = { self._registered_hooks = {
"build_env_start": {}, "build_env_start_event": OrderedDict(),
"build_env_end": {}, "build_env_end_event": OrderedDict(),
"finetune_start": {}, "finetune_start_event": OrderedDict(),
"finetune_end": {}, "finetune_end_event": OrderedDict(),
"predict_start": {}, "predict_start_event": OrderedDict(),
"predict_end": {}, "predict_end_event": OrderedDict(),
"eval_start": {}, "eval_start_event": OrderedDict(),
"eval_end": {}, "eval_end_event": OrderedDict(),
"log_interval": {}, "log_interval_event": OrderedDict(),
"save_ckpt_interval": {}, "save_ckpt_interval_event": OrderedDict(),
"eval_interval": {}, "eval_interval_event": OrderedDict(),
"run_step": {}, "run_step_event": OrderedDict(),
} }
self._hook_params_num = { self._hook_params_num = {
"build_env_start": 1, "build_env_start_event": 1,
"build_env_end": 1, "build_env_end_event": 1,
"finetune_start": 1, "finetune_start_event": 1,
"finetune_end": 2, "finetune_end_event": 2,
"predict_start": 1, "predict_start_event": 1,
"predict_end": 2, "predict_end_event": 2,
"eval_start": 1, "eval_start_event": 1,
"eval_end": 2, "eval_end_event": 2,
"log_interval": 2, "log_interval_event": 2,
"save_ckpt_interval": 1, "save_ckpt_interval_event": 1,
"eval_interval": 1, "eval_interval_event": 1,
"run_step": 2, "run_step_event": 2,
} }
def add(self, hook_type, name=None, func=None): def add(self, hook_type, name=None, func=None):
if not func or not callable(func): if not func or not callable(func):
raise TypeError( raise TypeError(
"The hook function is empty or it is not a function") "The hook function is empty or it is not a function")
if name and not isinstance(name, str): if name == None:
raise TypeError("The hook name must be a string")
if not name:
name = "hook_%s" % id(func) name = "hook_%s" % id(func)
# check validity # check validity
if not isinstance(name, str) or name.strip() == "":
raise TypeError("The hook name must be a non-empty string")
if hook_type not in self._registered_hooks: if hook_type not in self._registered_hooks:
raise ValueError("hook_type: %s does not exist" % (hook_type)) raise ValueError("hook_type: %s does not exist" % (hook_type))
if name in self._registered_hooks[hook_type]: if name in self._registered_hooks[hook_type]:
...@@ -167,13 +168,13 @@ class TaskHooks(): ...@@ -167,13 +168,13 @@ class TaskHooks():
else: else:
return True return True
def info(self, only_customized=True): def info(self, show_default=False):
# formatted output the source code # formatted output the source code
ret = "" ret = ""
for hook_type, hooks in self._registered_hooks.items(): for hook_type, hooks in self._registered_hooks.items():
already_print_type = False already_print_type = False
for name, func in hooks.items(): for name, func in hooks.items():
if name == "default" and only_customized: if name == "default" and not show_default:
continue continue
if not already_print_type: if not already_print_type:
ret += "hook_type: %s{\n" % hook_type ret += "hook_type: %s{\n" % hook_type
...@@ -186,7 +187,7 @@ class TaskHooks(): ...@@ -186,7 +187,7 @@ class TaskHooks():
if already_print_type: if already_print_type:
ret += "}\n" ret += "}\n"
if not ret: if not ret:
ret = "Not any hooks when only_customized=%s" % only_customized ret = "Not any customized hooks have been defined, you can set show_default=True to see the default hooks information"
return ret return ret
def __getitem__(self, hook_type): def __getitem__(self, hook_type):
...@@ -263,8 +264,8 @@ class BaseTask(object): ...@@ -263,8 +264,8 @@ class BaseTask(object):
self._hooks = TaskHooks() self._hooks = TaskHooks()
for hook_type, event_hooks in self._hooks._registered_hooks.items(): for hook_type, event_hooks in self._hooks._registered_hooks.items():
self._hooks.add(hook_type, "default", self._hooks.add(hook_type, "default",
eval("self._default_%s_event" % hook_type)) eval("self._default_%s" % hook_type))
setattr(BaseTask, "_%s_event" % hook_type, setattr(BaseTask, "_%s" % hook_type,
self.create_event_function(hook_type)) self.create_event_function(hook_type))
# accelerate predict # accelerate predict
...@@ -585,13 +586,18 @@ class BaseTask(object): ...@@ -585,13 +586,18 @@ class BaseTask(object):
return self._hooks.info(only_customized) return self._hooks.info(only_customized)
def add_hook(self, hook_type, name=None, func=None): def add_hook(self, hook_type, name=None, func=None):
if name == None:
name = "hook_%s" % id(func)
self._hooks.add(hook_type, name=name, func=func) self._hooks.add(hook_type, name=name, func=func)
logger.info("Add hook %s:%s successfully" % (hook_type, name))
def delete_hook(self, hook_type, name): def delete_hook(self, hook_type, name):
self._hooks.delete(hook_type, name) self._hooks.delete(hook_type, name)
logger.info("Delete hook %s:%s successfully" % (hook_type, name))
def modify_hook(self, hook_type, name, func): def modify_hook(self, hook_type, name, func):
self._hooks.modify(hook_type, name, func) self._hooks.modify(hook_type, name, func)
logger.info("Modify hook %s:%s successfully" % (hook_type, name))
def _default_build_env_start_event(self): def _default_build_env_start_event(self):
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册