提交 237542e9 编写于 作者: H hj

fix code style.

上级 e8728fbf
...@@ -79,7 +79,8 @@ class BaseTuningStrategy(object): ...@@ -79,7 +79,8 @@ class BaseTuningStrategy(object):
# for parallel on mpi # for parallel on mpi
self.mpi = MPIHelper() self.mpi = MPIHelper()
if self.mpi.multi_machine: if self.mpi.multi_machine:
print("Autofinetune multimachine mode: running on {}".format(self.mpi.gather(self.mpi.name))) print("Autofinetune multimachine mode: running on {}".format(
self.mpi.gather(self.mpi.name)))
@property @property
def thread(self): def thread(self):
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
class MPIHelper(object): class MPIHelper(object):
def __init__(self): def __init__(self):
try: try:
...@@ -43,7 +44,7 @@ class MPIHelper(object): ...@@ -43,7 +44,7 @@ class MPIHelper(object):
def bcast(self, data): def bcast(self, data):
if self._multi_machine: if self._multi_machine:
# call real bcast # call real bcast
return self._comm.bcast(data, root = 0) return self._comm.bcast(data, root=0)
else: else:
# do nothing # do nothing
return data return data
...@@ -51,7 +52,7 @@ class MPIHelper(object): ...@@ -51,7 +52,7 @@ class MPIHelper(object):
def gather(self, data): def gather(self, data):
if self._multi_machine: if self._multi_machine:
# call real gather # call real gather
return self._comm.gather(data, root = 0) return self._comm.gather(data, root=0)
else: else:
# do nothing # do nothing
return [data] return [data]
...@@ -73,7 +74,8 @@ class MPIHelper(object): ...@@ -73,7 +74,8 @@ class MPIHelper(object):
return average_count * self._rank, average_count * (self._rank + 1) return average_count * self._rank, average_count * (self._rank + 1)
else: else:
if self._rank < array_length % self._size: if self._rank < array_length % self._size:
return (average_count + 1) * self._rank, (average_count + 1) * (self._rank + 1) return (average_count + 1) * self._rank, (average_count + 1) * (
self._rank + 1)
else: else:
start = (average_count + 1) * (array_length % self._size) \ start = (average_count + 1) * (array_length % self._size) \
+ average_count * (self._rank - array_length % self._size) + average_count * (self._rank - array_length % self._size)
...@@ -83,7 +85,8 @@ class MPIHelper(object): ...@@ -83,7 +85,8 @@ class MPIHelper(object):
if __name__ == "__main__": if __name__ == "__main__":
mpi = MPIHelper() mpi = MPIHelper()
print("Hello world from process {} of {} at {}.".format(mpi.rank, mpi.size, mpi.name)) print("Hello world from process {} of {} at {}.".format(
mpi.rank, mpi.size, mpi.name))
all_node_names = mpi.gather(mpi.name) all_node_names = mpi.gather(mpi.name)
print("all node names using gather: {}".format(all_node_names)) print("all node names using gather: {}".format(all_node_names))
...@@ -106,8 +109,7 @@ if __name__ == "__main__": ...@@ -106,8 +109,7 @@ if __name__ == "__main__":
# test for split # test for split
for i in range(12): for i in range(12):
length = i + mpi.size # length should >= mpi.size length = i + mpi.size # length should >= mpi.size
[start, end] = mpi.split_range(length) [start, end] = mpi.split_range(length)
split_result = mpi.gather([start, end]) split_result = mpi.gather([start, end])
print("length {}, split_result {}".format(length, split_result)) print("length {}, split_result {}".format(length, split_result))
...@@ -199,14 +199,16 @@ class AutoFineTuneCommand(BaseCommand): ...@@ -199,14 +199,16 @@ class AutoFineTuneCommand(BaseCommand):
print("%s=%s" % (hparam_name, best_hparams[index])) print("%s=%s" % (hparam_name, best_hparams[index]))
f.write(hparam_name + "\t:\t" + str(best_hparams[index]) + "\n") f.write(hparam_name + "\t:\t" + str(best_hparams[index]) + "\n")
best_hparams_dir, best_hparams_rank = solutions_modeldirs[tuple(best_hparams_origin)] best_hparams_dir, best_hparams_rank = solutions_modeldirs[tuple(
best_hparams_origin)]
print("The final best eval score is %s." % print("The final best eval score is %s." %
autoft.get_best_eval_value()) autoft.get_best_eval_value())
if autoft.mpi.multi_machine: if autoft.mpi.multi_machine:
print("The final best model parameters are saved as " + print("The final best model parameters are saved as " +
autoft._output_dir + "/best_model on rank " + str(best_hparams_rank) + " .") autoft._output_dir + "/best_model on rank " +
str(best_hparams_rank) + " .")
else: else:
print("The final best model parameters are saved as " + print("The final best model parameters are saved as " +
autoft._output_dir + "/best_model .") autoft._output_dir + "/best_model .")
...@@ -226,7 +228,8 @@ class AutoFineTuneCommand(BaseCommand): ...@@ -226,7 +228,8 @@ class AutoFineTuneCommand(BaseCommand):
"\tsaved_params_dir\trank\n") "\tsaved_params_dir\trank\n")
else: else:
f.write( f.write(
"The final best model parameters are saved as ./best_model .") "The final best model parameters are saved as ./best_model ."
)
f.write("\t".join(autoft.hparams_name_list) + f.write("\t".join(autoft.hparams_name_list) +
"\tsaved_params_dir\n") "\tsaved_params_dir\n")
...@@ -237,7 +240,8 @@ class AutoFineTuneCommand(BaseCommand): ...@@ -237,7 +240,8 @@ class AutoFineTuneCommand(BaseCommand):
param = evaluator.convert_params(solution) param = evaluator.convert_params(solution)
param = [str(p) for p in param] param = [str(p) for p in param]
if autoft.mpi.multi_machine: if autoft.mpi.multi_machine:
f.write("\t".join(param) + "\t" + modeldir[0] + "\t" + str(modeldir[1]) + "\n") f.write("\t".join(param) + "\t" + modeldir[0] + "\t" +
str(modeldir[1]) + "\n")
else: else:
f.write("\t".join(param) + "\t" + modeldir[0] + "\n") f.write("\t".join(param) + "\t" + modeldir[0] + "\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册