提交 b29a0937 编写于 作者: Z zhangxuefei

fix bug that use gpu in turn

上级 c40b9011
...@@ -148,13 +148,7 @@ if __name__ == '__main__': ...@@ -148,13 +148,7 @@ if __name__ == '__main__':
] ]
if args.use_taskid: if args.use_taskid:
feed_list = [ feed_list.append(inputs["task_ids"].name)
inputs["input_ids"].name,
inputs["position_ids"].name,
inputs["segment_ids"].name,
inputs["input_mask"].name,
inputs["task_ids"].name,
]
# Setup runing config for PaddleHub Finetune API # Setup runing config for PaddleHub Finetune API
config = hub.RunConfig( config = hub.RunConfig(
......
...@@ -166,6 +166,7 @@ class BaseTuningStrategy(object): ...@@ -166,6 +166,7 @@ class BaseTuningStrategy(object):
cnt = 0 cnt = 0
solutions_ckptdirs = {} solutions_ckptdirs = {}
mkdir(output_dir) mkdir(output_dir)
for idx, solution in enumerate(solutions): for idx, solution in enumerate(solutions):
cuda = self.is_cuda_free["free"][0] cuda = self.is_cuda_free["free"][0]
ckptdir = output_dir + "/ckpt-" + str(idx) ckptdir = output_dir + "/ckpt-" + str(idx)
...@@ -174,8 +175,8 @@ class BaseTuningStrategy(object): ...@@ -174,8 +175,8 @@ class BaseTuningStrategy(object):
solutions_ckptdirs[tuple(solution)] = ckptdir solutions_ckptdirs[tuple(solution)] = ckptdir
self.is_cuda_free["free"].remove(cuda) self.is_cuda_free["free"].remove(cuda)
self.is_cuda_free["busy"].append(cuda) self.is_cuda_free["busy"].append(cuda)
if len(params_cudas_dirs) == self.thread or cnt == int( if len(params_cudas_dirs
self.popsize / self.thread): ) == self.thread or idx == len(solutions) - 1:
tp = ThreadPool(len(params_cudas_dirs)) tp = ThreadPool(len(params_cudas_dirs))
solution_results += tp.map(self.evaluator.run, solution_results += tp.map(self.evaluator.run,
params_cudas_dirs) params_cudas_dirs)
...@@ -245,11 +246,11 @@ class HAZero(BaseTuningStrategy): ...@@ -245,11 +246,11 @@ class HAZero(BaseTuningStrategy):
best_hparams = self.evaluator.convert_params(self.best_hparams_all_pop) best_hparams = self.evaluator.convert_params(self.best_hparams_all_pop)
for index, name in enumerate(self.hparams_name_list): for index, name in enumerate(self.hparams_name_list):
self.writer.add_scalar( self.writer.add_scalar(
tag="hyperparameter tuning/" + name, tag="hyperparameter_tuning/" + name,
scalar_value=best_hparams[index], scalar_value=best_hparams[index],
global_step=self.round) global_step=self.round)
self.writer.add_scalar( self.writer.add_scalar(
tag="hyperparameter tuning/best_eval_value", tag="hyperparameter_tuning/best_eval_value",
scalar_value=self.get_best_eval_value(), scalar_value=self.get_best_eval_value(),
global_step=self.round) global_step=self.round)
...@@ -368,11 +369,11 @@ class PSHE2(BaseTuningStrategy): ...@@ -368,11 +369,11 @@ class PSHE2(BaseTuningStrategy):
best_hparams = self.evaluator.convert_params(self.best_hparams_all_pop) best_hparams = self.evaluator.convert_params(self.best_hparams_all_pop)
for index, name in enumerate(self.hparams_name_list): for index, name in enumerate(self.hparams_name_list):
self.writer.add_scalar( self.writer.add_scalar(
tag="hyperparameter tuning/" + name, tag="hyperparameter_tuning/" + name,
scalar_value=best_hparams[index], scalar_value=best_hparams[index],
global_step=self.round) global_step=self.round)
self.writer.add_scalar( self.writer.add_scalar(
tag="hyperparameter tuning/best_eval_value", tag="hyperparameter_tuning/best_eval_value",
scalar_value=self.get_best_eval_value(), scalar_value=self.get_best_eval_value(),
global_step=self.round) global_step=self.round)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册