提交 b29a0937 编写于 作者: Z zhangxuefei

fix bug that use gpu in turn

上级 c40b9011
......@@ -148,13 +148,7 @@ if __name__ == '__main__':
]
if args.use_taskid:
feed_list = [
inputs["input_ids"].name,
inputs["position_ids"].name,
inputs["segment_ids"].name,
inputs["input_mask"].name,
inputs["task_ids"].name,
]
feed_list.append(inputs["task_ids"].name)
# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(
......
......@@ -166,6 +166,7 @@ class BaseTuningStrategy(object):
cnt = 0
solutions_ckptdirs = {}
mkdir(output_dir)
for idx, solution in enumerate(solutions):
cuda = self.is_cuda_free["free"][0]
ckptdir = output_dir + "/ckpt-" + str(idx)
......@@ -174,8 +175,8 @@ class BaseTuningStrategy(object):
solutions_ckptdirs[tuple(solution)] = ckptdir
self.is_cuda_free["free"].remove(cuda)
self.is_cuda_free["busy"].append(cuda)
if len(params_cudas_dirs) == self.thread or cnt == int(
self.popsize / self.thread):
if len(params_cudas_dirs
) == self.thread or idx == len(solutions) - 1:
tp = ThreadPool(len(params_cudas_dirs))
solution_results += tp.map(self.evaluator.run,
params_cudas_dirs)
......@@ -245,11 +246,11 @@ class HAZero(BaseTuningStrategy):
best_hparams = self.evaluator.convert_params(self.best_hparams_all_pop)
for index, name in enumerate(self.hparams_name_list):
self.writer.add_scalar(
tag="hyperparameter tuning/" + name,
tag="hyperparameter_tuning/" + name,
scalar_value=best_hparams[index],
global_step=self.round)
self.writer.add_scalar(
tag="hyperparameter tuning/best_eval_value",
tag="hyperparameter_tuning/best_eval_value",
scalar_value=self.get_best_eval_value(),
global_step=self.round)
......@@ -368,11 +369,11 @@ class PSHE2(BaseTuningStrategy):
best_hparams = self.evaluator.convert_params(self.best_hparams_all_pop)
for index, name in enumerate(self.hparams_name_list):
self.writer.add_scalar(
tag="hyperparameter tuning/" + name,
tag="hyperparameter_tuning/" + name,
scalar_value=best_hparams[index],
global_step=self.round)
self.writer.add_scalar(
tag="hyperparameter tuning/best_eval_value",
tag="hyperparameter_tuning/best_eval_value",
scalar_value=self.get_best_eval_value(),
global_step=self.round)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册