From b29a0937dddc22ebeae3834df4c9d541d6d81fb4 Mon Sep 17 00:00:00 2001 From: zhangxuefei Date: Thu, 19 Sep 2019 15:42:10 +0800 Subject: [PATCH] fix bug that use gpu in turn --- demo/text-classification/predict.py | 8 +------- paddlehub/autofinetune/autoft.py | 13 +++++++------ 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/demo/text-classification/predict.py b/demo/text-classification/predict.py index 3e11f9c3..17321920 100644 --- a/demo/text-classification/predict.py +++ b/demo/text-classification/predict.py @@ -148,13 +148,7 @@ if __name__ == '__main__': ] if args.use_taskid: - feed_list = [ - inputs["input_ids"].name, - inputs["position_ids"].name, - inputs["segment_ids"].name, - inputs["input_mask"].name, - inputs["task_ids"].name, - ] + feed_list.append(inputs["task_ids"].name) # Setup runing config for PaddleHub Finetune API config = hub.RunConfig( diff --git a/paddlehub/autofinetune/autoft.py b/paddlehub/autofinetune/autoft.py index 74178bac..2ab1a857 100644 --- a/paddlehub/autofinetune/autoft.py +++ b/paddlehub/autofinetune/autoft.py @@ -166,6 +166,7 @@ class BaseTuningStrategy(object): cnt = 0 solutions_ckptdirs = {} mkdir(output_dir) + for idx, solution in enumerate(solutions): cuda = self.is_cuda_free["free"][0] ckptdir = output_dir + "/ckpt-" + str(idx) @@ -174,8 +175,8 @@ class BaseTuningStrategy(object): solutions_ckptdirs[tuple(solution)] = ckptdir self.is_cuda_free["free"].remove(cuda) self.is_cuda_free["busy"].append(cuda) - if len(params_cudas_dirs) == self.thread or cnt == int( - self.popsize / self.thread): + if len(params_cudas_dirs + ) == self.thread or idx == len(solutions) - 1: tp = ThreadPool(len(params_cudas_dirs)) solution_results += tp.map(self.evaluator.run, params_cudas_dirs) @@ -245,11 +246,11 @@ class HAZero(BaseTuningStrategy): best_hparams = self.evaluator.convert_params(self.best_hparams_all_pop) for index, name in enumerate(self.hparams_name_list): self.writer.add_scalar( - tag="hyperparameter tuning/" + name, + tag="hyperparameter_tuning/" + name, scalar_value=best_hparams[index], global_step=self.round) self.writer.add_scalar( - tag="hyperparameter tuning/best_eval_value", + tag="hyperparameter_tuning/best_eval_value", scalar_value=self.get_best_eval_value(), global_step=self.round) @@ -368,11 +369,11 @@ class PSHE2(BaseTuningStrategy): best_hparams = self.evaluator.convert_params(self.best_hparams_all_pop) for index, name in enumerate(self.hparams_name_list): self.writer.add_scalar( - tag="hyperparameter tuning/" + name, + tag="hyperparameter_tuning/" + name, scalar_value=best_hparams[index], global_step=self.round) self.writer.add_scalar( - tag="hyperparameter tuning/best_eval_value", + tag="hyperparameter_tuning/best_eval_value", scalar_value=self.get_best_eval_value(), global_step=self.round) -- GitLab