diff --git a/paddlehub/finetune/task/basic_task.py b/paddlehub/finetune/task/basic_task.py index 88a1878f22d321e95e6715edfb71d47b2128f101..6162b65fa2e5774c3915f74e73afdf278626256b 100644 --- a/paddlehub/finetune/task/basic_task.py +++ b/paddlehub/finetune/task/basic_task.py @@ -22,6 +22,9 @@ import contextlib import time import copy import logging +import inspect +from functools import partial + import numpy as np import paddle.fluid as fluid from tb_paddle import SummaryWriter @@ -78,6 +81,117 @@ class RunEnv(object): return self.__dict__[key] +class TaskHooks(): + def __init__(self): + self._registered_hooks = { + "build_env_start": {}, + "build_env_end": {}, + "finetune_start": {}, + "finetune_end": {}, + "predict_start": {}, + "predict_end": {}, + "eval_start": {}, + "eval_end": {}, + "log_interval": {}, + "save_ckpt_interval": {}, + "eval_interval": {}, + "run_step": {}, + } + self._hook_params_num = { + "build_env_start": 1, + "build_env_end": 1, + "finetune_start": 1, + "finetune_end": 2, + "predict_start": 1, + "predict_end": 2, + "eval_start": 1, + "eval_end": 2, + "log_interval": 2, + "save_ckpt_interval": 1, + "eval_interval": 1, + "run_step": 2, + } + + def add(self, hook_type, name=None, func=None): + if not func or not callable(func): + raise TypeError( + "The hook function is empty or it is not a function") + if name and not isinstance(name, str): + raise TypeError("The hook name must be a string") + if not name: + name = "hook_%s" % id(func) + + # check validity + if hook_type not in self._registered_hooks: + raise ValueError("hook_type: %s does not exist" % (hook_type)) + if name in self._registered_hooks[hook_type]: + raise ValueError( + "name: %s has existed in hook_type:%s, use modify method to modify it" + % (name, hook_type)) + else: + args_num = len(inspect.getfullargspec(func).args) + if args_num != self._hook_params_num[hook_type]: + raise ValueError( + "The number of parameters to the hook hook_type:%s should be %i" + % (hook_type, self._hook_params_num[hook_type])) + self._registered_hooks[hook_type][name] = func + + def delete(self, hook_type, name): + if self.exist(hook_type, name): + del self._registered_hooks[hook_type][name] + else: + raise ValueError( + "No hook_type: %s exists or name: %s does not exist in hook_type: %s" + % (hook_type, name, hook_type)) + + def modify(self, hook_type, name, func): + if not (isinstance(name, str) and callable(func)): + raise TypeError( + "The hook name must be a string, and the hook function must be a function" + ) + if self.exist(hook_type, name): + self._registered_hooks[hook_type][name] = func + else: + raise ValueError( + "No hook_type: %s exists or name: %s does not exist in hook_type: %s" + % (hook_type, name, hook_type)) + + def exist(self, hook_type, name): + if hook_type not in self._registered_hooks \ + or name not in self._registered_hooks[hook_type]: + return False + else: + return True + + def info(self, only_customized=True): + # formatted output the source code + ret = "" + for hook_type, hooks in self._registered_hooks.items(): + already_print_type = False + for name, func in hooks.items(): + if name == "default" and only_customized: + continue + if not already_print_type: + ret += "hook_type: %s{\n" % hook_type + already_print_type = True + source = inspect.getsource(func) + ret += " name: %s{\n" % name + for line in source.split("\n"): + ret += " %s\n" % line + ret += " }\n" + if already_print_type: + ret += "}\n" + if not ret: + ret = "Not any hooks when only_customized=%s" % only_customized + return ret + + def __getitem__(self, hook_type): + return self._registered_hooks[hook_type] + + def __repr__(self): + return self.info(only_customized=False) + + class BasicTask(object): def __init__(self, feed_list, @@ -146,6 +260,14 @@ class BasicTask(object): self._envs = {} self._predict_data = None + # event hooks + self._hooks = TaskHooks() + for hook_type, event_hooks in self._hooks._registered_hooks.items(): + self._hooks.add(hook_type, "default", + eval("self._default_%s_event" % hook_type)) + setattr(BasicTask, "_%s_event" % hook_type, + self.create_event_function(hook_type)) + # accelerate predict self.is_best_model_loaded = False @@ -261,10 +383,6 @@ class BasicTask(object): var = self.env.main_program.global_block().vars[var_name] var.persistable = True - # to avoid to print logger two times in result of the logger usage of paddle-fluid 1.6 - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - if self.is_train_phase: with fluid.program_guard(self.env.main_program, self._base_startup_program): @@ -441,29 +559,55 @@ class BasicTask(object): return [metric.name for metric in self.metrics] + [self.loss.name] return [output.name for output in self.outputs] - def _build_env_start_event(self): + def create_event_function(self, hook_type): + def hook_function(self, *args): + for name, func in self._hooks[hook_type].items(): + if inspect.ismethod(func): + func(*args) + else: + partial(func, self)(*args) + + return hook_function + + @property + def hooks(self): + return self._hooks + + def hooks_info(self, only_customized=True): + return self._hooks.info(only_customized) + + def add_hook(self, hook_type, name=None, func=None): + self._hooks.add(hook_type, name=name, func=func) + + def delete_hook(self, hook_type, name): + self._hooks.delete(hook_type, name) + + def modify_hook(self, hook_type, name, func): + self._hooks.modify(hook_type, name, func) + + def _default_build_env_start_event(self): pass - def _build_env_end_event(self): + def _default_build_env_end_event(self): if not self.is_predict_phase: self.env.score_scalar = {} - def _finetune_start_event(self): - logger.train("PaddleHub finetune start") + def _default_finetune_start_event(self): + logger.info("PaddleHub finetune start") - def _finetune_end_event(self, run_states): + def _default_finetune_end_event(self, run_states): logger.info("PaddleHub finetune finished.") - def _predict_start_event(self): + def _default_predict_start_event(self): logger.info("PaddleHub predict start") - def _predict_end_event(self, run_states): + def _default_predict_end_event(self, run_states): logger.info("PaddleHub predict finished.") - def _eval_start_event(self): - logger.eval("Evaluation on {} dataset start".format(self.phase)) + def _default_eval_start_event(self): + logger.info("Evaluation on {} dataset start".format(self.phase)) - def _eval_end_event(self, run_states): + def _default_eval_end_event(self, run_states): eval_scores, eval_loss, run_speed = self._calculate_metrics(run_states) if 'train' in self._envs: self.tb_writer.add_scalar( @@ -505,7 +649,7 @@ class BasicTask(object): dirname=model_saved_dir, main_program=self.main_program) - def _log_interval_event(self, run_states): + def _default_log_interval_event(self, run_states): scores, avg_loss, run_speed = self._calculate_metrics(run_states) self.tb_writer.add_scalar( tag="Loss_{}".format(self.phase), @@ -522,15 +666,14 @@ class BasicTask(object): (self.current_step, self.max_train_steps, avg_loss, log_scores, run_speed)) - def _save_ckpt_interval_event(self): + def _default_save_ckpt_interval_event(self): self.save_checkpoint() - def _eval_interval_event(self): + def _default_eval_interval_event(self): self.eval(phase="dev") - def _run_step_event(self, run_state): - if self.is_predict_phase: - yield run_state.run_results + def _default_run_step_event(self, run_state): + pass def _build_net(self): raise NotImplementedError diff --git a/paddlehub/finetune/task/reading_comprehension_task.py b/paddlehub/finetune/task/reading_comprehension_task.py index fa7ca712efa5ccae46766b76e92e850199637da4..9d0f27a5b037507d607a6ee71127c6c7238a48f9 100644 --- a/paddlehub/finetune/task/reading_comprehension_task.py +++ b/paddlehub/finetune/task/reading_comprehension_task.py @@ -575,7 +575,7 @@ class ReadingComprehensionTask(BasicTask): scores = cmrc2018_evaluate.get_eval(dataset, predictions) return scores, avg_loss, run_speed - def _predict_end_event(self, run_states): + def _default_predict_end_event(self, run_states): all_results = [] RawResult = collections.namedtuple( "RawResult", ["unique_id", "start_logits", "end_logits"])