提交 237542e9 编写于 作者: H hj

fix code style.

上级 e8728fbf
......@@ -79,7 +79,8 @@ class BaseTuningStrategy(object):
# for parallel on mpi
self.mpi = MPIHelper()
if self.mpi.multi_machine:
print("Autofinetune multimachine mode: running on {}".format(self.mpi.gather(self.mpi.name)))
print("Autofinetune multimachine mode: running on {}".format(
self.mpi.gather(self.mpi.name)))
@property
def thread(self):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
class MPIHelper(object):
def __init__(self):
try:
......@@ -43,7 +44,7 @@ class MPIHelper(object):
def bcast(self, data):
if self._multi_machine:
# call real bcast
return self._comm.bcast(data, root = 0)
return self._comm.bcast(data, root=0)
else:
# do nothing
return data
......@@ -51,7 +52,7 @@ class MPIHelper(object):
def gather(self, data):
if self._multi_machine:
# call real gather
return self._comm.gather(data, root = 0)
return self._comm.gather(data, root=0)
else:
# do nothing
return [data]
......@@ -73,7 +74,8 @@ class MPIHelper(object):
return average_count * self._rank, average_count * (self._rank + 1)
else:
if self._rank < array_length % self._size:
return (average_count + 1) * self._rank, (average_count + 1) * (self._rank + 1)
return (average_count + 1) * self._rank, (average_count + 1) * (
self._rank + 1)
else:
start = (average_count + 1) * (array_length % self._size) \
+ average_count * (self._rank - array_length % self._size)
......@@ -83,7 +85,8 @@ class MPIHelper(object):
if __name__ == "__main__":
mpi = MPIHelper()
print("Hello world from process {} of {} at {}.".format(mpi.rank, mpi.size, mpi.name))
print("Hello world from process {} of {} at {}.".format(
mpi.rank, mpi.size, mpi.name))
all_node_names = mpi.gather(mpi.name)
print("all node names using gather: {}".format(all_node_names))
......@@ -106,8 +109,7 @@ if __name__ == "__main__":
# test for split
for i in range(12):
length = i + mpi.size # length should >= mpi.size
length = i + mpi.size # length should >= mpi.size
[start, end] = mpi.split_range(length)
split_result = mpi.gather([start, end])
print("length {}, split_result {}".format(length, split_result))
......@@ -199,14 +199,16 @@ class AutoFineTuneCommand(BaseCommand):
print("%s=%s" % (hparam_name, best_hparams[index]))
f.write(hparam_name + "\t:\t" + str(best_hparams[index]) + "\n")
best_hparams_dir, best_hparams_rank = solutions_modeldirs[tuple(best_hparams_origin)]
best_hparams_dir, best_hparams_rank = solutions_modeldirs[tuple(
best_hparams_origin)]
print("The final best eval score is %s." %
autoft.get_best_eval_value())
if autoft.mpi.multi_machine:
print("The final best model parameters are saved as " +
autoft._output_dir + "/best_model on rank " + str(best_hparams_rank) + " .")
autoft._output_dir + "/best_model on rank " +
str(best_hparams_rank) + " .")
else:
print("The final best model parameters are saved as " +
autoft._output_dir + "/best_model .")
......@@ -226,7 +228,8 @@ class AutoFineTuneCommand(BaseCommand):
"\tsaved_params_dir\trank\n")
else:
f.write(
"The final best model parameters are saved as ./best_model .")
"The final best model parameters are saved as ./best_model ."
)
f.write("\t".join(autoft.hparams_name_list) +
"\tsaved_params_dir\n")
......@@ -237,7 +240,8 @@ class AutoFineTuneCommand(BaseCommand):
param = evaluator.convert_params(solution)
param = [str(p) for p in param]
if autoft.mpi.multi_machine:
f.write("\t".join(param) + "\t" + modeldir[0] + "\t" + str(modeldir[1]) + "\n")
f.write("\t".join(param) + "\t" + modeldir[0] + "\t" +
str(modeldir[1]) + "\n")
else:
f.write("\t".join(param) + "\t" + modeldir[0] + "\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册