Created by: Masterkmp
### 我的版本、环境信息 1)PaddleHub 1.8.2和PaddlePaddle版本1.8.4 2)系统环境:在Windows的AiStudio网页端跑的,python版本为3.7
- 复现信息:
参考这里https://github.com/PaddlePaddle/PaddleHub/wiki/%E5%A6%82%E4%BD%95%E4%BF%AE%E6%94%B9Task%E5%86%85%E7%BD%AE%E6%96%B9%E6%B3%95%EF%BC%9F意图修改paddlehub中的hub.TextClassifierTask的默认的评价指标**(目前是计划先按照文档把默认指标”acc"改成“f1",进而学会怎么改成macro_f1、microf1、召回率)**,这部分代码如下: `import numpy as np
def calculate_f1_np(preds, labels): # 计算F1分数 # preds:预测label # labels: 真实labels # 返回F1分数 preds = np.array(preds) labels = np.array(labels) tp = np.sum((labels == 1) & (preds == 1)) tn = np.sum((labels == 0) & (preds == 0)) fp = np.sum((labels == 0) & (preds == 1)) fn = np.sum((labels == 1) & (preds == 0)) p = tp / (tp + fp) if (tp + fp) else 0 r = tp / (tp + fn) if (tp + fn) else 0 f1 = (2 * p * r) / (p + r) if p + r else 0 return f1
def calculate_metrics(self, run_states): loss_sum = acc_sum = run_examples = 0 run_step = run_time_used = 0 all_labels = np.array([]) all_infers = np.array([]) for run_state in run_states: run_examples += run_state.run_examples run_step += run_state.run_step loss_sum += np.mean( run_state.run_results[-1]) * run_state.run_examples acc_sum += np.mean( run_state.run_results[2]) * run_state.run_examples np_labels = run_state.run_results[0] np_infers = run_state.run_results[1] all_labels = np.hstack((all_labels, np_labels.reshape([-1]))) all_infers = np.hstack((all_infers, np_infers.reshape([-1])))
run_time_used = time.time() - run_states[0].run_time_begin
avg_loss = loss_sum / run_examples
run_speed = run_step / run_time_used
scores = OrderedDict()
f1 = calculate_f1_np(all_infers, all_labels)
scores["f1"] = f1
return scores, avg_loss, run_speed
from tb_paddle import SummaryWriter tb_writer = SummaryWriter("PATH/TO/LOG") def record_value(evaluation_scores, loss, s): tb_writer.add_scalar( tag="Loss_{}".format(self.phase), scalar_value=loss, global_step=self.envs['train'].current_step) log_scores = "" for metric in evaluation_scores: self.tb_writer.add_scalar( tag="{}{}".format(metric, self.phase), scalar_value=scores[metric], global_step=self._envs['train'].current_step) log_scores += "%s=%.5f " % (metric, scores[metric]) print("step %d / %d: loss=%.5f %s[step/sec: %.2f]" % (self.current_step, self.max_train_steps, avg_loss, log_scores, run_speed))
def new_log_interval_event(self, run_states): print("This is the new log_interval_event!") scores, avg_loss, run_speed = calculate_metrics(run_states) record_value(scores, avg_loss, run_speed)
cls_task.delete_hook(hook_type="log_interval_event", name="default") cls_task.add_hook(hook_type="log_interval_event", name="new_log_interval_event", func=new_log_interval_event) cls_task.hook_info()` ### 报错信息如下 This is the new log_interval_event!
---------------------------------------------------------------------------TypeError Traceback (most recent call last) in 1 #run_states = cls_task.finetune_and_eval() ----> 2 cls_task.finetune_and_eval() /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddlehub/finetune/task/base_task.py in finetune_and_eval(self) 943 944 def finetune_and_eval(self): --> 945 return self.finetune(do_eval=True) 946 947 def finetune(self, do_eval=False): /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddlehub/finetune/task/base_task.py in finetune(self, do_eval) 964 while self.current_epoch <= self.config.num_epoch: 965 self.config.strategy.step() --> 966 run_states = self._run(do_eval=do_eval) 967 self.env.current_epoch += 1 968 /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddlehub/finetune/task/base_task.py in _run(self, do_eval) 1222 if self.is_train_phase: 1223 if self.current_step % self.config.log_interval == 0: -> 1224 self._log_interval_event(period_run_states) 1225 global_run_states += period_run_states 1226 period_run_states = [] /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddlehub/finetune/task/base_task.py in hook_function(self, *args) 708 func(*args) 709 else: --> 710 partial(func, self)(*args) 711 712 return hook_function in new_log_interval_event(self, run_states) 73 # 改写的事件方法,参数列表务必与PaddleHub内置的相应方法保持一致 74 print("This is the new log_interval_event!") ---> 75 scores, avg_loss, run_speed = calculate_metrics(run_states) 76 record_value(scores, avg_loss, run_speed) 77 TypeError: calculate_metrics() missing 1 required positional argument: 'run_states'