From 1491d0cb25364e24f6630f87eba0ca686f355e3b Mon Sep 17 00:00:00 2001 From: kinghuin Date: Fri, 10 Jan 2020 17:14:05 +0800 Subject: [PATCH] Optimize hook (#317) * optimize hook --- paddlehub/finetune/task/base_task.py | 70 +++++++++++++++------------- 1 file changed, 38 insertions(+), 32 deletions(-) diff --git a/paddlehub/finetune/task/base_task.py b/paddlehub/finetune/task/base_task.py index 51351555..d04ec2bf 100644 --- a/paddlehub/finetune/task/base_task.py +++ b/paddlehub/finetune/task/base_task.py @@ -24,6 +24,7 @@ import copy import logging import inspect from functools import partial +from collections import OrderedDict import six if six.PY2: from inspect import getargspec as get_args @@ -88,44 +89,44 @@ class RunEnv(object): class TaskHooks(): def __init__(self): self._registered_hooks = { - "build_env_start": {}, - "build_env_end": {}, - "finetune_start": {}, - "finetune_end": {}, - "predict_start": {}, - "predict_end": {}, - "eval_start": {}, - "eval_end": {}, - "log_interval": {}, - "save_ckpt_interval": {}, - "eval_interval": {}, - "run_step": {}, + "build_env_start_event": OrderedDict(), + "build_env_end_event": OrderedDict(), + "finetune_start_event": OrderedDict(), + "finetune_end_event": OrderedDict(), + "predict_start_event": OrderedDict(), + "predict_end_event": OrderedDict(), + "eval_start_event": OrderedDict(), + "eval_end_event": OrderedDict(), + "log_interval_event": OrderedDict(), + "save_ckpt_interval_event": OrderedDict(), + "eval_interval_event": OrderedDict(), + "run_step_event": OrderedDict(), } self._hook_params_num = { - "build_env_start": 1, - "build_env_end": 1, - "finetune_start": 1, - "finetune_end": 2, - "predict_start": 1, - "predict_end": 2, - "eval_start": 1, - "eval_end": 2, - "log_interval": 2, - "save_ckpt_interval": 1, - "eval_interval": 1, - "run_step": 2, + "build_env_start_event": 1, + "build_env_end_event": 1, + "finetune_start_event": 1, + "finetune_end_event": 2, + "predict_start_event": 1, + "predict_end_event": 2, + "eval_start_event": 1, + "eval_end_event": 2, + "log_interval_event": 2, + "save_ckpt_interval_event": 1, + "eval_interval_event": 1, + "run_step_event": 2, } def add(self, hook_type, name=None, func=None): if not func or not callable(func): raise TypeError( "The hook function is empty or it is not a function") - if name and not isinstance(name, str): - raise TypeError("The hook name must be a string") - if not name: + if name == None: name = "hook_%s" % id(func) # check validity + if not isinstance(name, str) or name.strip() == "": + raise TypeError("The hook name must be a non-empty string") if hook_type not in self._registered_hooks: raise ValueError("hook_type: %s does not exist" % (hook_type)) if name in self._registered_hooks[hook_type]: @@ -167,13 +168,13 @@ class TaskHooks(): else: return True - def info(self, only_customized=True): + def info(self, show_default=False): # formatted output the source code ret = "" for hook_type, hooks in self._registered_hooks.items(): already_print_type = False for name, func in hooks.items(): - if name == "default" and only_customized: + if name == "default" and not show_default: continue if not already_print_type: ret += "hook_type: %s{\n" % hook_type @@ -186,7 +187,7 @@ class TaskHooks(): if already_print_type: ret += "}\n" if not ret: - ret = "Not any hooks when only_customized=%s" % only_customized + ret = "Not any customized hooks have been defined, you can set show_default=True to see the default hooks information" return ret def __getitem__(self, hook_type): @@ -263,8 +264,8 @@ class BaseTask(object): self._hooks = TaskHooks() for hook_type, event_hooks in self._hooks._registered_hooks.items(): self._hooks.add(hook_type, "default", - eval("self._default_%s_event" % hook_type)) - setattr(BaseTask, "_%s_event" % hook_type, + eval("self._default_%s" % hook_type)) + setattr(BaseTask, "_%s" % hook_type, self.create_event_function(hook_type)) # accelerate predict @@ -585,13 +586,18 @@ class BaseTask(object): return self._hooks.info(only_customized) def add_hook(self, hook_type, name=None, func=None): + if name == None: + name = "hook_%s" % id(func) self._hooks.add(hook_type, name=name, func=func) + logger.info("Add hook %s:%s successfully" % (hook_type, name)) def delete_hook(self, hook_type, name): self._hooks.delete(hook_type, name) + logger.info("Delete hook %s:%s successfully" % (hook_type, name)) def modify_hook(self, hook_type, name, func): self._hooks.modify(hook_type, name, func) + logger.info("Modify hook %s:%s successfully" % (hook_type, name)) def _default_build_env_start_event(self): pass -- GitLab