提交 5a6c0869 编写于 作者: J jingqinghe

update load_program

上级 37203246
......@@ -199,23 +199,29 @@ class JobGenerator(object):
local_job.set_strategy(fl_strategy)
local_job.save(output)
def save_program(self, program_path, input_list, hidden_vars, loss):
def save_program(self,
main_prog,
startup_prog,
program_path,
input_list,
hidden_vars,
loss,
learning_rate=None):
if not os.path.exists(program_path):
os.makedirs(program_path)
main_program_str = fluid.default_main_program(
).desc.serialize_to_string()
startup_program_str = fluid.default_startup_program(
).desc.serialize_to_string()
params = fluid.default_main_program().global_block().all_parameters()
main_program_str = main_prog.desc.serialize_to_string()
startup_program_str = startup_prog.desc.serialize_to_string()
params = main_prog.global_block().all_parameters()
para_info = []
for pa in params:
para_info.append(pa.name)
with open(program_path + '/input_names', 'w') as fout:
for input in input_list:
fout.write("%s\n" % input.name)
with open(program_path + '/hidden_vars', 'w') as fout:
for var in hidden_vars:
fout.write("%s:%s\n" % (var[0], var[1].name))
fout.write("%s\n" % input)
if hidden_vars != None:
with open(program_path + '/hidden_vars', 'w') as fout:
for var in hidden_vars:
fout.write("%s:%s\n" % (var[0], var[1].name))
with open(program_path + '/para_info', 'w') as fout:
for item in para_info:
fout.write("%s\n" % item)
......@@ -225,6 +231,9 @@ class JobGenerator(object):
fout.write(main_program_str)
with open(program_path + '/loss_name', 'w') as fout:
fout.write(loss.name)
if type(learning_rate) == fluid.Variable:
with open(program_path + '/lr_name', 'w') as fout:
fout.write(learning_rate.name)
def generate_fl_job_from_program(self, strategy, endpoints, worker_num,
program_input, output):
......@@ -252,6 +261,12 @@ class JobGenerator(object):
with open(program_input + '/loss_name', 'r') as fin:
loss_name = fin.read()
if os.path.exists(program_input + '/lr_name'):
with open(program_input + '/lr_name', 'r') as fin:
lr_name = fin.read()
else:
lr_name = None
for item in para_list:
para = new_main.global_block().var(item)
para.regularizer = None
......@@ -262,9 +277,20 @@ class JobGenerator(object):
for var in new_main.list_vars():
if var.name == loss_name:
loss = var
if lr_name != None:
if var.name == lr_name:
lr = var
with fluid.program_guard(new_main, new_startup):
optimizer = fluid.optimizer.SGD(learning_rate=0.1,
parameter_list=para_list)
if lr_name != None:
optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate=lr, momentum=0.9, parameter_list=para_list)
else:
optimizer = fluid.optimizer.MomentumOptimizer(
learning_rate=0.00001,
momentum=0.9,
parameter_list=para_list)
exe.run(new_startup)
strategy.minimize(optimizer, loss)
......
......@@ -695,7 +695,8 @@ class FLDistributeTranspiler(object):
opti_to_param = dict()
param_to_opti = dict()
for op in self.optimize_ops:
if (op.type == "sgd") or (op.type == "adam"):
if (op.type == "sgd") or (op.type == "adam") or (
op.type == "momentum"):
origin_name = op.output("ParamOut")
var = self.origin_program.global_block().var(origin_name[0])
new_var_name = "%s.opti.trainer_%d" % (origin_name[0],
......
......@@ -79,6 +79,3 @@ while not trainer.stop():
train_test_feed=feeder)
print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val))
save_dir = (output_folder + "/epoch_%d") % epoch_id
trainer.save_inference_program(output_folder)
......@@ -25,12 +25,13 @@ sum_cost = fluid.layers.cross_entropy(input=predict, label=label)
accuracy = fluid.layers.accuracy(input=predict, label=label)
avg_cost = fluid.layers.mean(sum_cost, name="loss")
startup_program = fluid.default_startup_program()
main_program = fluid.default_main_program()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
job_generator = JobGenerator()
program_path = './load_file'
job_generator.save_program(program_path, [input, label],
[['predict', predict], ['accuracy', accuracy]],
avg_cost)
job_generator.save_program(
main_program, startup_program, program_path, [input, label],
[['predict', predict], ['accuracy', accuracy]], avg_cost)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册